Compare commits
22 Commits
98331abbd5
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0699436c5 | ||
|
|
af53111928 | ||
|
|
b8654aa31f | ||
|
|
be5c84bcff | ||
|
|
19fc9f3289 | ||
|
|
036e12349d | ||
|
|
e0931daece | ||
|
|
e55ec42ae5 | ||
|
|
189a0fad34 | ||
|
|
d2b4610df9 | ||
|
|
0e78e5b06b | ||
|
|
38644594d2 | ||
|
|
ef6e5ac2be | ||
|
|
33db5aeb10 | ||
|
|
a2f750269d | ||
|
|
a54eb224e0 | ||
|
|
006b4ee5d7 | ||
|
|
b861ff055f | ||
|
|
512f988dd0 | ||
|
|
6e7b824b64 | ||
|
|
1050df780d | ||
|
|
7c3571b47d |
@@ -2,7 +2,22 @@
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(find:*)",
|
||||
"Bash(ruff:*)",
|
||||
"Bash(pytest:*)",
|
||||
"Bash(git status:*)",
|
||||
"Bash(git diff:*)",
|
||||
"Bash(git log:*)",
|
||||
"Bash(git branch:*)",
|
||||
"Bash(git add:*)",
|
||||
"Bash(git commit:*)",
|
||||
"Bash(git checkout:*)",
|
||||
"Bash(git merge:*)",
|
||||
"Bash(git tag:*)",
|
||||
"Bash(git show:*)",
|
||||
"Bash(docker:*)",
|
||||
"Bash(docker-compose:*)",
|
||||
"WebSearch"
|
||||
]
|
||||
],
|
||||
"defaultMode": "bypassPermissions"
|
||||
}
|
||||
}
|
||||
|
||||
35
.env.example
Normal file
35
.env.example
Normal file
@@ -0,0 +1,35 @@
|
||||
# Smart Support -- Docker Compose environment variables
|
||||
# Copy to .env and fill in your values
|
||||
|
||||
# PostgreSQL password (used by both postgres and backend services)
|
||||
POSTGRES_PASSWORD=dev_password
|
||||
|
||||
# LLM provider: anthropic | openai | azure_openai | google
|
||||
LLM_PROVIDER=anthropic
|
||||
LLM_MODEL=claude-sonnet-4-6
|
||||
|
||||
# API keys (provide the one matching LLM_PROVIDER)
|
||||
ANTHROPIC_API_KEY=
|
||||
OPENAI_API_KEY=
|
||||
GOOGLE_API_KEY=
|
||||
|
||||
# Azure OpenAI (required when LLM_PROVIDER=azure_openai)
|
||||
AZURE_OPENAI_API_KEY=
|
||||
AZURE_OPENAI_ENDPOINT=
|
||||
AZURE_OPENAI_DEPLOYMENT=
|
||||
AZURE_OPENAI_API_VERSION=2024-12-01-preview
|
||||
|
||||
# Optional: webhook URL for escalation notifications
|
||||
WEBHOOK_URL=
|
||||
|
||||
# Session and interrupt TTL in minutes
|
||||
SESSION_TTL_MINUTES=30
|
||||
INTERRUPT_TTL_MINUTES=30
|
||||
|
||||
# Optional: API key for admin endpoints (analytics, replay, openapi, websocket)
|
||||
# Leave empty to disable authentication (dev mode)
|
||||
ADMIN_API_KEY=
|
||||
|
||||
# Optional: load a named agent template instead of agents.yaml
|
||||
# Available templates: ecommerce, saas, generic
|
||||
TEMPLATE_NAME=
|
||||
65
CLAUDE.md
65
CLAUDE.md
@@ -30,7 +30,7 @@ pytest --cov=app --cov-report=term-missing
|
||||
# - If any test fails, fix it before starting the new phase
|
||||
|
||||
# 3. Create checkpoint to snapshot the starting state
|
||||
/everything-claude-code:checkpoint create [phase name]
|
||||
/ecc:checkpoint create "phase-name"
|
||||
|
||||
# 4. Create the phase branch
|
||||
git checkout main
|
||||
@@ -50,25 +50,32 @@ git checkout -b phase-{N}/{short-description}
|
||||
3. Identify all tasks, acceptance criteria, and dependencies for this phase
|
||||
4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5)
|
||||
|
||||
### Step 2: Develop Using Orchestrate Skill
|
||||
### Step 2: Develop Using ECC Skills
|
||||
|
||||
Route to the correct orchestration mode based on work type:
|
||||
Route to the correct skill based on work type:
|
||||
|
||||
| Work Type | Skill Command |
|
||||
|-----------|---------------|
|
||||
| New feature | `/everything-claude-code:orchestrate feature` |
|
||||
| Bug fix | `/everything-claude-code:orchestrate bugfix` |
|
||||
| Refactor | `/everything-claude-code:orchestrate refactor` |
|
||||
| Work Type | Skill Command | What It Does |
|
||||
|-----------|---------------|--------------|
|
||||
| New feature | `/ecc:feature-dev <desc>` | Discovery -> Exploration -> Architecture -> TDD -> Review -> Summary |
|
||||
| Bug fix | `/ecc:tdd` then `/ecc:code-review` | RED -> GREEN -> REFACTOR cycle, then review |
|
||||
| Refactor | `/ecc:plan` then `/ecc:tdd` then `/ecc:code-review` | Plan refactor scope, TDD, review |
|
||||
| Security-sensitive | Add `/ecc:security-review` after code-review | Auth, payments, user input, external APIs |
|
||||
| Final verification | `/ecc:verify` | Build + tests + lint + coverage + security scan |
|
||||
|
||||
ALWAYS use the appropriate orchestrate skill. Never develop without it.
|
||||
|
||||
A single phase may contain mixed work types (e.g., Phase 5 has feature + bugfix + refactor). Call the orchestrate skill **per sub-task** with the matching mode. Example:
|
||||
A single phase may contain mixed work types. Call the appropriate skill **per sub-task**:
|
||||
|
||||
```
|
||||
# Within Phase 5:
|
||||
/everything-claude-code:orchestrate feature # for demo script
|
||||
/everything-claude-code:orchestrate bugfix # for error handling fixes
|
||||
/everything-claude-code:orchestrate refactor # for code cleanup
|
||||
# Within a phase:
|
||||
/ecc:feature-dev "demo script" # for new features
|
||||
/ecc:tdd # for bug fixes (write failing test, then fix)
|
||||
/ecc:plan "consolidate error handling" # for refactors (plan first, then TDD)
|
||||
```
|
||||
|
||||
For full multi-phase autonomous execution, use GSD:
|
||||
|
||||
```
|
||||
/gsd:autonomous # execute all remaining phases
|
||||
/gsd:execute-phase 6 # execute a specific phase
|
||||
```
|
||||
|
||||
### Step 3: Module Independence (CRITICAL)
|
||||
@@ -171,10 +178,10 @@ After all development and testing, run verification in this exact order:
|
||||
|
||||
```
|
||||
# 1. Run the verification skill -- must pass
|
||||
/everything-claude-code:verify
|
||||
/ecc:verify
|
||||
|
||||
# 2. Verify the checkpoint -- validates all phase deliverables
|
||||
/everything-claude-code:checkpoint verify [phase name]
|
||||
/ecc:checkpoint verify "phase-name"
|
||||
```
|
||||
|
||||
The checkpoint verify validates:
|
||||
@@ -222,11 +229,11 @@ git push origin main --tags
|
||||
All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy.
|
||||
|
||||
A checkpoint includes:
|
||||
- `/everything-claude-code:checkpoint create` at phase start
|
||||
- `/everything-claude-code:checkpoint verify` at phase end
|
||||
- `/ecc:checkpoint create` at phase start
|
||||
- `/ecc:checkpoint verify` at phase end
|
||||
- All tests passing (80%+ coverage)
|
||||
- Phase dev log written and linked
|
||||
- `/everything-claude-code:verify` passed
|
||||
- `/ecc:verify` passed
|
||||
- Git tag `checkpoint/phase-{N}` created
|
||||
- Phase marked COMPLETED in four locations
|
||||
- Branch merged to main
|
||||
@@ -238,10 +245,10 @@ A checkpoint includes:
|
||||
| Phase | Branch | Focus | Status |
|
||||
|-------|--------|-------|--------|
|
||||
| 1 | `phase-1/core-framework` | FastAPI + LangGraph + React chat loop + PostgresSaver | COMPLETED (2026-03-30) |
|
||||
| 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | NOT STARTED |
|
||||
| 3 | `phase-3/openapi-discovery` | OpenAPI parsing + MCP generation + SSRF protection | NOT STARTED |
|
||||
| 4 | `phase-4/analytics-replay` | Replay API + analytics dashboard | NOT STARTED |
|
||||
| 5 | `phase-5/polish-demo` | Error hardening + demo prep + Docker deploy | NOT STARTED |
|
||||
| 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | COMPLETED (2026-03-30) |
|
||||
| 3 | `phase-3/openapi-discovery` | OpenAPI parsing + MCP generation + SSRF protection | COMPLETED (2026-03-30) |
|
||||
| 4 | `phase-4/analytics-replay` | Replay API + analytics dashboard | COMPLETED (2026-03-31) |
|
||||
| 5 | `phase-5/polish-demo` | Error hardening + demo prep + Docker deploy | COMPLETED (2026-03-31) |
|
||||
|
||||
Status values: `NOT STARTED` -> `IN PROGRESS` -> `COMPLETED (YYYY-MM-DD)`
|
||||
|
||||
@@ -264,7 +271,7 @@ This project inherits from `~/.claude/rules/`. CLAUDE.md only contains project-s
|
||||
|
||||
### Hooks (ECC Plugin -- No Custom Hooks)
|
||||
|
||||
All hooks come from the ECC plugin (`everything-claude-code`). No project-level hooks in `.claude/settings.local.json`.
|
||||
All hooks come from the ECC plugin (`ecc`). No project-level hooks in `.claude/settings.local.json`.
|
||||
|
||||
| ECC Hook | Type | What It Does |
|
||||
|----------|------|-------------|
|
||||
@@ -290,7 +297,7 @@ Controlled by `ECC_HOOK_PROFILE` env var in `~/.claude/settings.json` (currently
|
||||
- Architecture doc: `docs/ARCHITECTURE.md`
|
||||
- Phase dev logs: `docs/phases/phase-{N}-dev-log.md`
|
||||
- Test command: `pytest --cov=app --cov-report=term-missing`
|
||||
- **Phase start:** `/everything-claude-code:checkpoint create [phase name]`
|
||||
- **Phase end:** `/everything-claude-code:checkpoint verify [phase name]`
|
||||
- Verify command: `/everything-claude-code:verify`
|
||||
- Orchestrate: `/everything-claude-code:orchestrate {feature|bugfix|refactor}`
|
||||
- **Phase start:** `/ecc:checkpoint create "phase-name"`
|
||||
- **Phase end:** `/ecc:checkpoint verify "phase-name"`
|
||||
- Verify command: `/ecc:verify`
|
||||
- Orchestrate: `/ecc:orchestrate {feature|bugfix|refactor}`
|
||||
|
||||
269
README.md
269
README.md
@@ -1,159 +1,174 @@
|
||||
# Smart Support
|
||||
|
||||
AI 客服行动层框架。粘贴你的 API,获得一个能执行真实操作的智能客服。
|
||||
AI customer support action layer. Paste your API spec, get an AI agent that executes real actions.
|
||||
|
||||
## 问题
|
||||
## The Problem
|
||||
|
||||
现有客服工具(Zendesk、Intercom、Ada)擅长回答 FAQ,但自动化率卡在 20-30%。剩下 70% 的工单需要人工登录内部系统,手动查订单、取消订单、发优惠券。
|
||||
Existing support tools (Zendesk, Intercom, Ada) answer FAQs well but automation
|
||||
rates stall at 20-30%. The remaining 70% of tickets require agents to manually
|
||||
log into internal systems to look up orders, cancel orders, issue coupons.
|
||||
|
||||
Smart Support 是补全这个缺口的「行动层」。它不替代现有客服平台,而是让 AI 能直接调用内部系统完成操作。
|
||||
Smart Support fills that gap as the "action layer" -- it does not replace your
|
||||
existing support platform, it enables AI to directly call your internal systems.
|
||||
|
||||
## 工作原理
|
||||
## How It Works
|
||||
|
||||
```
|
||||
客户消息 → Chat UI → FastAPI WebSocket → LangGraph Supervisor → 专业 Agent → MCP Tools → 你的内部系统
|
||||
↑ ↑
|
||||
Agent 注册表 interrupt()
|
||||
(YAML 配置) (人工确认)
|
||||
↑
|
||||
PostgresSaver
|
||||
(会话状态持久化)
|
||||
User message -> Chat UI -> FastAPI WebSocket -> LangGraph Supervisor -> Specialist Agent -> MCP Tools -> Your systems
|
||||
| |
|
||||
Agent Registry interrupt()
|
||||
(YAML config) (human approval)
|
||||
|
|
||||
PostgresSaver
|
||||
(session persistence)
|
||||
```
|
||||
|
||||
1. 客户在聊天界面发送消息
|
||||
2. LangGraph Supervisor 分析意图,路由到对应的专业 Agent
|
||||
3. Agent 通过 MCP 协议调用你的内部系统(查订单、取消订单、发折扣...)
|
||||
4. 涉及写操作时,自动触发人工确认流程
|
||||
5. 所有操作全程记录,支持回放和分析
|
||||
1. User sends a message in the chat UI.
|
||||
2. LangGraph Supervisor classifies intent and routes to the right agent.
|
||||
3. Agent calls your internal systems via MCP tools.
|
||||
4. Write operations trigger a human-in-the-loop approval gate.
|
||||
5. All operations are logged with full replay and analytics.
|
||||
|
||||
## 核心特性
|
||||
## Key Features
|
||||
|
||||
- **多 Agent 协作** - 不同操作由不同 Agent 处理,各自拥有独立的权限边界和工具集
|
||||
- **即插即用** - 粘贴 OpenAPI 规范 URL,自动生成 MCP 工具和 Agent 配置
|
||||
- **人工确认** - 所有写操作(取消、退款、修改)需要人工审批,读操作直接执行
|
||||
- **会话上下文** - 支持多轮对话,Agent 能理解「取消那个订单」这样的指代
|
||||
- **实时流式输出** - WebSocket 双向通信,逐 token 流式返回
|
||||
- **对话回放** - 逐步查看 Agent 决策过程、工具调用和返回结果
|
||||
- **数据分析** - 解决率、Agent 使用率、升级率、每次对话成本
|
||||
- **YAML 驱动配置** - Agent 定义、人设、垂直模板全部通过 YAML 配置
|
||||
- **Multi-agent routing** -- each operation goes to a specialist agent with its own tools and permissions
|
||||
- **Zero-config import** -- paste an OpenAPI 3.0 URL, agents are generated automatically
|
||||
- **Human-in-the-loop** -- all write operations (cancel, refund, modify) require approval; reads execute immediately
|
||||
- **Session context** -- multi-turn conversation with persistent state across reconnects
|
||||
- **Real-time streaming** -- WebSocket token streaming with live tool call visibility
|
||||
- **Conversation replay** -- step-by-step audit trail of every agent decision
|
||||
- **Analytics dashboard** -- resolution rate, agent usage, escalation rate, cost per conversation
|
||||
- **YAML-driven config** -- agents, personas, and vertical templates in a single file
|
||||
|
||||
## 技术栈
|
||||
## Tech Stack
|
||||
|
||||
| 组件 | 技术选型 |
|
||||
|------|---------|
|
||||
| 后端 | Python 3.11+, FastAPI |
|
||||
| Agent 编排 | LangGraph v1.1, langgraph-supervisor |
|
||||
| 工具集成 | langchain-mcp-adapters, @tool |
|
||||
| 状态持久化 | PostgreSQL + langgraph-checkpoint-postgres |
|
||||
| LLM | Claude Sonnet 4.6(可切换 OpenAI、Google 等) |
|
||||
| 前端 | React |
|
||||
| 部署 | Docker Compose |
|
||||
| Component | Technology |
|
||||
|-----------|-----------|
|
||||
| Backend | Python 3.11+, FastAPI |
|
||||
| Agent orchestration | LangGraph 1.x, langgraph-supervisor |
|
||||
| Session state | PostgreSQL 16 + langgraph-checkpoint-postgres |
|
||||
| LLM | Claude Sonnet 4.6 (configurable: OpenAI, Azure OpenAI, Google) |
|
||||
| Frontend | React 19, TypeScript, Vite |
|
||||
| Testing | pytest (backend), vitest + happy-dom (frontend) |
|
||||
| Deployment | Docker Compose |
|
||||
|
||||
## 项目结构
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
git clone <repo-url>
|
||||
cd smart-support
|
||||
|
||||
# Configure your LLM API key
|
||||
cp .env.example .env
|
||||
# Edit .env: set LLM_PROVIDER and the corresponding API key
|
||||
# anthropic -> ANTHROPIC_API_KEY
|
||||
# openai -> OPENAI_API_KEY
|
||||
# azure_openai -> AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT + AZURE_OPENAI_DEPLOYMENT
|
||||
# google -> GOOGLE_API_KEY
|
||||
|
||||
# Start all services
|
||||
docker compose up -d
|
||||
|
||||
# Open the app
|
||||
open http://localhost
|
||||
```
|
||||
|
||||
### Local Development
|
||||
|
||||
```bash
|
||||
# Start only PostgreSQL via Docker (exposed on port 5433)
|
||||
docker compose up postgres -d
|
||||
|
||||
# Backend (in one terminal)
|
||||
cd backend
|
||||
pip install -e ".[dev]"
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8001 --reload
|
||||
|
||||
# Frontend (in another terminal)
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev # http://localhost:5173 (proxies /api and /ws to :8001)
|
||||
```
|
||||
|
||||
See [Deployment Guide](docs/deployment.md) for production setup, HTTPS, and scaling.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
smart-support/
|
||||
├── backend/
|
||||
│ ├── app/
|
||||
│ │ ├── main.py # FastAPI + WebSocket 入口
|
||||
│ │ ├── graph.py # LangGraph Supervisor 配置
|
||||
│ │ ├── agents/ # Agent 定义 + 工具
|
||||
│ │ ├── registry.py # YAML Agent 注册表加载器
|
||||
│ │ ├── openapi/ # OpenAPI 解析 + MCP 服务器生成
|
||||
│ │ ├── replay/ # 对话回放 API
|
||||
│ │ ├── analytics/ # 数据分析查询 + API
|
||||
│ │ └── callbacks.py # Token 用量统计
|
||||
│ ├── agents.yaml # Agent 注册表配置
|
||||
│ ├── templates/ # 垂直行业模板
|
||||
│ └── tests/
|
||||
├── frontend/ # React 聊天 UI + 回放 + 仪表盘
|
||||
├── docker-compose.yml # PostgreSQL + 应用
|
||||
└── pyproject.toml
|
||||
│ │ ├── main.py # FastAPI + WebSocket entry point
|
||||
│ │ ├── graph.py # LangGraph Supervisor construction
|
||||
│ │ ├── graph_context.py # Typed wrapper for graph + classifier + registry
|
||||
│ │ ├── ws_handler.py # WebSocket message dispatch + rate limiting
|
||||
│ │ ├── ws_context.py # WebSocket dependency bundle
|
||||
│ │ ├── auth.py # API key authentication middleware
|
||||
│ │ ├── api_utils.py # Shared API response helpers
|
||||
│ │ ├── safety.py # Confirmation rules + MCP error taxonomy
|
||||
│ │ ├── agents/ # Agent definitions and tools
|
||||
│ │ ├── registry.py # YAML agent registry loader
|
||||
│ │ ├── openapi/ # OpenAPI parser, classifier, and review API
|
||||
│ │ ├── replay/ # Conversation replay API
|
||||
│ │ └── analytics/ # Analytics queries and API
|
||||
│ ├── agents.yaml # Agent registry configuration
|
||||
│ ├── templates/ # Vertical industry templates
|
||||
│ └── tests/ # Unit, integration, and E2E tests
|
||||
├── frontend/
|
||||
│ ├── src/
|
||||
│ │ ├── pages/ # Chat, Replay, Dashboard, Review pages
|
||||
│ │ ├── components/ # NavBar, Layout, MetricCard, ReplayTimeline
|
||||
│ │ ├── hooks/ # useWebSocket with reconnect support
|
||||
│ │ └── api.ts # Typed API client
|
||||
│ └── Dockerfile # Multi-stage nginx build
|
||||
├── docs/ # Architecture, deployment, guides
|
||||
├── docker-compose.yml # Full-stack compose
|
||||
└── .env.example # Environment variable template
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
## API Endpoints
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|--------|------|------|-------------|
|
||||
| WS | `/ws` | Token | Main WebSocket chat endpoint (`?token=<key>`) |
|
||||
| GET | `/api/health` | No | Health check |
|
||||
| GET | `/api/conversations` | API Key | List conversations (paginated) |
|
||||
| GET | `/api/replay/{thread_id}` | API Key | Replay conversation steps (paginated) |
|
||||
| GET | `/api/analytics` | API Key | Analytics summary (`?range=7d`) |
|
||||
| POST | `/api/openapi/import` | API Key | Start OpenAPI import job |
|
||||
| GET | `/api/openapi/jobs/{id}` | API Key | Check import job status |
|
||||
| GET | `/api/openapi/jobs/{id}/classifications` | API Key | Get endpoint classifications |
|
||||
| PUT | `/api/openapi/jobs/{id}/classifications/{idx}` | API Key | Update a classification |
|
||||
| POST | `/api/openapi/jobs/{id}/approve` | API Key | Approve and generate tools |
|
||||
|
||||
Authentication is controlled by the `ADMIN_API_KEY` environment variable.
|
||||
API Key endpoints require the `X-API-Key` header. When `ADMIN_API_KEY` is unset, auth is disabled.
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# 启动 PostgreSQL 和应用
|
||||
docker compose up
|
||||
# Backend (516 tests, 94% coverage)
|
||||
cd backend
|
||||
pytest --cov=app --cov-report=term-missing
|
||||
|
||||
# 访问聊天界面
|
||||
open http://localhost:8000
|
||||
# Frontend (23 tests, vitest + happy-dom)
|
||||
cd frontend
|
||||
npm test
|
||||
```
|
||||
|
||||
## Agent 配置示例
|
||||
Backend coverage is enforced at 80%+.
|
||||
|
||||
```yaml
|
||||
# agents.yaml
|
||||
agents:
|
||||
- name: order_lookup
|
||||
description: 查询订单状态、物流信息
|
||||
permission: read
|
||||
personality:
|
||||
tone: professional
|
||||
greeting: "您好,我来帮您查询订单信息。"
|
||||
tools:
|
||||
- get_order_status
|
||||
- get_tracking_info
|
||||
## Documentation
|
||||
|
||||
- name: order_actions
|
||||
description: 取消订单、修改订单
|
||||
permission: write # 触发人工确认
|
||||
personality:
|
||||
tone: careful
|
||||
greeting: "我可以帮您处理订单变更,所有操作都会先经过您的确认。"
|
||||
tools:
|
||||
- cancel_order
|
||||
- modify_order
|
||||
|
||||
- name: discount
|
||||
description: 发放优惠券、折扣码
|
||||
permission: write
|
||||
tools:
|
||||
- apply_discount
|
||||
- generate_coupon
|
||||
```
|
||||
|
||||
## OpenAPI 自动接入
|
||||
|
||||
不需要手动写 MCP 连接器。粘贴你的 API 规范 URL:
|
||||
|
||||
1. 框架解析 OpenAPI 3.0 规范
|
||||
2. LLM 自动分类每个端点(读/写、客户参数、Agent 分组)
|
||||
3. 运维人员审核分类结果
|
||||
4. 自动生成 MCP 服务器 + Agent YAML 配置
|
||||
5. 新工具立即可用
|
||||
|
||||
## 安全设计
|
||||
|
||||
- **人工确认** - 所有写操作需要客户或运维人员批准
|
||||
- **SSRF 防护** - OpenAPI URL 导入时屏蔽内网地址和 DNS 重绑定攻击
|
||||
- **操作审计** - 每个操作记录 Agent、参数、结果、时间戳
|
||||
- **权限隔离** - 每个 Agent 只能访问其配置的工具集
|
||||
- **中断超时** - 30 分钟未确认的操作自动取消,防止过期审批
|
||||
|
||||
## 开发阶段
|
||||
|
||||
| 阶段 | 周期 | 内容 |
|
||||
|------|------|------|
|
||||
| Phase 1 | 第 1-3 周 | 核心框架:Chat UI + Supervisor + Agent 注册表 + 中断流程 |
|
||||
| Phase 2 | 第 3-4 周 | 多 Agent 路由 + Webhook 升级 + 垂直模板 |
|
||||
| Phase 3 | 第 4-6 周 | OpenAPI 自动发现 + MCP 服务器生成 + SSRF 防护 |
|
||||
| Phase 4 | 第 6-7 周 | 对话回放 + 数据分析仪表盘 |
|
||||
|
||||
## 目标用户
|
||||
|
||||
中型电商公司(日均 500-5000 订单,5-20 名客服)的客户体验负责人。
|
||||
|
||||
他们的痛点:客服需要在 Zendesk 和 Shopify 后台之间反复切换,手动执行查询和操作。Smart Support 让 AI 直接完成这些操作,人工只需审批关键步骤。
|
||||
|
||||
## 相关文档
|
||||
|
||||
- [设计文档](design-doc.md) - 问题定义、约束、方案选择
|
||||
- [CEO 计划](ceo-plan.md) - 产品愿景、范围决策
|
||||
- [工程评审计划](eng-review-plan.md) - 架构决策、测试策略、失败模式
|
||||
- [测试计划](eng-review-test-plan.md) - 测试路径、边界情况、E2E 流程
|
||||
- [待办事项](TODOS.md) - 延迟到后续阶段的工作
|
||||
| Document | Description |
|
||||
|----------|-------------|
|
||||
| [Architecture](docs/ARCHITECTURE.md) | System design, component diagram, data flow, ADRs |
|
||||
| [Development Plan](docs/DEVELOPMENT-PLAN.md) | Phase breakdown, task checklists, and status |
|
||||
| [Agent Config Guide](docs/agent-config-guide.md) | agents.yaml format, fields, templates, routing logic |
|
||||
| [OpenAPI Import Guide](docs/openapi-import-guide.md) | Auto-discovery workflow, REST API, SSRF protection |
|
||||
| [Deployment Guide](docs/deployment.md) | Docker, local dev, production, HTTPS, backups, scaling |
|
||||
| [Demo Script](docs/demo-script.md) | Step-by-step live demo walkthrough (5 scenes) |
|
||||
| [UX Design System](docs/ux_design_system.md) | Color palette, typography, component patterns, CSS tokens |
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -1,19 +1,34 @@
|
||||
# Database
|
||||
# Smart Support Backend -- environment variables
|
||||
# Copy to .env and fill in your values
|
||||
|
||||
# Required: PostgreSQL connection string
|
||||
DATABASE_URL=postgresql://smart_support:dev_password@localhost:5432/smart_support
|
||||
|
||||
# LLM Provider: anthropic | openai | google
|
||||
# Required: LLM provider configuration
|
||||
# provider: anthropic | openai | google
|
||||
LLM_PROVIDER=anthropic
|
||||
LLM_MODEL=claude-sonnet-4-6
|
||||
|
||||
# API Keys (set the one matching your LLM_PROVIDER)
|
||||
# API keys -- provide the one matching LLM_PROVIDER
|
||||
ANTHROPIC_API_KEY=
|
||||
OPENAI_API_KEY=
|
||||
GOOGLE_API_KEY=
|
||||
|
||||
# Session
|
||||
# Optional: webhook endpoint for escalation notifications
|
||||
# The backend will POST a JSON payload when a conversation is escalated.
|
||||
WEBHOOK_URL=
|
||||
WEBHOOK_TIMEOUT_SECONDS=10
|
||||
WEBHOOK_MAX_RETRIES=3
|
||||
|
||||
# Session management
|
||||
SESSION_TTL_MINUTES=30
|
||||
INTERRUPT_TTL_MINUTES=30
|
||||
|
||||
# Server
|
||||
# Optional: load a named agent template instead of agents.yaml
|
||||
# Leave blank to use the default agents.yaml in the backend directory.
|
||||
# Available templates: ecommerce, saas, generic
|
||||
TEMPLATE_NAME=
|
||||
|
||||
# Server binding
|
||||
WS_HOST=0.0.0.0
|
||||
WS_PORT=8000
|
||||
|
||||
@@ -20,6 +20,17 @@ agents:
|
||||
tools:
|
||||
- cancel_order
|
||||
|
||||
- name: discount
|
||||
description: "Applies discounts and generates coupon codes. Use when the customer asks for discounts, promotions, or coupon codes."
|
||||
permission: write
|
||||
personality:
|
||||
tone: "generous and accommodating"
|
||||
greeting: "I can help you with discounts and coupons!"
|
||||
escalation_message: "Let me connect you with our promotions team."
|
||||
tools:
|
||||
- apply_discount
|
||||
- generate_coupon
|
||||
|
||||
- name: fallback
|
||||
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
|
||||
permission: read
|
||||
|
||||
149
backend/alembic.ini
Normal file
149
backend/alembic.ini
Normal file
@@ -0,0 +1,149 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url =
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
backend/alembic/README
Normal file
1
backend/alembic/README
Normal file
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
67
backend/alembic/env.py
Normal file
67
backend/alembic/env.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Alembic environment configuration for smart-support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
config = context.config
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# No SQLAlchemy ORM models -- we use raw DDL migrations
|
||||
target_metadata = None
|
||||
|
||||
|
||||
def _get_url() -> str:
|
||||
"""Read DATABASE_URL from environment, falling back to alembic.ini."""
|
||||
return os.environ.get("DATABASE_URL", "") or config.get_main_option(
|
||||
"sqlalchemy.url", ""
|
||||
)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
Configures the context with just a URL so that an Engine
|
||||
is not required.
|
||||
"""
|
||||
url = _get_url()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode with a live database connection."""
|
||||
configuration = config.get_section(config.config_ini_section, {})
|
||||
configuration["sqlalchemy.url"] = _get_url()
|
||||
|
||||
connectable = engine_from_config(
|
||||
configuration,
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
backend/alembic/script.py.mako
Normal file
28
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
92
backend/alembic/versions/001_initial_schema.py
Normal file
92
backend/alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Initial schema -- all application tables.
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises:
|
||||
Create Date: 2026-04-06
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "a1b2c3d4e5f6"
|
||||
down_revision: str | None = None
|
||||
branch_labels: tuple[str, ...] | None = None
|
||||
depends_on: tuple[str, ...] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
thread_id TEXT PRIMARY KEY,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
total_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||
status TEXT NOT NULL DEFAULT 'active'
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS active_interrupts (
|
||||
interrupt_id TEXT PRIMARY KEY,
|
||||
thread_id TEXT NOT NULL REFERENCES conversations(thread_id),
|
||||
action TEXT NOT NULL,
|
||||
params JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
resolved_at TIMESTAMPTZ,
|
||||
resolution TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
thread_id TEXT PRIMARY KEY,
|
||||
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS analytics_events (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
thread_id TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
agent_name TEXT,
|
||||
tool_name TEXT,
|
||||
tokens_used INTEGER NOT NULL DEFAULT 0,
|
||||
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||
duration_ms INTEGER,
|
||||
success BOOLEAN,
|
||||
error_message TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Migration columns added in Phase 4
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE conversations
|
||||
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
||||
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
|
||||
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
|
||||
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP TABLE IF EXISTS analytics_events")
|
||||
op.execute("DROP TABLE IF EXISTS sessions")
|
||||
op.execute("DROP TABLE IF EXISTS active_interrupts")
|
||||
op.execute("DROP TABLE IF EXISTS conversations")
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.discount import apply_discount, generate_coupon
|
||||
from app.agents.fallback import fallback_respond
|
||||
from app.agents.order_actions import cancel_order
|
||||
from app.agents.order_lookup import get_order_status, get_tracking_info
|
||||
@@ -16,6 +17,8 @@ _TOOL_MAP: dict[str, BaseTool] = {
|
||||
"get_tracking_info": get_tracking_info,
|
||||
"cancel_order": cancel_order,
|
||||
"fallback_respond": fallback_respond,
|
||||
"apply_discount": apply_discount,
|
||||
"generate_coupon": generate_coupon,
|
||||
}
|
||||
|
||||
|
||||
|
||||
79
backend/app/agents/discount.py
Normal file
79
backend/app/agents/discount.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Discount agent tools -- apply discounts and generate coupons."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import interrupt
|
||||
|
||||
|
||||
@tool
|
||||
def apply_discount(order_id: str, discount_percent: int) -> dict:
|
||||
"""Apply a discount to an order. Requires human approval before execution."""
|
||||
if discount_percent < 1 or discount_percent > 100:
|
||||
return {
|
||||
"status": "error",
|
||||
"order_id": order_id,
|
||||
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
|
||||
}
|
||||
|
||||
response = interrupt(
|
||||
{
|
||||
"action": "apply_discount",
|
||||
"order_id": order_id,
|
||||
"discount_percent": discount_percent,
|
||||
"message": (
|
||||
f"Please confirm: apply {discount_percent}% discount to order {order_id}?"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(response, bool):
|
||||
approved = response
|
||||
elif isinstance(response, dict):
|
||||
approved = response.get("approved", False)
|
||||
else:
|
||||
approved = bool(response)
|
||||
|
||||
if approved:
|
||||
return {
|
||||
"status": "applied",
|
||||
"order_id": order_id,
|
||||
"discount_percent": discount_percent,
|
||||
"message": (
|
||||
f"{discount_percent}% discount applied to order {order_id}."
|
||||
),
|
||||
}
|
||||
return {
|
||||
"status": "declined",
|
||||
"order_id": order_id,
|
||||
"message": f"Discount for order {order_id} was declined.",
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def generate_coupon(discount_percent: int, expiry_days: int = 30) -> dict:
|
||||
"""Generate a coupon code with the specified discount percentage."""
|
||||
if discount_percent < 1 or discount_percent > 100:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
|
||||
}
|
||||
if expiry_days < 1:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Invalid expiry: {expiry_days} days. Must be at least 1.",
|
||||
}
|
||||
|
||||
coupon_code = f"SAVE{discount_percent}-{uuid.uuid4().hex[:8].upper()}"
|
||||
return {
|
||||
"status": "generated",
|
||||
"coupon_code": coupon_code,
|
||||
"discount_percent": discount_percent,
|
||||
"expiry_days": expiry_days,
|
||||
"message": (
|
||||
f"Coupon {coupon_code} generated: {discount_percent}% off, "
|
||||
f"valid for {expiry_days} days."
|
||||
),
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Fallback agent tools -- handles unmatched intents."""
|
||||
"""Fallback agent tools -- handles unmatched intents and clarification requests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -13,6 +13,7 @@ def fallback_respond(query: str) -> str:
|
||||
"Here's what I can do:\n"
|
||||
"- Check order status (e.g., 'What is the status of order 1042?')\n"
|
||||
"- Get tracking information (e.g., 'Track order 1042')\n"
|
||||
"- Cancel an order (e.g., 'Cancel order 1042')\n\n"
|
||||
"- Cancel an order (e.g., 'Cancel order 1042')\n"
|
||||
"- Apply discounts or generate coupons\n\n"
|
||||
"Could you please rephrase your request?"
|
||||
)
|
||||
|
||||
3
backend/app/analytics/__init__.py
Normal file
3
backend/app/analytics/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Analytics module -- event recording and dashboard queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
60
backend/app/analytics/api.py
Normal file
60
backend/app/analytics/api.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Analytics API router -- dashboard metrics endpoint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
|
||||
from app.analytics.queries import get_analytics
|
||||
from app.api_utils import envelope
|
||||
from app.auth import require_admin_api_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/analytics",
|
||||
tags=["analytics"],
|
||||
dependencies=[Depends(require_admin_api_key)],
|
||||
)
|
||||
|
||||
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
|
||||
_DEFAULT_RANGE = "7d"
|
||||
_MAX_RANGE_DAYS = 365
|
||||
|
||||
|
||||
async def _get_pool(request: Request) -> AsyncConnectionPool:
|
||||
"""Dependency: extract the shared pool from app state."""
|
||||
return request.app.state.pool
|
||||
|
||||
|
||||
def _parse_range(range_str: str) -> int:
|
||||
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
|
||||
match = _RANGE_PATTERN.match(range_str)
|
||||
if not match:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid range format. Expected: '<N>d' e.g. '7d', '30d'.",
|
||||
)
|
||||
days = int(match.group(1))
|
||||
if days < 1 or days > _MAX_RANGE_DAYS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Range must be between 1 and {_MAX_RANGE_DAYS} days.",
|
||||
)
|
||||
return days
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def analytics(
|
||||
request: Request,
|
||||
range: str = Query(default=_DEFAULT_RANGE, alias="range"), # noqa: A002
|
||||
) -> dict:
|
||||
"""Return aggregated analytics metrics for the given time range."""
|
||||
range_days = _parse_range(range)
|
||||
pool = await _get_pool(request)
|
||||
result = await get_analytics(pool, range_days=range_days)
|
||||
return envelope(asdict(result))
|
||||
97
backend/app/analytics/event_recorder.py
Normal file
97
backend/app/analytics/event_recorder.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Analytics event recorder -- Protocol and implementations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
from psycopg.types.json import Json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
_INSERT_SQL = """
|
||||
INSERT INTO analytics_events
|
||||
(thread_id, event_type, agent_name, tool_name, tokens_used, cost_usd,
|
||||
duration_ms, success, error_message, metadata)
|
||||
VALUES
|
||||
(%(thread_id)s, %(event_type)s, %(agent_name)s, %(tool_name)s,
|
||||
%(tokens_used)s, %(cost_usd)s, %(duration_ms)s, %(success)s,
|
||||
%(error_message)s, %(metadata)s)
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AnalyticsRecorder(Protocol):
|
||||
"""Protocol for recording analytics events."""
|
||||
|
||||
async def record(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
event_type: str,
|
||||
agent_name: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tokens_used: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
duration_ms: int | None = None,
|
||||
success: bool | None = None,
|
||||
error_message: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class NoOpAnalyticsRecorder:
|
||||
"""No-op implementation for testing or when the DB is unavailable."""
|
||||
|
||||
async def record(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
event_type: str,
|
||||
agent_name: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tokens_used: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
duration_ms: int | None = None,
|
||||
success: bool | None = None,
|
||||
error_message: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
|
||||
class PostgresAnalyticsRecorder:
|
||||
"""Postgres-backed analytics recorder -- INSERTs into analytics_events."""
|
||||
|
||||
def __init__(self, pool: AsyncConnectionPool) -> None:
|
||||
self._pool = pool
|
||||
|
||||
async def record(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
event_type: str,
|
||||
agent_name: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tokens_used: int = 0,
|
||||
cost_usd: float = 0.0,
|
||||
duration_ms: int | None = None,
|
||||
success: bool | None = None,
|
||||
error_message: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Insert one analytics event row."""
|
||||
params: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"event_type": event_type,
|
||||
"agent_name": agent_name,
|
||||
"tool_name": tool_name,
|
||||
"tokens_used": tokens_used,
|
||||
"cost_usd": cost_usd,
|
||||
"duration_ms": duration_ms,
|
||||
"success": success,
|
||||
"error_message": error_message,
|
||||
"metadata": Json(metadata or {}),
|
||||
}
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(_INSERT_SQL, params)
|
||||
38
backend/app/analytics/models.py
Normal file
38
backend/app/analytics/models.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Value objects for analytics dashboard."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentUsage:
|
||||
"""Agent usage statistics within a time range."""
|
||||
|
||||
agent: str
|
||||
count: int
|
||||
percentage: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InterruptStats:
|
||||
"""Interrupt approval/rejection statistics within a time range."""
|
||||
|
||||
total: int = 0
|
||||
approved: int = 0
|
||||
rejected: int = 0
|
||||
expired: int = 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnalyticsResult:
|
||||
"""Full analytics result for a given time range."""
|
||||
|
||||
range: str
|
||||
total_conversations: int
|
||||
resolution_rate: float
|
||||
escalation_rate: float
|
||||
avg_turns_per_conversation: float
|
||||
avg_cost_per_conversation_usd: float
|
||||
agent_usage: tuple[AgentUsage, ...]
|
||||
interrupt_stats: InterruptStats
|
||||
184
backend/app/analytics/queries.py
Normal file
184
backend/app/analytics/queries.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Analytics query functions -- all async, take pool + range_days."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
_RESOLUTION_RATE_SQL = """
|
||||
SELECT
|
||||
CASE WHEN COUNT(*) = 0 THEN 0.0
|
||||
ELSE COUNT(*) FILTER (WHERE resolution_type = 'resolved')::float / COUNT(*)
|
||||
END AS rate
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||
"""
|
||||
|
||||
_ESCALATION_RATE_SQL = """
|
||||
SELECT
|
||||
CASE WHEN COUNT(*) = 0 THEN 0.0
|
||||
ELSE COUNT(*) FILTER (WHERE resolution_type = 'escalated')::float / COUNT(*)
|
||||
END AS rate
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||
"""
|
||||
|
||||
_TOTAL_CONVERSATIONS_SQL = """
|
||||
SELECT COUNT(*) AS total
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||
"""
|
||||
|
||||
_AVG_TURNS_SQL = """
|
||||
SELECT COALESCE(AVG(turn_count), 0.0) AS avg_turns
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||
"""
|
||||
|
||||
_COST_PER_CONVERSATION_SQL = """
|
||||
SELECT COALESCE(AVG(total_cost_usd), 0.0) AS avg_cost
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||
"""
|
||||
|
||||
_AGENT_USAGE_SQL = """
|
||||
SELECT
|
||||
agent,
|
||||
COUNT(*) AS count,
|
||||
ROUND(COUNT(*) * 100.0 / NULLIF(SUM(COUNT(*)) OVER (), 0), 2) AS percentage
|
||||
FROM (
|
||||
SELECT UNNEST(agents_used) AS agent
|
||||
FROM conversations
|
||||
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||
AND agents_used IS NOT NULL
|
||||
) sub
|
||||
GROUP BY agent
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
|
||||
_INTERRUPT_STATS_SQL = """
|
||||
SELECT
|
||||
COUNT(*) FILTER (WHERE event_type = 'interrupt') AS total,
|
||||
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND success = TRUE) AS approved,
|
||||
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND success = FALSE
|
||||
AND error_message IS NULL) AS rejected,
|
||||
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND error_message = 'expired') AS expired
|
||||
FROM analytics_events
|
||||
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
|
||||
"""
|
||||
|
||||
|
||||
async def resolution_rate(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||
"""Return the fraction of resolved conversations in the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_RESOLUTION_RATE_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0.0
|
||||
return float(row.get("rate") or 0.0)
|
||||
|
||||
|
||||
async def escalation_rate(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||
"""Return the fraction of escalated conversations in the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_ESCALATION_RATE_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0.0
|
||||
return float(row.get("rate") or 0.0)
|
||||
|
||||
|
||||
async def _total_conversations(pool: AsyncConnectionPool, range_days: int) -> int:
|
||||
"""Return the total number of conversations in the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_TOTAL_CONVERSATIONS_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0
|
||||
return int(row.get("total") or 0)
|
||||
|
||||
|
||||
async def _avg_turns(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||
"""Return the average turn count per conversation in the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_AVG_TURNS_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0.0
|
||||
return float(row.get("avg_turns") or 0.0)
|
||||
|
||||
|
||||
async def cost_per_conversation(pool: AsyncConnectionPool, range_days: int) -> float:
|
||||
"""Return the average cost per conversation in the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_COST_PER_CONVERSATION_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return 0.0
|
||||
return float(row.get("avg_cost") or 0.0)
|
||||
|
||||
|
||||
async def agent_usage(
|
||||
pool: AsyncConnectionPool, range_days: int
|
||||
) -> tuple[AgentUsage, ...]:
|
||||
"""Return per-agent usage statistics for the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_AGENT_USAGE_SQL, {"days": range_days})
|
||||
rows = await cursor.fetchall()
|
||||
if not rows:
|
||||
return ()
|
||||
return tuple(
|
||||
AgentUsage(
|
||||
agent=row["agent"],
|
||||
count=int(row["count"]),
|
||||
percentage=float(row["percentage"]),
|
||||
)
|
||||
for row in rows
|
||||
)
|
||||
|
||||
|
||||
async def interrupt_stats(
|
||||
pool: AsyncConnectionPool, range_days: int
|
||||
) -> InterruptStats:
|
||||
"""Return interrupt approval/rejection statistics for the given range."""
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_INTERRUPT_STATS_SQL, {"days": range_days})
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return InterruptStats()
|
||||
return InterruptStats(
|
||||
total=int(row.get("total") or 0),
|
||||
approved=int(row.get("approved") or 0),
|
||||
rejected=int(row.get("rejected") or 0),
|
||||
expired=int(row.get("expired") or 0),
|
||||
)
|
||||
|
||||
|
||||
async def get_analytics(
|
||||
pool: AsyncConnectionPool, range_days: int
|
||||
) -> AnalyticsResult:
|
||||
"""Aggregate all analytics metrics into a single AnalyticsResult."""
|
||||
res_rate, esc_rate, cost, usage, i_stats, total, avg_t = await asyncio.gather(
|
||||
resolution_rate(pool, range_days),
|
||||
escalation_rate(pool, range_days),
|
||||
cost_per_conversation(pool, range_days),
|
||||
agent_usage(pool, range_days),
|
||||
interrupt_stats(pool, range_days),
|
||||
_total_conversations(pool, range_days),
|
||||
_avg_turns(pool, range_days),
|
||||
)
|
||||
return AnalyticsResult(
|
||||
range=f"{range_days}d",
|
||||
total_conversations=total,
|
||||
resolution_rate=res_rate,
|
||||
escalation_rate=esc_rate,
|
||||
avg_turns_per_conversation=avg_t,
|
||||
avg_cost_per_conversation_usd=cost,
|
||||
agent_usage=usage,
|
||||
interrupt_stats=i_stats,
|
||||
)
|
||||
10
backend/app/api_utils.py
Normal file
10
backend/app/api_utils.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Shared API response helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
|
||||
"""Wrap API response data in a standard envelope format."""
|
||||
return {"success": success, "data": data, "error": error}
|
||||
72
backend/app/auth.py
Normal file
72
backend/app/auth.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""API key authentication for admin endpoints and WebSocket connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from typing import Annotated
|
||||
|
||||
import structlog
|
||||
from fastapi import Depends, HTTPException, Query, Request, WebSocket, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
def _get_admin_api_key(request: Request) -> str:
|
||||
"""Retrieve the configured admin API key from app settings.
|
||||
|
||||
Returns empty string if settings are not configured (test/dev mode).
|
||||
"""
|
||||
settings = getattr(request.app.state, "settings", None)
|
||||
if settings is None:
|
||||
return ""
|
||||
key = getattr(settings, "admin_api_key", "")
|
||||
return key if isinstance(key, str) else ""
|
||||
|
||||
|
||||
async def require_admin_api_key(
|
||||
request: Request,
|
||||
api_key: Annotated[str | None, Depends(_API_KEY_HEADER)] = None,
|
||||
) -> None:
|
||||
"""Dependency that enforces API key authentication on admin endpoints.
|
||||
|
||||
Skips validation when no admin_api_key is configured (dev mode).
|
||||
"""
|
||||
expected = _get_admin_api_key(request)
|
||||
if not expected:
|
||||
return
|
||||
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing X-API-Key header",
|
||||
)
|
||||
if not secrets.compare_digest(api_key, expected):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
|
||||
|
||||
async def verify_ws_token(
|
||||
ws: WebSocket,
|
||||
token: str | None = Query(default=None),
|
||||
) -> None:
|
||||
"""Verify WebSocket connection token from query parameter.
|
||||
|
||||
Skips validation when no admin_api_key is configured (dev mode).
|
||||
Usage: ws://host/ws?token=<api_key>
|
||||
"""
|
||||
settings = ws.app.state.settings
|
||||
expected = settings.admin_api_key
|
||||
if not expected:
|
||||
return
|
||||
|
||||
if token is None or not secrets.compare_digest(token, expected):
|
||||
await ws.close(code=4001, reason="Unauthorized")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid or missing WebSocket token",
|
||||
)
|
||||
@@ -17,7 +17,7 @@ class Settings(BaseSettings):
|
||||
|
||||
database_url: str
|
||||
|
||||
llm_provider: Literal["anthropic", "openai", "google"] = "anthropic"
|
||||
llm_provider: Literal["anthropic", "openai", "azure_openai", "google"] = "anthropic"
|
||||
llm_model: str = "claude-sonnet-4-6"
|
||||
|
||||
session_ttl_minutes: int = 30
|
||||
@@ -26,8 +26,22 @@ class Settings(BaseSettings):
|
||||
ws_host: str = "0.0.0.0"
|
||||
ws_port: int = 8000
|
||||
|
||||
webhook_url: str = ""
|
||||
webhook_timeout_seconds: int = 10
|
||||
webhook_max_retries: int = 3
|
||||
|
||||
template_name: str = ""
|
||||
|
||||
log_format: str = "console" # "console" for dev, "json" for production
|
||||
|
||||
admin_api_key: str = ""
|
||||
|
||||
anthropic_api_key: str = ""
|
||||
openai_api_key: str = ""
|
||||
azure_openai_api_key: str = ""
|
||||
azure_openai_endpoint: str = ""
|
||||
azure_openai_api_version: str = "2024-12-01-preview"
|
||||
azure_openai_deployment: str = ""
|
||||
google_api_key: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -35,6 +49,7 @@ class Settings(BaseSettings):
|
||||
key_map = {
|
||||
"anthropic": self.anthropic_api_key,
|
||||
"openai": self.openai_api_key,
|
||||
"azure_openai": self.azure_openai_api_key,
|
||||
"google": self.google_api_key,
|
||||
}
|
||||
key = key_map.get(self.llm_provider, "")
|
||||
@@ -43,4 +58,13 @@ class Settings(BaseSettings):
|
||||
f"API key for provider '{self.llm_provider}' is required. "
|
||||
f"Set the corresponding environment variable."
|
||||
)
|
||||
if self.llm_provider == "azure_openai":
|
||||
if not self.azure_openai_endpoint:
|
||||
raise ValueError(
|
||||
"AZURE_OPENAI_ENDPOINT is required for azure_openai provider."
|
||||
)
|
||||
if not self.azure_openai_deployment:
|
||||
raise ValueError(
|
||||
"AZURE_OPENAI_DEPLOYMENT is required for azure_openai provider."
|
||||
)
|
||||
return self
|
||||
|
||||
135
backend/app/conversation_tracker.py
Normal file
135
backend/app/conversation_tracker.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Conversation tracker -- Protocol and implementations for tracking conversation state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
_ENSURE_SQL = """
|
||||
INSERT INTO conversations
|
||||
(thread_id, created_at, last_activity)
|
||||
VALUES
|
||||
(%(thread_id)s, NOW(), NOW())
|
||||
ON CONFLICT (thread_id) DO NOTHING
|
||||
"""
|
||||
|
||||
_RECORD_TURN_SQL = """
|
||||
UPDATE conversations
|
||||
SET
|
||||
turn_count = turn_count + 1,
|
||||
agents_used = CASE
|
||||
WHEN %(agent_name)s IS NOT NULL AND NOT (agents_used @> ARRAY[%(agent_name)s]::text[])
|
||||
THEN agents_used || ARRAY[%(agent_name)s]::text[]
|
||||
ELSE agents_used
|
||||
END,
|
||||
total_tokens = total_tokens + %(tokens)s,
|
||||
total_cost_usd = total_cost_usd + %(cost)s,
|
||||
last_activity = NOW()
|
||||
WHERE thread_id = %(thread_id)s
|
||||
"""
|
||||
|
||||
_RESOLVE_SQL = """
|
||||
UPDATE conversations
|
||||
SET
|
||||
resolution_type = %(resolution_type)s,
|
||||
ended_at = NOW()
|
||||
WHERE thread_id = %(thread_id)s
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ConversationTrackerProtocol(Protocol):
|
||||
"""Protocol for tracking conversation lifecycle and metrics."""
|
||||
|
||||
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
||||
"""Create conversation row if it does not already exist."""
|
||||
...
|
||||
|
||||
async def record_turn(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
agent_name: str | None,
|
||||
tokens: int,
|
||||
cost: float,
|
||||
) -> None:
|
||||
"""Increment turn count and update aggregated metrics."""
|
||||
...
|
||||
|
||||
async def resolve(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
resolution_type: str,
|
||||
) -> None:
|
||||
"""Mark conversation as resolved with a resolution type."""
|
||||
...
|
||||
|
||||
|
||||
class NoOpConversationTracker:
|
||||
"""No-op implementation -- used in tests or when DB is unavailable."""
|
||||
|
||||
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
async def record_turn(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
agent_name: str | None,
|
||||
tokens: int,
|
||||
cost: float,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
async def resolve(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
resolution_type: str,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
|
||||
class PostgresConversationTracker:
|
||||
"""Postgres-backed conversation tracker."""
|
||||
|
||||
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
|
||||
"""Insert conversation row; do nothing if already exists (ON CONFLICT DO NOTHING)."""
|
||||
params = {"thread_id": thread_id}
|
||||
async with pool.connection() as conn:
|
||||
await conn.execute(_ENSURE_SQL, params)
|
||||
|
||||
async def record_turn(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
agent_name: str | None,
|
||||
tokens: int,
|
||||
cost: float,
|
||||
) -> None:
|
||||
"""Increment turn count, append agent if new, update token/cost totals."""
|
||||
params = {
|
||||
"thread_id": thread_id,
|
||||
"agent_name": agent_name,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
}
|
||||
async with pool.connection() as conn:
|
||||
await conn.execute(_RECORD_TURN_SQL, params)
|
||||
|
||||
async def resolve(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
thread_id: str,
|
||||
resolution_type: str,
|
||||
) -> None:
|
||||
"""Set resolution_type and ended_at on the conversation row."""
|
||||
params = {
|
||||
"thread_id": thread_id,
|
||||
"resolution_type": resolution_type,
|
||||
}
|
||||
async with pool.connection() as conn:
|
||||
await conn.execute(_RESOLVE_SQL, params)
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
@@ -34,6 +35,40 @@ CREATE TABLE IF NOT EXISTS active_interrupts (
|
||||
);
|
||||
"""
|
||||
|
||||
_ANALYTICS_EVENTS_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS analytics_events (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
thread_id TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
agent_name TEXT,
|
||||
tool_name TEXT,
|
||||
tokens_used INTEGER NOT NULL DEFAULT 0,
|
||||
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
|
||||
duration_ms INTEGER,
|
||||
success BOOLEAN,
|
||||
error_message TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
_SESSIONS_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
thread_id TEXT PRIMARY KEY,
|
||||
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
_CONVERSATIONS_MIGRATION_DDL = """
|
||||
ALTER TABLE conversations
|
||||
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
|
||||
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
|
||||
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
|
||||
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ;
|
||||
"""
|
||||
|
||||
|
||||
async def create_pool(settings: Settings) -> AsyncConnectionPool:
|
||||
"""Create an async connection pool with the required psycopg settings."""
|
||||
@@ -54,8 +89,22 @@ async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
|
||||
return checkpointer
|
||||
|
||||
|
||||
def run_alembic_migrations(database_url: str) -> None:
|
||||
"""Run Alembic migrations to head."""
|
||||
from alembic.config import Config
|
||||
|
||||
from alembic import command
|
||||
|
||||
alembic_cfg = Config(str(Path(__file__).parent.parent / "alembic.ini"))
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
|
||||
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
|
||||
"""Create application-specific tables (conversations, active_interrupts)."""
|
||||
"""Create application-specific tables and apply migrations."""
|
||||
async with pool.connection() as conn:
|
||||
await conn.execute(_CONVERSATIONS_DDL)
|
||||
await conn.execute(_INTERRUPTS_DDL)
|
||||
await conn.execute(_SESSIONS_DDL)
|
||||
await conn.execute(_ANALYTICS_EVENTS_DDL)
|
||||
await conn.execute(_CONVERSATIONS_MIGRATION_DDL)
|
||||
|
||||
140
backend/app/escalation.py
Normal file
140
backend/app/escalation.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Webhook escalation module -- HTTP POST with exponential backoff retry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class EscalationPayload(BaseModel, frozen=True):
|
||||
"""Immutable payload sent to the escalation webhook."""
|
||||
|
||||
thread_id: str
|
||||
reason: str
|
||||
conversation_summary: str
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EscalationResult:
|
||||
"""Immutable result of an escalation attempt."""
|
||||
|
||||
success: bool
|
||||
status_code: int | None
|
||||
attempts: int
|
||||
error: str | None
|
||||
|
||||
|
||||
class EscalationService(Protocol):
|
||||
"""Protocol for escalation implementations."""
|
||||
|
||||
async def escalate(self, payload: EscalationPayload) -> EscalationResult: ...
|
||||
|
||||
|
||||
class WebhookEscalator:
|
||||
"""Sends escalation requests via HTTP POST with exponential backoff retry."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
timeout_seconds: int = 10,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
self._url = url
|
||||
self._timeout = timeout_seconds
|
||||
self._max_retries = max_retries
|
||||
|
||||
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
|
||||
"""POST the escalation payload to the configured webhook URL."""
|
||||
if not self._url:
|
||||
return EscalationResult(
|
||||
success=False,
|
||||
status_code=None,
|
||||
attempts=0,
|
||||
error="Webhook URL not configured",
|
||||
)
|
||||
|
||||
last_error: str | None = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
try:
|
||||
response = await client.post(
|
||||
self._url,
|
||||
json=payload.model_dump(),
|
||||
)
|
||||
|
||||
if 200 <= response.status_code < 300:
|
||||
logger.info(
|
||||
"Escalation succeeded for thread %s (attempt %d)",
|
||||
payload.thread_id,
|
||||
attempt,
|
||||
)
|
||||
return EscalationResult(
|
||||
success=True,
|
||||
status_code=response.status_code,
|
||||
attempts=attempt,
|
||||
error=None,
|
||||
)
|
||||
|
||||
last_error = f"HTTP {response.status_code}"
|
||||
logger.warning(
|
||||
"Escalation attempt %d/%d failed: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
last_error,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
last_error = "Request timed out"
|
||||
logger.warning(
|
||||
"Escalation attempt %d/%d timed out",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
last_error = str(exc)
|
||||
logger.warning(
|
||||
"Escalation attempt %d/%d error: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
last_error,
|
||||
)
|
||||
|
||||
# Exponential backoff: skip delay after last attempt
|
||||
if attempt < self._max_retries:
|
||||
delay = 2**attempt
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.error(
|
||||
"Escalation failed for thread %s after %d attempts: %s",
|
||||
payload.thread_id,
|
||||
self._max_retries,
|
||||
last_error,
|
||||
)
|
||||
return EscalationResult(
|
||||
success=False,
|
||||
status_code=None,
|
||||
attempts=self._max_retries,
|
||||
error=last_error,
|
||||
)
|
||||
|
||||
|
||||
class NoOpEscalator:
|
||||
"""Escalator that does nothing -- used when webhook URL is not configured."""
|
||||
|
||||
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
|
||||
logger.info("Escalation disabled (no webhook URL). Thread: %s", payload.thread_id)
|
||||
return EscalationResult(
|
||||
success=False,
|
||||
status_code=None,
|
||||
attempts=0,
|
||||
error="Escalation disabled",
|
||||
)
|
||||
@@ -4,27 +4,46 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain.agents import create_agent
|
||||
from langgraph_supervisor import create_supervisor
|
||||
|
||||
from app.agents import get_tools_by_names
|
||||
from app.graph_context import GraphContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.intent import IntentClassifier
|
||||
from app.registry import AgentRegistry
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
SUPERVISOR_PROMPT = (
|
||||
"You are a customer support supervisor. "
|
||||
"Route customer requests to the appropriate agent based on their description. "
|
||||
"For order status and tracking queries, use the order_lookup agent. "
|
||||
"For order modifications like cancellations, use the order_actions agent. "
|
||||
"For anything else, use the fallback agent."
|
||||
"Route customer requests to the appropriate agent based on their description.\n\n"
|
||||
"Available agents and their roles:\n"
|
||||
"{agent_descriptions}\n\n"
|
||||
"Routing rules:\n"
|
||||
"- For order status and tracking queries, use the order_lookup agent.\n"
|
||||
"- For order modifications like cancellations, use the order_actions agent.\n"
|
||||
"- For discounts, promotions, or coupon codes, use the discount agent.\n"
|
||||
"- For anything else or when uncertain, use the fallback agent.\n"
|
||||
"- If the user's request involves multiple actions, execute them in order.\n"
|
||||
"- If a previous intent classification is provided, follow it.\n"
|
||||
)
|
||||
|
||||
|
||||
def _format_agent_descriptions(registry: AgentRegistry) -> str:
|
||||
"""Build agent description text for the supervisor prompt."""
|
||||
lines = []
|
||||
for agent in registry.list_agents():
|
||||
lines.append(f"- {agent.name}: {agent.description}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def build_agent_nodes(
|
||||
registry: AgentRegistry,
|
||||
llm: BaseChatModel,
|
||||
@@ -41,11 +60,11 @@ def build_agent_nodes(
|
||||
f"Permission level: {agent_config.permission}."
|
||||
)
|
||||
|
||||
agent_node = create_react_agent(
|
||||
agent_node = create_agent(
|
||||
model=llm,
|
||||
tools=tools,
|
||||
name=agent_config.name,
|
||||
prompt=system_prompt,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
agent_nodes.append(agent_node)
|
||||
|
||||
@@ -56,15 +75,29 @@ def build_graph(
|
||||
registry: AgentRegistry,
|
||||
llm: BaseChatModel,
|
||||
checkpointer: AsyncPostgresSaver,
|
||||
) -> CompiledStateGraph:
|
||||
"""Build and compile the LangGraph supervisor graph."""
|
||||
intent_classifier: IntentClassifier | None = None,
|
||||
) -> GraphContext:
|
||||
"""Build and compile the LangGraph supervisor graph.
|
||||
|
||||
Returns a GraphContext that bundles the compiled graph with its
|
||||
associated registry and intent classifier.
|
||||
"""
|
||||
agent_nodes = build_agent_nodes(registry, llm)
|
||||
agent_descriptions = _format_agent_descriptions(registry)
|
||||
|
||||
prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions)
|
||||
|
||||
workflow = create_supervisor(
|
||||
agent_nodes,
|
||||
agents=agent_nodes,
|
||||
model=llm,
|
||||
prompt=SUPERVISOR_PROMPT,
|
||||
prompt=prompt,
|
||||
output_mode="full_history",
|
||||
)
|
||||
|
||||
return workflow.compile(checkpointer=checkpointer)
|
||||
compiled = workflow.compile(checkpointer=checkpointer)
|
||||
|
||||
return GraphContext(
|
||||
graph=compiled,
|
||||
registry=registry,
|
||||
intent_classifier=intent_classifier,
|
||||
)
|
||||
|
||||
36
backend/app/graph_context.py
Normal file
36
backend/app/graph_context.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""GraphContext -- typed wrapper around the compiled graph and its dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.intent import ClassificationResult, IntentClassifier
|
||||
from app.registry import AgentRegistry
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GraphContext:
|
||||
"""Bundles the compiled LangGraph graph with its associated services.
|
||||
|
||||
Replaces the previous pattern of monkey-patching attributes onto the
|
||||
third-party CompiledStateGraph instance.
|
||||
"""
|
||||
|
||||
graph: CompiledStateGraph
|
||||
registry: AgentRegistry
|
||||
intent_classifier: IntentClassifier | None = None
|
||||
|
||||
async def classify_intent(self, message: str) -> ClassificationResult | None:
|
||||
"""Classify user intent using the attached classifier.
|
||||
|
||||
Returns None if no classifier is configured.
|
||||
"""
|
||||
if self.intent_classifier is None:
|
||||
return None
|
||||
|
||||
agents = self.registry.list_agents()
|
||||
return await self.intent_classifier.classify(message, agents)
|
||||
119
backend/app/intent.py
Normal file
119
backend/app/intent.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Intent classification using LLM structured output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.registry import AgentConfig
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
CLASSIFICATION_PROMPT = (
|
||||
"You are an intent classifier for a customer support system.\n"
|
||||
"Given a user message, determine which agent(s) should handle it.\n\n"
|
||||
"Available agents:\n{agent_list}\n\n"
|
||||
"Rules:\n"
|
||||
"- If the message clearly maps to one agent, return a single intent.\n"
|
||||
"- If the message contains multiple distinct requests, return multiple intents "
|
||||
"in execution order.\n"
|
||||
"- If the message is vague or doesn't match any agent, set is_ambiguous=True "
|
||||
"and provide a clarification_question.\n"
|
||||
"- Never route to the fallback agent unless truly ambiguous.\n"
|
||||
"- confidence should be between 0.0 and 1.0.\n"
|
||||
)
|
||||
|
||||
AMBIGUITY_THRESHOLD = 0.5
|
||||
|
||||
|
||||
class IntentTarget(BaseModel, frozen=True):
|
||||
"""A single classified intent targeting a specific agent."""
|
||||
|
||||
agent_name: str
|
||||
confidence: float
|
||||
reasoning: str
|
||||
|
||||
|
||||
class ClassificationResult(BaseModel, frozen=True):
|
||||
"""Result of intent classification -- may contain multiple intents."""
|
||||
|
||||
intents: tuple[IntentTarget, ...]
|
||||
is_ambiguous: bool = False
|
||||
clarification_question: str | None = None
|
||||
|
||||
|
||||
class IntentClassifier(Protocol):
|
||||
"""Protocol for intent classification implementations."""
|
||||
|
||||
async def classify(
|
||||
self,
|
||||
message: str,
|
||||
available_agents: tuple[AgentConfig, ...],
|
||||
) -> ClassificationResult: ...
|
||||
|
||||
|
||||
def _build_agent_list(agents: tuple[AgentConfig, ...]) -> str:
|
||||
"""Format agent descriptions for the classification prompt."""
|
||||
lines = []
|
||||
for agent in agents:
|
||||
lines.append(f"- {agent.name}: {agent.description} (permission: {agent.permission})")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class LLMIntentClassifier:
|
||||
"""Classifies user intent using LLM structured output."""
|
||||
|
||||
def __init__(self, llm: BaseChatModel) -> None:
|
||||
self._llm = llm
|
||||
|
||||
async def classify(
|
||||
self,
|
||||
message: str,
|
||||
available_agents: tuple[AgentConfig, ...],
|
||||
) -> ClassificationResult:
|
||||
"""Classify user message into one or more agent intents."""
|
||||
agent_list = _build_agent_list(available_agents)
|
||||
system_prompt = CLASSIFICATION_PROMPT.format(agent_list=agent_list)
|
||||
|
||||
structured_llm = self._llm.with_structured_output(ClassificationResult)
|
||||
|
||||
try:
|
||||
result = await structured_llm.ainvoke(
|
||||
[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": message},
|
||||
]
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Intent classification failed, returning ambiguous")
|
||||
return ClassificationResult(
|
||||
intents=(),
|
||||
is_ambiguous=True,
|
||||
clarification_question="I'm not sure I understood. Could you please rephrase?",
|
||||
)
|
||||
|
||||
if not isinstance(result, ClassificationResult):
|
||||
return ClassificationResult(
|
||||
intents=(),
|
||||
is_ambiguous=True,
|
||||
clarification_question="I'm not sure I understood. Could you please rephrase?",
|
||||
)
|
||||
|
||||
# Apply ambiguity threshold
|
||||
if result.intents and all(i.confidence < AMBIGUITY_THRESHOLD for i in result.intents):
|
||||
return ClassificationResult(
|
||||
intents=result.intents,
|
||||
is_ambiguous=True,
|
||||
clarification_question=(
|
||||
result.clarification_question
|
||||
or "I'm not sure I understood. Could you please rephrase?"
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
268
backend/app/interrupt_manager.py
Normal file
268
backend/app/interrupt_manager.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.
|
||||
|
||||
Provides both in-memory (InterruptManager) and PostgreSQL-backed
|
||||
(PgInterruptManager) implementations behind a common Protocol.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InterruptRecord:
|
||||
"""Immutable record of a pending interrupt."""
|
||||
|
||||
interrupt_id: str
|
||||
thread_id: str
|
||||
action: str
|
||||
params: dict
|
||||
created_at: float
|
||||
ttl_seconds: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InterruptStatus:
|
||||
"""Current status of a tracked interrupt."""
|
||||
|
||||
is_expired: bool
|
||||
remaining_seconds: float
|
||||
record: InterruptRecord
|
||||
|
||||
|
||||
class InterruptManagerProtocol(Protocol):
|
||||
"""Protocol for interrupt TTL management."""
|
||||
|
||||
def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ...
|
||||
def check_status(self, thread_id: str) -> InterruptStatus | None: ...
|
||||
def resolve(self, thread_id: str) -> None: ...
|
||||
def has_pending(self, thread_id: str) -> bool: ...
|
||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ...
|
||||
|
||||
|
||||
def _build_retry_prompt(expired_record: InterruptRecord) -> dict:
|
||||
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||
return {
|
||||
"type": "interrupt_expired",
|
||||
"thread_id": expired_record.thread_id,
|
||||
"action": expired_record.action,
|
||||
"message": (
|
||||
f"The approval request for '{expired_record.action}' has expired "
|
||||
f"after {expired_record.ttl_seconds // 60} minutes. "
|
||||
f"Would you like to try again?"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class InterruptManager:
|
||||
"""In-memory interrupt manager for single-worker development.
|
||||
|
||||
Complements SessionManager -- this tracks interrupt-specific TTL
|
||||
while SessionManager handles session-level TTL.
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: int = 1800) -> None:
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._interrupts: dict[str, InterruptRecord] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
thread_id: str,
|
||||
action: str,
|
||||
params: dict,
|
||||
) -> InterruptRecord:
|
||||
"""Register a new pending interrupt with TTL tracking."""
|
||||
record = InterruptRecord(
|
||||
interrupt_id=uuid.uuid4().hex,
|
||||
thread_id=thread_id,
|
||||
action=action,
|
||||
params=dict(params),
|
||||
created_at=time.time(),
|
||||
ttl_seconds=self._ttl_seconds,
|
||||
)
|
||||
self._interrupts = {**self._interrupts, thread_id: record}
|
||||
return record
|
||||
|
||||
def check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||
"""Check the TTL status of a pending interrupt."""
|
||||
record = self._interrupts.get(thread_id)
|
||||
if record is None:
|
||||
return None
|
||||
elapsed = time.time() - record.created_at
|
||||
remaining = max(0.0, record.ttl_seconds - elapsed)
|
||||
is_expired = elapsed > record.ttl_seconds
|
||||
return InterruptStatus(
|
||||
is_expired=is_expired,
|
||||
remaining_seconds=remaining,
|
||||
record=record,
|
||||
)
|
||||
|
||||
def resolve(self, thread_id: str) -> None:
|
||||
"""Remove a resolved interrupt from tracking."""
|
||||
self._interrupts = {
|
||||
k: v for k, v in self._interrupts.items() if k != thread_id
|
||||
}
|
||||
|
||||
def cleanup_expired(self) -> tuple[InterruptRecord, ...]:
|
||||
"""Find and remove all expired interrupts. Returns the expired records."""
|
||||
now = time.time()
|
||||
expired: list[InterruptRecord] = []
|
||||
active: dict[str, InterruptRecord] = {}
|
||||
for thread_id, record in self._interrupts.items():
|
||||
if now - record.created_at > record.ttl_seconds:
|
||||
expired.append(record)
|
||||
else:
|
||||
active[thread_id] = record
|
||||
self._interrupts = active
|
||||
return tuple(expired)
|
||||
|
||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||
"""Generate a WebSocket message prompting the user to retry an expired action."""
|
||||
return _build_retry_prompt(expired_record)
|
||||
|
||||
def has_pending(self, thread_id: str) -> bool:
|
||||
"""Check if a thread has a pending (non-expired) interrupt."""
|
||||
status = self.check_status(thread_id)
|
||||
if status is None:
|
||||
return False
|
||||
return not status.is_expired
|
||||
|
||||
|
||||
# Alias for explicit naming
|
||||
InMemoryInterruptManager = InterruptManager
|
||||
|
||||
|
||||
class PgInterruptManager:
|
||||
"""PostgreSQL-backed interrupt manager for multi-worker production.
|
||||
|
||||
Uses the existing active_interrupts table defined in db.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
ttl_seconds: int = 1800,
|
||||
) -> None:
|
||||
self._pool = pool
|
||||
self._ttl_seconds = ttl_seconds
|
||||
|
||||
def register(
|
||||
self,
|
||||
thread_id: str,
|
||||
action: str,
|
||||
params: dict,
|
||||
) -> InterruptRecord:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._register(thread_id, action, params)
|
||||
)
|
||||
|
||||
async def _register(
|
||||
self, thread_id: str, action: str, params: dict
|
||||
) -> InterruptRecord:
|
||||
import json
|
||||
|
||||
record = InterruptRecord(
|
||||
interrupt_id=uuid.uuid4().hex,
|
||||
thread_id=thread_id,
|
||||
action=action,
|
||||
params=dict(params),
|
||||
created_at=time.time(),
|
||||
ttl_seconds=self._ttl_seconds,
|
||||
)
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO active_interrupts (interrupt_id, thread_id, action, params)
|
||||
VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s)
|
||||
ON CONFLICT (thread_id) WHERE resolved_at IS NULL
|
||||
DO UPDATE SET
|
||||
interrupt_id = %(iid)s,
|
||||
action = %(action)s,
|
||||
params = %(params)s,
|
||||
created_at = NOW(),
|
||||
resolved_at = NULL
|
||||
""",
|
||||
{
|
||||
"iid": record.interrupt_id,
|
||||
"tid": thread_id,
|
||||
"action": action,
|
||||
"params": json.dumps(params),
|
||||
},
|
||||
)
|
||||
return record
|
||||
|
||||
def check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._check_status(thread_id)
|
||||
)
|
||||
|
||||
async def _check_status(self, thread_id: str) -> InterruptStatus | None:
|
||||
async with self._pool.connection() as conn:
|
||||
cursor = await conn.execute(
|
||||
"""
|
||||
SELECT interrupt_id, action, params, created_at
|
||||
FROM active_interrupts
|
||||
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
""",
|
||||
{"tid": thread_id},
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
created_at = row["created_at"].timestamp()
|
||||
elapsed = time.time() - created_at
|
||||
remaining = max(0.0, self._ttl_seconds - elapsed)
|
||||
is_expired = elapsed > self._ttl_seconds
|
||||
|
||||
record = InterruptRecord(
|
||||
interrupt_id=row["interrupt_id"],
|
||||
thread_id=thread_id,
|
||||
action=row["action"],
|
||||
params=row["params"] if isinstance(row["params"], dict) else {},
|
||||
created_at=created_at,
|
||||
ttl_seconds=self._ttl_seconds,
|
||||
)
|
||||
|
||||
return InterruptStatus(
|
||||
is_expired=is_expired,
|
||||
remaining_seconds=remaining,
|
||||
record=record,
|
||||
)
|
||||
|
||||
def resolve(self, thread_id: str) -> None:
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(self._resolve(thread_id))
|
||||
|
||||
async def _resolve(self, thread_id: str) -> None:
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE active_interrupts
|
||||
SET resolved_at = NOW(), resolution = 'resolved'
|
||||
WHERE thread_id = %(tid)s AND resolved_at IS NULL
|
||||
""",
|
||||
{"tid": thread_id},
|
||||
)
|
||||
|
||||
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
|
||||
return _build_retry_prompt(expired_record)
|
||||
|
||||
def has_pending(self, thread_id: str) -> bool:
|
||||
status = self.check_status(thread_id)
|
||||
if status is None:
|
||||
return False
|
||||
return not status.is_expired
|
||||
@@ -31,6 +31,16 @@ def create_llm(settings: Settings) -> BaseChatModel:
|
||||
api_key=settings.openai_api_key,
|
||||
)
|
||||
|
||||
if provider == "azure_openai":
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
return AzureChatOpenAI(
|
||||
azure_deployment=settings.azure_openai_deployment,
|
||||
azure_endpoint=settings.azure_openai_endpoint,
|
||||
api_key=settings.azure_openai_api_key,
|
||||
api_version=settings.azure_openai_api_version,
|
||||
)
|
||||
|
||||
if provider == "google":
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
@@ -39,4 +49,7 @@ def create_llm(settings: Settings) -> BaseChatModel:
|
||||
google_api_key=settings.google_api_key,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown LLM provider: '{provider}'. Use 'anthropic', 'openai', or 'google'.")
|
||||
raise ValueError(
|
||||
f"Unknown LLM provider: '{provider}'. "
|
||||
"Use 'anthropic', 'openai', 'azure_openai', or 'google'."
|
||||
)
|
||||
|
||||
57
backend/app/logging_config.py
Normal file
57
backend/app/logging_config.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Structured logging configuration using structlog."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def configure_logging(log_format: str = "console") -> None:
|
||||
"""Configure structlog with stdlib integration.
|
||||
|
||||
Args:
|
||||
log_format: "console" for human-readable dev output,
|
||||
"json" for machine-parseable production output.
|
||||
"""
|
||||
shared_processors: list[structlog.types.Processor] = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.filter_by_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
]
|
||||
|
||||
if log_format == "json":
|
||||
renderer: structlog.types.Processor = structlog.processors.JSONRenderer()
|
||||
else:
|
||||
renderer = structlog.dev.ConsoleRenderer()
|
||||
|
||||
structlog.configure(
|
||||
processors=[
|
||||
*shared_processors,
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
],
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
formatter = structlog.stdlib.ProcessorFormatter(
|
||||
processors=[
|
||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||
renderer,
|
||||
],
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(logging.INFO)
|
||||
@@ -2,79 +2,211 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import contextlib
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.analytics.api import router as analytics_router
|
||||
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||
from app.api_utils import envelope
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.config import Settings
|
||||
from app.db import create_checkpointer, create_pool, setup_app_tables
|
||||
from app.conversation_tracker import PostgresConversationTracker
|
||||
from app.db import create_checkpointer, create_pool, run_alembic_migrations
|
||||
from app.escalation import NoOpEscalator, WebhookEscalator
|
||||
from app.graph import build_graph
|
||||
from app.intent import LLMIntentClassifier
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.llm import create_llm
|
||||
from app.logging_config import configure_logging
|
||||
from app.openapi.review_api import router as openapi_router
|
||||
from app.registry import AgentRegistry
|
||||
from app.replay.api import router as replay_router
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
|
||||
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
|
||||
|
||||
|
||||
async def _interrupt_cleanup_loop(
|
||||
interrupt_manager: InterruptManager,
|
||||
interval: int = 60,
|
||||
) -> None:
|
||||
"""Periodically remove expired interrupts in the background.
|
||||
|
||||
Runs until cancelled. Catches all exceptions to prevent the task
|
||||
from dying unexpectedly.
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
try:
|
||||
expired = interrupt_manager.cleanup_expired()
|
||||
if expired:
|
||||
logger.info(
|
||||
"Cleaned up %d expired interrupt(s)",
|
||||
len(expired),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error during interrupt cleanup")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
settings = Settings()
|
||||
configure_logging(settings.log_format)
|
||||
|
||||
pool = await create_pool(settings)
|
||||
checkpointer = await create_checkpointer(pool)
|
||||
await setup_app_tables(pool)
|
||||
run_alembic_migrations(settings.database_url)
|
||||
|
||||
# Load agents from template or default YAML
|
||||
if settings.template_name:
|
||||
registry = AgentRegistry.load_template(settings.template_name)
|
||||
else:
|
||||
registry = AgentRegistry.load(AGENTS_YAML)
|
||||
|
||||
registry = AgentRegistry.load(AGENTS_YAML)
|
||||
llm = create_llm(settings)
|
||||
graph = build_graph(registry, llm, checkpointer)
|
||||
intent_classifier = LLMIntentClassifier(llm)
|
||||
graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
|
||||
|
||||
session_manager = SessionManager(
|
||||
session_ttl_seconds=settings.session_ttl_minutes * 60,
|
||||
)
|
||||
interrupt_manager = InterruptManager(
|
||||
ttl_seconds=settings.interrupt_ttl_minutes * 60,
|
||||
)
|
||||
|
||||
app.state.graph = graph
|
||||
# Configure escalation
|
||||
if settings.webhook_url:
|
||||
escalator = WebhookEscalator(
|
||||
url=settings.webhook_url,
|
||||
timeout_seconds=settings.webhook_timeout_seconds,
|
||||
max_retries=settings.webhook_max_retries,
|
||||
)
|
||||
else:
|
||||
escalator = NoOpEscalator()
|
||||
|
||||
app.state.graph_ctx = graph_ctx
|
||||
app.state.session_manager = session_manager
|
||||
app.state.interrupt_manager = interrupt_manager
|
||||
app.state.escalator = escalator
|
||||
app.state.settings = settings
|
||||
app.state.pool = pool
|
||||
app.state.analytics_recorder = PostgresAnalyticsRecorder(pool=pool)
|
||||
app.state.conversation_tracker = PostgresConversationTracker()
|
||||
|
||||
logger.info(
|
||||
"Smart Support started: %d agents loaded, LLM=%s/%s",
|
||||
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
|
||||
len(registry),
|
||||
settings.llm_provider,
|
||||
settings.llm_model,
|
||||
settings.template_name or "(default)",
|
||||
)
|
||||
|
||||
cleanup_task = asyncio.create_task(
|
||||
_interrupt_cleanup_loop(interrupt_manager),
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
cleanup_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await cleanup_task
|
||||
|
||||
await pool.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Smart Support", version="0.1.0", lifespan=lifespan)
|
||||
_VERSION = "0.6.0"
|
||||
|
||||
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
|
||||
|
||||
app.include_router(openapi_router)
|
||||
app.include_router(replay_router)
|
||||
app.include_router(analytics_router)
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||
"""Wrap HTTPException in standard envelope format."""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=envelope(None, success=False, error=exc.detail),
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||
"""Wrap validation errors in standard envelope format."""
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=envelope(None, success=False, error=str(exc)),
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||
"""Catch-all handler -- never leak stack traces."""
|
||||
logger.exception("Unhandled exception: %s", exc)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=envelope(None, success=False, error="Internal server error"),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
def health_check() -> dict:
|
||||
"""Health check endpoint for load balancers and monitoring."""
|
||||
return {"status": "ok", "version": _VERSION}
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket) -> None:
|
||||
await ws.accept()
|
||||
graph = app.state.graph
|
||||
session_manager = app.state.session_manager
|
||||
async def websocket_endpoint(
|
||||
ws: WebSocket,
|
||||
token: str | None = Query(default=None),
|
||||
) -> None:
|
||||
settings = app.state.settings
|
||||
|
||||
# Verify WebSocket token when admin_api_key is configured
|
||||
if settings.admin_api_key:
|
||||
import secrets as _secrets
|
||||
|
||||
if token is None or not _secrets.compare_digest(token, settings.admin_api_key):
|
||||
await ws.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await ws.accept()
|
||||
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
|
||||
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=app.state.graph_ctx,
|
||||
session_manager=app.state.session_manager,
|
||||
callback_handler=callback_handler,
|
||||
interrupt_manager=app.state.interrupt_manager,
|
||||
analytics_recorder=app.state.analytics_recorder,
|
||||
conversation_tracker=app.state.conversation_tracker,
|
||||
pool=app.state.pool,
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_data = await ws.receive_text()
|
||||
await dispatch_message(ws, graph, session_manager, callback_handler, raw_data)
|
||||
await dispatch_message(ws, ws_ctx, raw_data)
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket client disconnected")
|
||||
|
||||
|
||||
2
backend/app/openapi/__init__.py
Normal file
2
backend/app/openapi/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# OpenAPI auto-discovery module
|
||||
# Parses OpenAPI specs, classifies endpoints via LLM, generates tools
|
||||
169
backend/app/openapi/classifier.py
Normal file
169
backend/app/openapi/classifier.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""OpenAPI endpoint classifier.
|
||||
|
||||
Classifies endpoints into read/write access types and identifies
|
||||
customer-identifying parameters. Provides a rule-based heuristic
|
||||
classifier and an LLM-backed classifier with heuristic fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Protocol
|
||||
|
||||
import structlog
|
||||
|
||||
from app.openapi.models import ClassificationResult, EndpointInfo
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
|
||||
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
|
||||
|
||||
# Parameter names that identify the customer/order context
|
||||
_CUSTOMER_PARAM_PATTERNS = re.compile(
|
||||
r"(customer_id|order_id|user_id|account_id|client_id|subscriber_id)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class ClassifierProtocol(Protocol):
|
||||
"""Protocol for endpoint classifiers."""
|
||||
|
||||
async def classify(
|
||||
self, endpoints: tuple[EndpointInfo, ...]
|
||||
) -> tuple[ClassificationResult, ...]: ...
|
||||
|
||||
|
||||
class HeuristicClassifier:
|
||||
"""Rule-based endpoint classifier.
|
||||
|
||||
GET -> read, no interrupt.
|
||||
POST/PUT/PATCH/DELETE -> write, needs_interrupt=True.
|
||||
"""
|
||||
|
||||
async def classify(
|
||||
self, endpoints: tuple[EndpointInfo, ...]
|
||||
) -> tuple[ClassificationResult, ...]:
|
||||
"""Classify endpoints using HTTP method heuristics."""
|
||||
if not endpoints:
|
||||
return ()
|
||||
return tuple(_classify_one(ep) for ep in endpoints)
|
||||
|
||||
|
||||
def _classify_one(ep: EndpointInfo) -> ClassificationResult:
|
||||
"""Classify a single endpoint using heuristics."""
|
||||
access_type = "write" if ep.method in _WRITE_METHODS else "read"
|
||||
needs_interrupt = ep.method in _INTERRUPT_METHODS
|
||||
customer_params = _detect_customer_params(ep)
|
||||
agent_group = "write_agent" if access_type == "write" else "read_agent"
|
||||
return ClassificationResult(
|
||||
endpoint=ep,
|
||||
access_type=access_type,
|
||||
customer_params=customer_params,
|
||||
agent_group=agent_group,
|
||||
confidence=0.7,
|
||||
needs_interrupt=needs_interrupt,
|
||||
)
|
||||
|
||||
|
||||
def _detect_customer_params(ep: EndpointInfo) -> tuple[str, ...]:
|
||||
"""Extract parameter names that identify the customer/order context."""
|
||||
return tuple(
|
||||
p.name
|
||||
for p in ep.parameters
|
||||
if _CUSTOMER_PARAM_PATTERNS.fullmatch(p.name)
|
||||
)
|
||||
|
||||
|
||||
class LLMClassifier:
|
||||
"""LLM-backed endpoint classifier with heuristic fallback.
|
||||
|
||||
Uses an LLM to classify endpoints with higher accuracy.
|
||||
Falls back to HeuristicClassifier on any LLM error.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: object) -> None:
|
||||
self._llm = llm
|
||||
self._fallback = HeuristicClassifier()
|
||||
|
||||
async def classify(
|
||||
self, endpoints: tuple[EndpointInfo, ...]
|
||||
) -> tuple[ClassificationResult, ...]:
|
||||
"""Classify endpoints using LLM with heuristic fallback."""
|
||||
if not endpoints:
|
||||
return ()
|
||||
try:
|
||||
return await self._classify_with_llm(endpoints)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"LLM classification failed, falling back to heuristic",
|
||||
exc_info=True,
|
||||
)
|
||||
return await self._fallback.classify(endpoints)
|
||||
|
||||
async def _classify_with_llm(
|
||||
self, endpoints: tuple[EndpointInfo, ...]
|
||||
) -> tuple[ClassificationResult, ...]:
|
||||
"""Attempt LLM-based classification."""
|
||||
prompt = _build_classification_prompt(endpoints)
|
||||
response = await self._llm.ainvoke(prompt)
|
||||
parsed = _parse_llm_response(response.content, endpoints)
|
||||
return parsed
|
||||
|
||||
|
||||
def _build_classification_prompt(endpoints: tuple[EndpointInfo, ...]) -> str:
|
||||
"""Build a prompt for classifying endpoints."""
|
||||
items = []
|
||||
for i, ep in enumerate(endpoints):
|
||||
items.append(
|
||||
f"{i}: {ep.method} {ep.path} - {ep.summary or ep.operation_id}"
|
||||
)
|
||||
endpoint_list = "\n".join(items)
|
||||
return (
|
||||
"Classify each API endpoint as 'read' or 'write'. "
|
||||
"For each, determine if it needs human interrupt approval, "
|
||||
"identify customer-identifying parameters, and assign an agent_group.\n\n"
|
||||
f"Endpoints:\n{endpoint_list}\n\n"
|
||||
"Respond with a JSON array with one object per endpoint:\n"
|
||||
'[{"access_type": "read|write", "agent_group": "...", '
|
||||
'"needs_interrupt": true|false, "customer_params": [...], "confidence": 0.0-1.0}]'
|
||||
)
|
||||
|
||||
|
||||
def _parse_llm_response(
|
||||
content: str, endpoints: tuple[EndpointInfo, ...]
|
||||
) -> tuple[ClassificationResult, ...]:
|
||||
"""Parse LLM JSON response into ClassificationResult instances.
|
||||
|
||||
Raises ValueError if the response cannot be parsed or is mismatched.
|
||||
"""
|
||||
# Extract JSON array from response
|
||||
match = re.search(r"\[.*\]", content, re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError(f"No JSON array found in LLM response: {content!r}")
|
||||
|
||||
items = json.loads(match.group())
|
||||
if not isinstance(items, list) or len(items) != len(endpoints):
|
||||
raise ValueError(
|
||||
f"LLM returned {len(items)} items for {len(endpoints)} endpoints"
|
||||
)
|
||||
|
||||
results: list[ClassificationResult] = []
|
||||
for ep, item in zip(endpoints, items, strict=True):
|
||||
raw_access = item.get("access_type", "read")
|
||||
access_type = raw_access if raw_access in {"read", "write"} else "read"
|
||||
confidence = max(0.0, min(1.0, float(item.get("confidence", 0.8))))
|
||||
raw_group = str(item.get("agent_group", "support"))
|
||||
agent_group = raw_group if raw_group.strip() else "support"
|
||||
results.append(
|
||||
ClassificationResult(
|
||||
endpoint=ep,
|
||||
access_type=access_type,
|
||||
customer_params=tuple(item.get("customer_params", [])),
|
||||
agent_group=agent_group,
|
||||
confidence=confidence,
|
||||
needs_interrupt=bool(item.get("needs_interrupt", False)),
|
||||
)
|
||||
)
|
||||
return tuple(results)
|
||||
93
backend/app/openapi/fetcher.py
Normal file
93
backend/app/openapi/fetcher.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""OpenAPI spec fetcher with SSRF protection.
|
||||
|
||||
Fetches OpenAPI spec documents from remote URLs, validates them against
|
||||
SSRF policy, and parses JSON or YAML format automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import yaml
|
||||
|
||||
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
|
||||
|
||||
_MAX_SIZE_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
|
||||
async def fetch_spec(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> dict:
|
||||
"""Fetch an OpenAPI spec from a URL and return as a dict.
|
||||
|
||||
Auto-detects JSON or YAML format from content-type header or URL extension.
|
||||
Enforces a 10MB size limit.
|
||||
|
||||
Raises:
|
||||
SSRFError: If the URL is blocked by SSRF policy.
|
||||
ValueError: If the response is too large or cannot be parsed.
|
||||
"""
|
||||
from app.openapi.ssrf import safe_fetch
|
||||
|
||||
response = await safe_fetch(url, policy=policy)
|
||||
response.raise_for_status()
|
||||
|
||||
content = response.text
|
||||
if len(content.encode("utf-8")) > _MAX_SIZE_BYTES:
|
||||
raise ValueError(
|
||||
f"Response too large: {len(content.encode('utf-8'))} bytes "
|
||||
f"(max {_MAX_SIZE_BYTES} bytes)"
|
||||
)
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
return _parse_content(content, content_type, url)
|
||||
|
||||
|
||||
def _parse_content(content: str, content_type: str, url: str) -> dict:
|
||||
"""Parse content as JSON or YAML based on content-type or URL extension."""
|
||||
if _is_yaml_format(content_type, url):
|
||||
return _parse_yaml(content)
|
||||
if _is_json_format(content_type, url):
|
||||
return _parse_json(content)
|
||||
# Fall back: try JSON first, then YAML
|
||||
try:
|
||||
return _parse_json(content)
|
||||
except ValueError:
|
||||
return _parse_yaml(content)
|
||||
|
||||
|
||||
def _is_yaml_format(content_type: str, url: str) -> bool:
|
||||
"""Check if the content is YAML format."""
|
||||
yaml_types = {"application/x-yaml", "text/yaml", "application/yaml"}
|
||||
if any(t in content_type for t in yaml_types):
|
||||
return True
|
||||
lower_url = url.lower().split("?")[0]
|
||||
return lower_url.endswith(".yaml") or lower_url.endswith(".yml")
|
||||
|
||||
|
||||
def _is_json_format(content_type: str, url: str) -> bool:
|
||||
"""Check if the content is JSON format."""
|
||||
if "application/json" in content_type:
|
||||
return True
|
||||
lower_url = url.lower().split("?")[0]
|
||||
return lower_url.endswith(".json")
|
||||
|
||||
|
||||
def _parse_json(content: str) -> dict:
|
||||
"""Parse content as JSON, raising ValueError on failure."""
|
||||
try:
|
||||
result = json.loads(content)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Invalid JSON: {exc}") from exc
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(f"Expected a JSON object, got {type(result).__name__}")
|
||||
return result
|
||||
|
||||
|
||||
def _parse_yaml(content: str) -> dict:
|
||||
"""Parse content as YAML, raising ValueError on failure."""
|
||||
try:
|
||||
result = yaml.safe_load(content)
|
||||
except yaml.YAMLError as exc:
|
||||
raise ValueError(f"Invalid YAML: {exc}") from exc
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(f"Expected a YAML mapping, got {type(result).__name__}")
|
||||
return result
|
||||
164
backend/app/openapi/generator.py
Normal file
164
backend/app/openapi/generator.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tool code generator for classified OpenAPI endpoints.
|
||||
|
||||
Generates Python source code for LangChain @tool functions and
|
||||
YAML agent configurations from classification results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import keyword
|
||||
import re
|
||||
|
||||
import yaml
|
||||
|
||||
from app.openapi.models import ClassificationResult, EndpointInfo, GeneratedTool, ParameterInfo
|
||||
|
||||
_INDENT = " "
|
||||
|
||||
|
||||
def generate_tool_code(classification: ClassificationResult, base_url: str) -> GeneratedTool:
|
||||
"""Generate a LangChain @tool function for a classified endpoint.
|
||||
|
||||
Returns a GeneratedTool with the function source code as a string.
|
||||
"""
|
||||
ep = classification.endpoint
|
||||
func_name = _to_snake_case(ep.operation_id)
|
||||
params = _collect_params(ep)
|
||||
sig = _build_signature(params, ep.request_body_schema)
|
||||
docstring = _sanitize_docstring(ep.summary or ep.description or ep.operation_id)
|
||||
interrupt_comment = _interrupt_comment(classification)
|
||||
http_call = _build_http_call(ep, base_url, params)
|
||||
|
||||
lines = [
|
||||
"@tool",
|
||||
f"async def {func_name}({sig}) -> str:",
|
||||
f'{_INDENT}"""{docstring}"""',
|
||||
]
|
||||
if interrupt_comment:
|
||||
lines.append(f"{_INDENT}{interrupt_comment}")
|
||||
lines += [
|
||||
f"{_INDENT}async with httpx.AsyncClient() as client:",
|
||||
f"{_INDENT}{_INDENT}{http_call}",
|
||||
f"{_INDENT}{_INDENT}return response.text",
|
||||
]
|
||||
|
||||
code = "\n".join(lines)
|
||||
return GeneratedTool(
|
||||
function_name=func_name,
|
||||
endpoint=ep,
|
||||
classification=classification,
|
||||
code=code,
|
||||
)
|
||||
|
||||
|
||||
def generate_agent_yaml(
|
||||
classifications: tuple[ClassificationResult, ...],
|
||||
base_url: str,
|
||||
) -> str:
|
||||
"""Generate an agents.yaml string from a set of classification results.
|
||||
|
||||
Groups tools by agent_group, creating one agent entry per group.
|
||||
"""
|
||||
if not classifications:
|
||||
return yaml.safe_dump({"agents": []})
|
||||
|
||||
groups: dict[str, dict] = {}
|
||||
for clf in classifications:
|
||||
group = clf.agent_group
|
||||
func_name = _to_snake_case(clf.endpoint.operation_id)
|
||||
if group not in groups:
|
||||
permission = "read" if clf.access_type == "read" else "write"
|
||||
groups[group] = {
|
||||
"name": group,
|
||||
"description": f"Agent for {group} operations",
|
||||
"permission": permission,
|
||||
"tools": [],
|
||||
}
|
||||
groups[group]["tools"].append(func_name)
|
||||
|
||||
return yaml.safe_dump({"agents": list(groups.values())}, sort_keys=False)
|
||||
|
||||
|
||||
# --- Private helpers ---
|
||||
|
||||
|
||||
def _collect_params(ep: EndpointInfo) -> list[ParameterInfo]:
|
||||
"""Return path params first, then query params."""
|
||||
path_params = [p for p in ep.parameters if p.location == "path"]
|
||||
other_params = [p for p in ep.parameters if p.location != "path"]
|
||||
return path_params + other_params
|
||||
|
||||
|
||||
def _build_signature(params: list[ParameterInfo], body_schema: dict | None) -> str:
|
||||
"""Build a Python function signature string from parameters."""
|
||||
parts: list[str] = []
|
||||
for p in params:
|
||||
py_type = _schema_type_to_python(p.schema_type)
|
||||
if p.required:
|
||||
parts.append(f"{p.name}: {py_type}")
|
||||
else:
|
||||
parts.append(f"{p.name}: {py_type} | None = None")
|
||||
if body_schema:
|
||||
parts.append("body: dict | None = None")
|
||||
return ", ".join(parts)
|
||||
|
||||
|
||||
def _build_http_call(
|
||||
ep: EndpointInfo, base_url: str, params: list[ParameterInfo]
|
||||
) -> str:
|
||||
"""Build the httpx client call line."""
|
||||
method = ep.method.lower()
|
||||
path = ep.path
|
||||
|
||||
# Replace path parameters with f-string expressions
|
||||
for p in params:
|
||||
if p.location == "path":
|
||||
path = path.replace(f"{{{p.name}}}", f"{{{p.name}}}")
|
||||
|
||||
url_expr = f'f"{base_url}{path}"'
|
||||
|
||||
query_params = [p for p in params if p.location == "query"]
|
||||
extra_args = []
|
||||
if query_params:
|
||||
qp_dict = "{" + ", ".join(f'"{p.name}": {p.name}' for p in query_params) + "}"
|
||||
extra_args.append(f"params={qp_dict}")
|
||||
|
||||
if ep.request_body_schema and method in ("post", "put", "patch"):
|
||||
extra_args.append("json=body")
|
||||
|
||||
args_str = ", ".join([url_expr] + extra_args)
|
||||
return f"response = await client.{method}({args_str})"
|
||||
|
||||
|
||||
def _interrupt_comment(classification: ClassificationResult) -> str:
|
||||
"""Return a comment line if the endpoint requires interrupt/approval."""
|
||||
if classification.needs_interrupt:
|
||||
return "# INTERRUPT: requires human approval before execution"
|
||||
return ""
|
||||
|
||||
|
||||
def _schema_type_to_python(schema_type: str) -> str:
|
||||
"""Map OpenAPI schema type to Python type annotation."""
|
||||
mapping = {
|
||||
"string": "str",
|
||||
"integer": "int",
|
||||
"number": "float",
|
||||
"boolean": "bool",
|
||||
"array": "list",
|
||||
"object": "dict",
|
||||
}
|
||||
return mapping.get(schema_type, "str")
|
||||
|
||||
|
||||
def _sanitize_docstring(text: str) -> str:
|
||||
"""Escape triple-quotes and newlines to prevent docstring injection."""
|
||||
return text.replace("\\", "\\\\").replace('"""', r"\"\"\"").replace("\n", " ")
|
||||
|
||||
|
||||
def _to_snake_case(name: str) -> str:
|
||||
"""Convert operationId to a valid snake_case Python identifier."""
|
||||
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
|
||||
result = clean.lower() or "unnamed_tool"
|
||||
if keyword.iskeyword(result):
|
||||
result = f"{result}_tool"
|
||||
return result
|
||||
111
backend/app/openapi/importer.py
Normal file
111
backend/app/openapi/importer.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Import orchestrator for OpenAPI auto-discovery pipeline.
|
||||
|
||||
Orchestrates: fetch -> validate -> parse -> classify
|
||||
Each stage updates the job status and calls the on_progress callback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import replace
|
||||
|
||||
import structlog
|
||||
|
||||
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
from app.openapi.models import ImportJob
|
||||
from app.openapi.parser import parse_endpoints
|
||||
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
ProgressCallback = Callable[[str, ImportJob], None] | None
|
||||
|
||||
|
||||
class ImportOrchestrator:
|
||||
"""Orchestrates the full OpenAPI import pipeline.
|
||||
|
||||
Stages:
|
||||
1. fetching -- download and parse spec from URL
|
||||
2. validating -- check spec structure
|
||||
3. parsing -- extract endpoint definitions
|
||||
4. classifying -- classify endpoints for agent routing
|
||||
5. done / failed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierProtocol | None = None,
|
||||
policy: SSRFPolicy = DEFAULT_POLICY,
|
||||
) -> None:
|
||||
self._classifier = classifier or HeuristicClassifier()
|
||||
self._policy = policy
|
||||
|
||||
async def start_import(
|
||||
self,
|
||||
url: str,
|
||||
job_id: str,
|
||||
on_progress: ProgressCallback,
|
||||
) -> ImportJob:
|
||||
"""Run the full import pipeline for a spec URL.
|
||||
|
||||
Returns an ImportJob reflecting final status (done or failed).
|
||||
on_progress is called with (stage_name, current_job) at each stage.
|
||||
Passing None for on_progress is safe.
|
||||
"""
|
||||
job = ImportJob(
|
||||
job_id=job_id,
|
||||
status="pending",
|
||||
spec_url=url,
|
||||
)
|
||||
|
||||
try:
|
||||
# Stage 1: fetch
|
||||
job = _update(job, status="fetching")
|
||||
_notify(on_progress, "fetching", job)
|
||||
spec_dict = await fetch_spec(url, self._policy)
|
||||
|
||||
# Stage 2: validate
|
||||
job = _update(job, status="validating")
|
||||
_notify(on_progress, "validating", job)
|
||||
errors = validate_spec(spec_dict)
|
||||
if errors:
|
||||
raise ValueError(f"Invalid spec: {'; '.join(errors)}")
|
||||
|
||||
# Stage 3: parse
|
||||
job = _update(job, status="parsing")
|
||||
_notify(on_progress, "parsing", job)
|
||||
endpoints = parse_endpoints(spec_dict)
|
||||
|
||||
# Stage 4: classify
|
||||
job = _update(job, status="classifying", total_endpoints=len(endpoints))
|
||||
_notify(on_progress, "classifying", job)
|
||||
classifications = await self._classifier.classify(endpoints)
|
||||
|
||||
# Done
|
||||
job = _update(
|
||||
job,
|
||||
status="done",
|
||||
total_endpoints=len(endpoints),
|
||||
classified_count=len(classifications),
|
||||
)
|
||||
_notify(on_progress, "done", job)
|
||||
return job
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("Import pipeline failed for job %s", job_id)
|
||||
job = _update(job, status="failed", error_message=str(exc))
|
||||
_notify(on_progress, "failed", job)
|
||||
return job
|
||||
|
||||
|
||||
def _update(job: ImportJob, **kwargs: object) -> ImportJob:
|
||||
"""Return a new ImportJob with updated fields (immutable update)."""
|
||||
return replace(job, **kwargs)
|
||||
|
||||
|
||||
def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None:
|
||||
"""Call the progress callback if provided."""
|
||||
if callback is not None:
|
||||
callback(stage, job)
|
||||
67
backend/app/openapi/models.py
Normal file
67
backend/app/openapi/models.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Data models for OpenAPI auto-discovery module.
|
||||
|
||||
Frozen dataclasses for all value objects to ensure immutability.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParameterInfo:
|
||||
"""Describes a single endpoint parameter."""
|
||||
|
||||
name: str
|
||||
location: str # "path", "query", "header", "cookie"
|
||||
required: bool
|
||||
schema_type: str # "string", "integer", "boolean", etc.
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EndpointInfo:
|
||||
"""Describes a single API endpoint."""
|
||||
|
||||
path: str
|
||||
method: str # uppercase: GET, POST, PUT, DELETE, PATCH
|
||||
operation_id: str
|
||||
summary: str
|
||||
description: str
|
||||
parameters: tuple[ParameterInfo, ...] = field(default_factory=tuple)
|
||||
request_body_schema: dict | None = None
|
||||
response_schema: dict | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClassificationResult:
|
||||
"""Result of classifying an endpoint for agent routing."""
|
||||
|
||||
endpoint: EndpointInfo
|
||||
access_type: str # "read" or "write"
|
||||
customer_params: tuple[str, ...] # param names that identify the customer
|
||||
agent_group: str # which agent group handles this endpoint
|
||||
confidence: float # 0.0 to 1.0
|
||||
needs_interrupt: bool # requires human approval before execution
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImportJob:
|
||||
"""Tracks the state of an OpenAPI import job."""
|
||||
|
||||
job_id: str
|
||||
status: str # "pending", "fetching", "validating", "parsing", "classifying", "done", "failed"
|
||||
spec_url: str
|
||||
total_endpoints: int = 0
|
||||
classified_count: int = 0
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneratedTool:
|
||||
"""A generated LangChain tool from a classified endpoint."""
|
||||
|
||||
function_name: str
|
||||
endpoint: EndpointInfo
|
||||
classification: ClassificationResult
|
||||
code: str
|
||||
152
backend/app/openapi/parser.py
Normal file
152
backend/app/openapi/parser.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""OpenAPI spec endpoint parser.
|
||||
|
||||
Extracts all endpoint definitions from a parsed OpenAPI spec dict,
|
||||
resolving $ref references from components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from app.openapi.models import EndpointInfo, ParameterInfo
|
||||
|
||||
_HTTP_METHODS = ("get", "post", "put", "patch", "delete", "head", "options")
|
||||
|
||||
|
||||
def parse_endpoints(spec_dict: dict) -> tuple[EndpointInfo, ...]:
|
||||
"""Parse all endpoints from a validated OpenAPI spec dict.
|
||||
|
||||
Returns an immutable tuple of EndpointInfo instances.
|
||||
"""
|
||||
paths = spec_dict.get("paths", {})
|
||||
if not isinstance(paths, dict) or not paths:
|
||||
return ()
|
||||
|
||||
endpoints: list[EndpointInfo] = []
|
||||
for path, path_item in paths.items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
for method in _HTTP_METHODS:
|
||||
operation = path_item.get(method)
|
||||
if operation is None:
|
||||
continue
|
||||
endpoint = _parse_operation(path, method.upper(), operation, spec_dict)
|
||||
endpoints.append(endpoint)
|
||||
|
||||
return tuple(endpoints)
|
||||
|
||||
|
||||
def _parse_operation(
|
||||
path: str,
|
||||
method: str,
|
||||
operation: dict,
|
||||
spec_dict: dict,
|
||||
) -> EndpointInfo:
|
||||
"""Parse a single operation dict into an EndpointInfo."""
|
||||
operation_id = operation.get("operationId") or _generate_operation_id(path, method)
|
||||
summary = operation.get("summary", "")
|
||||
description = operation.get("description", "")
|
||||
|
||||
parameters = _parse_parameters(operation.get("parameters", []), spec_dict)
|
||||
request_body_schema = _parse_request_body(operation.get("requestBody"), spec_dict)
|
||||
response_schema = _parse_response_schema(operation.get("responses", {}), spec_dict)
|
||||
|
||||
return EndpointInfo(
|
||||
path=path,
|
||||
method=method,
|
||||
operation_id=operation_id,
|
||||
summary=summary,
|
||||
description=description,
|
||||
parameters=tuple(parameters),
|
||||
request_body_schema=request_body_schema,
|
||||
response_schema=response_schema,
|
||||
)
|
||||
|
||||
|
||||
def _parse_parameters(
|
||||
params_list: list,
|
||||
spec_dict: dict,
|
||||
) -> list[ParameterInfo]:
|
||||
"""Parse list of parameter dicts into ParameterInfo instances."""
|
||||
result: list[ParameterInfo] = []
|
||||
for param in params_list:
|
||||
if not isinstance(param, dict):
|
||||
continue
|
||||
schema = param.get("schema", {})
|
||||
schema_type = schema.get("type", "string") if isinstance(schema, dict) else "string"
|
||||
result.append(
|
||||
ParameterInfo(
|
||||
name=param.get("name", ""),
|
||||
location=param.get("in", "query"),
|
||||
required=bool(param.get("required", False)),
|
||||
schema_type=schema_type,
|
||||
description=param.get("description", ""),
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _parse_request_body(request_body: dict | None, spec_dict: dict) -> dict | None:
|
||||
"""Extract schema from requestBody, resolving $ref if present."""
|
||||
if not isinstance(request_body, dict):
|
||||
return None
|
||||
content = request_body.get("content", {})
|
||||
if not isinstance(content, dict):
|
||||
return None
|
||||
# Prefer application/json
|
||||
for media_type in ("application/json", *content.keys()):
|
||||
media = content.get(media_type)
|
||||
if not isinstance(media, dict):
|
||||
continue
|
||||
schema = media.get("schema")
|
||||
if schema:
|
||||
return _resolve_ref(schema, spec_dict)
|
||||
return None
|
||||
|
||||
|
||||
def _parse_response_schema(responses: dict, spec_dict: dict) -> dict | None:
|
||||
"""Extract schema from the first 2xx response."""
|
||||
if not isinstance(responses, dict):
|
||||
return None
|
||||
for status_code in ("200", "201", "202", "204"):
|
||||
response = responses.get(status_code)
|
||||
if not isinstance(response, dict):
|
||||
continue
|
||||
content = response.get("content", {})
|
||||
if not isinstance(content, dict):
|
||||
continue
|
||||
for media_type in ("application/json", *content.keys()):
|
||||
media = content.get(media_type)
|
||||
if not isinstance(media, dict):
|
||||
continue
|
||||
schema = media.get("schema")
|
||||
if schema:
|
||||
return _resolve_ref(schema, spec_dict)
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_ref(schema: object, spec_dict: dict) -> dict:
|
||||
"""Resolve a $ref to its target schema, or return the schema as-is."""
|
||||
if not isinstance(schema, dict):
|
||||
return {}
|
||||
ref = schema.get("$ref")
|
||||
if not ref:
|
||||
return schema
|
||||
# Only handle local refs like #/components/schemas/Foo
|
||||
if not isinstance(ref, str) or not ref.startswith("#/"):
|
||||
return schema
|
||||
parts = ref.lstrip("#/").split("/")
|
||||
target: object = spec_dict
|
||||
for part in parts:
|
||||
if not isinstance(target, dict):
|
||||
return schema
|
||||
target = target.get(part)
|
||||
return target if isinstance(target, dict) else schema
|
||||
|
||||
|
||||
def _generate_operation_id(path: str, method: str) -> str:
|
||||
"""Generate a snake_case operation_id from path and method."""
|
||||
# Remove path parameters braces and replace / with _
|
||||
clean = re.sub(r"\{[^}]+\}", "by_param", path)
|
||||
clean = re.sub(r"[^a-zA-Z0-9]+", "_", clean).strip("_")
|
||||
return f"{method.lower()}_{clean}" if clean else method.lower()
|
||||
282
backend/app/openapi/review_api.py
Normal file
282
backend/app/openapi/review_api.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""FastAPI router for OpenAPI import review workflow.
|
||||
|
||||
Exposes endpoints for:
|
||||
- Starting an import job (triggers background pipeline)
|
||||
- Querying job status
|
||||
- Reviewing and editing classifications
|
||||
- Approving a job to trigger tool generation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from app.auth import require_admin_api_key
|
||||
from app.openapi.generator import generate_agent_yaml, generate_tool_code
|
||||
from app.openapi.importer import ImportOrchestrator
|
||||
from app.openapi.models import ClassificationResult, ImportJob
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/openapi",
|
||||
tags=["openapi"],
|
||||
dependencies=[Depends(require_admin_api_key)],
|
||||
)
|
||||
|
||||
# In-memory store: job_id -> job dict, guarded by async lock
|
||||
_job_store: dict[str, dict] = {}
|
||||
_store_lock = asyncio.Lock()
|
||||
|
||||
# Shared orchestrator instance
|
||||
_orchestrator = ImportOrchestrator()
|
||||
|
||||
|
||||
# --- Request / Response schemas ---
|
||||
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
url: str
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_must_be_valid(cls, value: str) -> str:
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
raise ValueError("url must not be empty")
|
||||
if not stripped.startswith(("http://", "https://")):
|
||||
raise ValueError("url must start with http:// or https://")
|
||||
return stripped
|
||||
|
||||
|
||||
class JobResponse(BaseModel):
|
||||
job_id: str
|
||||
status: str
|
||||
spec_url: str
|
||||
total_endpoints: int = 0
|
||||
classified_count: int = 0
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ClassificationResponse(BaseModel):
|
||||
index: int
|
||||
access_type: str
|
||||
needs_interrupt: bool
|
||||
agent_group: str
|
||||
confidence: float
|
||||
customer_params: list[str]
|
||||
endpoint: dict
|
||||
|
||||
|
||||
class UpdateClassificationRequest(BaseModel):
|
||||
access_type: Literal["read", "write"]
|
||||
needs_interrupt: bool
|
||||
agent_group: str
|
||||
|
||||
@field_validator("agent_group")
|
||||
@classmethod
|
||||
def agent_group_must_be_safe(cls, value: str) -> str:
|
||||
if not value.strip() or not re.fullmatch(r"[a-zA-Z0-9_\-]+", value):
|
||||
raise ValueError(
|
||||
"agent_group must be non-empty and contain only "
|
||||
"alphanumeric characters, underscores, or hyphens"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
|
||||
def _job_to_response(job: dict) -> dict:
|
||||
return {
|
||||
"job_id": job["job_id"],
|
||||
"status": job["status"],
|
||||
"spec_url": job["spec_url"],
|
||||
"total_endpoints": job.get("total_endpoints", 0),
|
||||
"classified_count": job.get("classified_count", 0),
|
||||
"error_message": job.get("error_message"),
|
||||
}
|
||||
|
||||
|
||||
def _classification_to_response(idx: int, clf: ClassificationResult) -> dict:
|
||||
ep = clf.endpoint
|
||||
return {
|
||||
"index": idx,
|
||||
"access_type": clf.access_type,
|
||||
"needs_interrupt": clf.needs_interrupt,
|
||||
"agent_group": clf.agent_group,
|
||||
"confidence": clf.confidence,
|
||||
"customer_params": list(clf.customer_params),
|
||||
"endpoint": {
|
||||
"path": ep.path,
|
||||
"method": ep.method,
|
||||
"operation_id": ep.operation_id,
|
||||
"summary": ep.summary,
|
||||
"description": ep.description,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def _run_import(job_id: str, url: str) -> None:
|
||||
"""Run the import pipeline as a background task."""
|
||||
|
||||
def on_progress(stage: str, result_job: ImportJob) -> None:
|
||||
if job_id in _job_store:
|
||||
_job_store[job_id] = {
|
||||
**_job_store[job_id],
|
||||
"status": result_job.status,
|
||||
"total_endpoints": result_job.total_endpoints,
|
||||
"classified_count": result_job.classified_count,
|
||||
"error_message": result_job.error_message,
|
||||
}
|
||||
|
||||
try:
|
||||
result = await _orchestrator.start_import(
|
||||
url=url, job_id=job_id, on_progress=on_progress,
|
||||
)
|
||||
if job_id in _job_store:
|
||||
_job_store[job_id] = {
|
||||
**_job_store[job_id],
|
||||
"status": result.status,
|
||||
"total_endpoints": result.total_endpoints,
|
||||
"classified_count": result.classified_count,
|
||||
"error_message": result.error_message,
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Background import failed for job %s", job_id)
|
||||
if job_id in _job_store:
|
||||
_job_store[job_id] = {
|
||||
**_job_store[job_id],
|
||||
"status": "failed",
|
||||
"error_message": "Import failed. Please check the URL and try again.",
|
||||
}
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
|
||||
@router.post("/import", status_code=202)
|
||||
async def start_import(
|
||||
request: ImportRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> dict:
|
||||
"""Start an OpenAPI import job for the given spec URL."""
|
||||
job_id = str(uuid.uuid4())
|
||||
job: dict = {
|
||||
"job_id": job_id,
|
||||
"status": "pending",
|
||||
"spec_url": request.url,
|
||||
"total_endpoints": 0,
|
||||
"classified_count": 0,
|
||||
"error_message": None,
|
||||
"classifications": [],
|
||||
}
|
||||
_job_store[job_id] = job
|
||||
background_tasks.add_task(_run_import, job_id, request.url)
|
||||
return _job_to_response(job)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
async def get_job(job_id: str) -> dict:
|
||||
"""Get the status of an import job."""
|
||||
job = _job_store.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
return _job_to_response(job)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/classifications")
|
||||
async def get_classifications(job_id: str) -> list:
|
||||
"""Get all classifications for an import job."""
|
||||
job = _job_store.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||
return [
|
||||
_classification_to_response(i, clf)
|
||||
for i, clf in enumerate(classifications)
|
||||
]
|
||||
|
||||
|
||||
@router.put("/jobs/{job_id}/classifications/{idx}")
|
||||
async def update_classification(
|
||||
job_id: str,
|
||||
idx: int,
|
||||
request: UpdateClassificationRequest,
|
||||
) -> dict:
|
||||
"""Update a specific classification by index."""
|
||||
job = _job_store.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
|
||||
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||
if idx < 0 or idx >= len(classifications):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Classification index {idx} out of range",
|
||||
)
|
||||
|
||||
original = classifications[idx]
|
||||
updated = ClassificationResult(
|
||||
endpoint=original.endpoint,
|
||||
access_type=request.access_type,
|
||||
customer_params=original.customer_params,
|
||||
agent_group=request.agent_group,
|
||||
confidence=original.confidence,
|
||||
needs_interrupt=request.needs_interrupt,
|
||||
)
|
||||
new_classifications = list(classifications)
|
||||
new_classifications[idx] = updated
|
||||
_job_store[job_id] = {**job, "classifications": new_classifications}
|
||||
|
||||
return _classification_to_response(idx, updated)
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/approve")
|
||||
async def approve_job(job_id: str) -> dict:
|
||||
"""Approve a job's classifications and trigger tool generation.
|
||||
|
||||
Generates Python tool code for each classified endpoint and
|
||||
produces an agent YAML configuration snippet.
|
||||
"""
|
||||
job = _job_store.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
|
||||
|
||||
classifications: list[ClassificationResult] = job.get("classifications", [])
|
||||
if not classifications:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No classifications to approve. Import must complete first.",
|
||||
)
|
||||
|
||||
base_url = job["spec_url"].rsplit("/", 1)[0]
|
||||
generated_tools = []
|
||||
for clf in classifications:
|
||||
tool = generate_tool_code(clf, base_url)
|
||||
generated_tools.append({
|
||||
"function_name": tool.function_name,
|
||||
"agent_group": clf.agent_group,
|
||||
"code": tool.code,
|
||||
})
|
||||
|
||||
agent_yaml = generate_agent_yaml(tuple(classifications), base_url)
|
||||
|
||||
updated_job = {
|
||||
**job,
|
||||
"status": "approved",
|
||||
"generated_tools": generated_tools,
|
||||
"agent_yaml": agent_yaml,
|
||||
}
|
||||
_job_store[job_id] = updated_job
|
||||
|
||||
response = _job_to_response(updated_job)
|
||||
response["generated_tools_count"] = len(generated_tools)
|
||||
return response
|
||||
167
backend/app/openapi/ssrf.py
Normal file
167
backend/app/openapi/ssrf.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""SSRF protection module.
|
||||
|
||||
Validates URLs before making external HTTP requests.
|
||||
Blocks private IPs, loopback addresses, link-local, and DNS rebinding attacks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class SSRFError(Exception):
|
||||
"""Raised when a URL fails SSRF validation."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SSRFPolicy:
|
||||
"""Configuration for SSRF protection."""
|
||||
|
||||
allowed_schemes: frozenset[str] = frozenset({"http", "https"})
|
||||
allowed_hosts: frozenset[str] | None = None # None = all public hosts allowed
|
||||
max_redirects: int = 5
|
||||
timeout_seconds: float = 30.0
|
||||
|
||||
|
||||
_BLOCKED_NETWORKS = (
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("169.254.0.0/16"),
|
||||
ipaddress.ip_network("0.0.0.0/32"),
|
||||
ipaddress.ip_network("100.64.0.0/10"), # Carrier-Grade NAT
|
||||
ipaddress.ip_network("198.18.0.0/15"), # Benchmarking
|
||||
ipaddress.ip_network("240.0.0.0/4"), # Reserved
|
||||
ipaddress.ip_network("255.255.255.255/32"), # Broadcast
|
||||
# IPv6
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("fe80::/10"),
|
||||
ipaddress.ip_network("fc00::/7"),
|
||||
ipaddress.ip_network("::/128"),
|
||||
ipaddress.ip_network("::ffff:0:0/96"), # IPv4-mapped IPv6
|
||||
ipaddress.ip_network("2001:db8::/32"), # Documentation
|
||||
)
|
||||
|
||||
DEFAULT_POLICY = SSRFPolicy()
|
||||
|
||||
|
||||
def is_private_ip(ip_str: str) -> bool:
|
||||
"""Check if an IP address is private/reserved."""
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return True # Invalid IP treated as blocked
|
||||
|
||||
return any(addr in network for network in _BLOCKED_NETWORKS)
|
||||
|
||||
|
||||
def validate_url(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> str:
|
||||
"""Validate a URL against SSRF policy.
|
||||
|
||||
Returns the validated URL string.
|
||||
Raises SSRFError if the URL is blocked.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme not in policy.allowed_schemes:
|
||||
raise SSRFError(
|
||||
f"URL scheme '{parsed.scheme}' is not allowed. "
|
||||
f"Allowed: {', '.join(sorted(policy.allowed_schemes))}"
|
||||
)
|
||||
|
||||
# Check hostname exists
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise SSRFError("URL has no hostname")
|
||||
|
||||
# Check allowed hosts whitelist
|
||||
if policy.allowed_hosts is not None and hostname not in policy.allowed_hosts:
|
||||
raise SSRFError(f"Host '{hostname}' is not in the allowed hosts list")
|
||||
|
||||
# DNS resolution -- resolve before making any request
|
||||
resolved_ips = resolve_hostname(hostname)
|
||||
if not resolved_ips:
|
||||
raise SSRFError(f"Could not resolve hostname '{hostname}'")
|
||||
|
||||
# Check all resolved IPs against blocked networks
|
||||
for ip_str in resolved_ips:
|
||||
if is_private_ip(ip_str):
|
||||
raise SSRFError(
|
||||
f"Host '{hostname}' resolves to private/reserved IP {ip_str}. "
|
||||
"Request blocked for SSRF protection."
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def resolve_hostname(hostname: str) -> list[str]:
|
||||
"""Resolve hostname to IP addresses via DNS."""
|
||||
try:
|
||||
results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||
return list({result[4][0] for result in results})
|
||||
except socket.gaierror:
|
||||
return []
|
||||
|
||||
|
||||
async def safe_fetch(
|
||||
url: str,
|
||||
*,
|
||||
policy: SSRFPolicy = DEFAULT_POLICY,
|
||||
method: str = "GET",
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> httpx.Response:
|
||||
"""Fetch a URL with SSRF protection.
|
||||
|
||||
Validates the URL, resolves DNS, checks IPs, then makes the request.
|
||||
After receiving the response, verifies the actual connected IP
|
||||
to guard against DNS rebinding.
|
||||
"""
|
||||
validate_url(url, policy)
|
||||
|
||||
# Make the request with redirect following disabled so we can check each hop
|
||||
async with httpx.AsyncClient(
|
||||
follow_redirects=False,
|
||||
timeout=httpx.Timeout(policy.timeout_seconds),
|
||||
) as client:
|
||||
current_url = url
|
||||
for _redirect_count in range(policy.max_redirects + 1):
|
||||
response = await client.request(
|
||||
method,
|
||||
current_url,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.is_redirect:
|
||||
redirect_url = str(response.next_request.url) if response.next_request else None
|
||||
if not redirect_url:
|
||||
raise SSRFError("Redirect with no target URL")
|
||||
# Validate the redirect target
|
||||
validate_url(redirect_url, policy)
|
||||
current_url = redirect_url
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
raise SSRFError(
|
||||
f"Too many redirects (max {policy.max_redirects}). "
|
||||
"Possible redirect loop or evasion attempt."
|
||||
)
|
||||
|
||||
|
||||
async def safe_fetch_text(
|
||||
url: str,
|
||||
*,
|
||||
policy: SSRFPolicy = DEFAULT_POLICY,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Fetch a URL and return text content with SSRF protection."""
|
||||
response = await safe_fetch(url, policy=policy, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
51
backend/app/openapi/validator.py
Normal file
51
backend/app/openapi/validator.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""OpenAPI spec validator.
|
||||
|
||||
Validates an OpenAPI spec dict for required fields and basic structural
|
||||
correctness. Returns a list of human-readable error strings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
_SUPPORTED_VERSIONS = ("3.0.", "3.1.")
|
||||
_REQUIRED_FIELDS = ("openapi", "info", "paths")
|
||||
|
||||
|
||||
def validate_spec(spec_dict: object) -> list[str]:
|
||||
"""Validate an OpenAPI spec dict.
|
||||
|
||||
Returns a list of error strings. An empty list means the spec is valid.
|
||||
Does not raise; all errors are captured and returned.
|
||||
"""
|
||||
if not isinstance(spec_dict, dict):
|
||||
return [f"Spec must be a dict, got {type(spec_dict).__name__}"]
|
||||
|
||||
errors: list[str] = []
|
||||
|
||||
# Check required top-level fields
|
||||
for field in _REQUIRED_FIELDS:
|
||||
if field not in spec_dict:
|
||||
errors.append(f"Missing required field: '{field}'")
|
||||
|
||||
# Validate openapi version if present
|
||||
if "openapi" in spec_dict:
|
||||
version = spec_dict["openapi"]
|
||||
if not isinstance(version, str):
|
||||
errors.append(f"'openapi' must be a string, got {type(version).__name__}")
|
||||
elif not any(version.startswith(prefix) for prefix in _SUPPORTED_VERSIONS):
|
||||
errors.append(
|
||||
f"Unsupported OpenAPI version '{version}'. "
|
||||
f"Supported versions start with: {', '.join(_SUPPORTED_VERSIONS)}"
|
||||
)
|
||||
|
||||
# Validate info object if present
|
||||
if "info" in spec_dict and isinstance(spec_dict["info"], dict):
|
||||
info = spec_dict["info"]
|
||||
for sub_field in ("title", "version"):
|
||||
if sub_field not in info:
|
||||
errors.append(f"Missing required field in 'info': '{sub_field}'")
|
||||
|
||||
# Validate paths object if present
|
||||
if "paths" in spec_dict and not isinstance(spec_dict["paths"], dict):
|
||||
errors.append(f"'paths' must be a dict, got {type(spec_dict['paths']).__name__}")
|
||||
|
||||
return errors
|
||||
@@ -100,5 +100,41 @@ class AgentRegistry:
|
||||
def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]:
|
||||
return tuple(a for a in self._agents.values() if a.permission == permission)
|
||||
|
||||
@classmethod
|
||||
def load_template(
|
||||
cls,
|
||||
template_name: str,
|
||||
templates_dir: str | Path | None = None,
|
||||
) -> AgentRegistry:
|
||||
"""Load agent configurations from a named template."""
|
||||
if templates_dir is None:
|
||||
templates_dir = Path(__file__).parent.parent / "templates"
|
||||
templates_dir = Path(templates_dir)
|
||||
|
||||
yaml_path = templates_dir / f"{template_name}.yaml"
|
||||
if not yaml_path.exists():
|
||||
available = cls.list_templates(templates_dir)
|
||||
raise FileNotFoundError(
|
||||
f"Template '{template_name}' not found. "
|
||||
f"Available: {', '.join(available) if available else 'none'}"
|
||||
)
|
||||
return cls.load(yaml_path)
|
||||
|
||||
@classmethod
|
||||
def list_templates(
|
||||
cls,
|
||||
templates_dir: str | Path | None = None,
|
||||
) -> tuple[str, ...]:
|
||||
"""List available template names from the templates directory."""
|
||||
if templates_dir is None:
|
||||
templates_dir = Path(__file__).parent.parent / "templates"
|
||||
templates_dir = Path(templates_dir)
|
||||
|
||||
if not templates_dir.is_dir():
|
||||
return ()
|
||||
return tuple(
|
||||
sorted(p.stem for p in templates_dir.glob("*.yaml"))
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._agents)
|
||||
|
||||
3
backend/app/replay/__init__.py
Normal file
3
backend/app/replay/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Replay module -- conversation replay API and transformer."""
|
||||
|
||||
from __future__ import annotations
|
||||
125
backend/app/replay/api.py
Normal file
125
backend/app/replay/api.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Replay API router -- conversation listing and step-by-step replay."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
|
||||
from app.api_utils import envelope
|
||||
from app.auth import require_admin_api_key
|
||||
|
||||
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1",
|
||||
tags=["replay"],
|
||||
dependencies=[Depends(require_admin_api_key)],
|
||||
)
|
||||
|
||||
_COUNT_CONVERSATIONS_SQL = """
|
||||
SELECT COUNT(*) FROM conversations
|
||||
"""
|
||||
|
||||
_LIST_CONVERSATIONS_SQL = """
|
||||
SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd
|
||||
FROM conversations
|
||||
ORDER BY last_activity DESC
|
||||
LIMIT %(limit)s OFFSET %(offset)s
|
||||
"""
|
||||
|
||||
_GET_CHECKPOINTS_SQL = """
|
||||
SELECT thread_id, checkpoint_id, checkpoint, metadata
|
||||
FROM checkpoints
|
||||
WHERE thread_id = %(thread_id)s
|
||||
ORDER BY checkpoint_id ASC
|
||||
"""
|
||||
|
||||
|
||||
async def get_pool(request: Request) -> AsyncConnectionPool:
|
||||
"""Dependency: extract the shared pool from app state."""
|
||||
return request.app.state.pool
|
||||
|
||||
|
||||
@router.get("/conversations")
|
||||
async def list_conversations(
|
||||
request: Request,
|
||||
page: Annotated[int, Query(ge=1)] = 1,
|
||||
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
) -> dict:
|
||||
"""List conversations with pagination."""
|
||||
pool = await get_pool(request)
|
||||
offset = (page - 1) * per_page
|
||||
async with pool.connection() as conn:
|
||||
count_cursor = await conn.execute(_COUNT_CONVERSATIONS_SQL)
|
||||
count_row = await count_cursor.fetchone()
|
||||
total = count_row[0] if count_row else 0
|
||||
|
||||
cursor = await conn.execute(
|
||||
_LIST_CONVERSATIONS_SQL,
|
||||
{"limit": per_page, "offset": offset},
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
return envelope({
|
||||
"conversations": [dict(row) for row in rows],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
})
|
||||
|
||||
|
||||
@router.get("/replay/{thread_id}")
|
||||
async def get_replay(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
page: Annotated[int, Query(ge=1)] = 1,
|
||||
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
) -> dict:
|
||||
"""Return paginated replay steps for a conversation thread."""
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
if not _THREAD_ID_PATTERN.match(thread_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid thread_id format")
|
||||
|
||||
pool = await get_pool(request)
|
||||
async with pool.connection() as conn:
|
||||
cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id})
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
all_steps = transform_checkpoints([dict(row) for row in rows])
|
||||
total_steps = len(all_steps)
|
||||
start = (page - 1) * per_page
|
||||
end = start + per_page
|
||||
page_steps = all_steps[start:end]
|
||||
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"total_steps": total_steps,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"steps": [
|
||||
{
|
||||
"step": s.step,
|
||||
"type": s.type.value,
|
||||
"timestamp": s.timestamp,
|
||||
"content": s.content,
|
||||
"agent": s.agent,
|
||||
"tool": s.tool,
|
||||
"params": s.params,
|
||||
"result": s.result,
|
||||
"reasoning": s.reasoning,
|
||||
"tokens": s.tokens,
|
||||
"duration_ms": s.duration_ms,
|
||||
}
|
||||
for s in page_steps
|
||||
],
|
||||
}
|
||||
return envelope(data)
|
||||
52
backend/app/replay/models.py
Normal file
52
backend/app/replay/models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Value objects for conversation replay."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class StepType(str, Enum):
|
||||
"""Types of steps in a conversation replay."""
|
||||
|
||||
user_message = "user_message"
|
||||
supervisor_routing = "supervisor_routing"
|
||||
tool_call = "tool_call"
|
||||
tool_result = "tool_result"
|
||||
agent_response = "agent_response"
|
||||
interrupt = "interrupt"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReplayStep:
|
||||
"""A single step in a conversation replay."""
|
||||
|
||||
step: int
|
||||
type: StepType
|
||||
timestamp: str
|
||||
content: str = ""
|
||||
agent: str | None = None
|
||||
tool: str | None = None
|
||||
params: dict | None = None
|
||||
result: dict | None = None
|
||||
reasoning: str | None = None
|
||||
tokens: int | None = None
|
||||
duration_ms: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Store params as a frozen copy to prevent mutation from the outside
|
||||
if self.params is not None:
|
||||
object.__setattr__(self, "params", dict(self.params))
|
||||
if self.result is not None:
|
||||
object.__setattr__(self, "result", dict(self.result))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReplayPage:
|
||||
"""A paginated page of replay steps for a conversation thread."""
|
||||
|
||||
thread_id: str
|
||||
total_steps: int
|
||||
page: int
|
||||
per_page: int
|
||||
steps: tuple[ReplayStep, ...]
|
||||
116
backend/app/replay/transformer.py
Normal file
116
backend/app/replay/transformer.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Transforms PostgresSaver checkpoint rows into ReplayStep list."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
|
||||
from app.replay.models import ReplayStep, StepType
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
_EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z"
|
||||
|
||||
|
||||
def _extract_messages(row: dict) -> list[dict]:
|
||||
"""Safely extract messages list from a checkpoint row."""
|
||||
checkpoint = row.get("checkpoint")
|
||||
if not checkpoint or not isinstance(checkpoint, dict):
|
||||
return []
|
||||
channel_values = checkpoint.get("channel_values")
|
||||
if not channel_values or not isinstance(channel_values, dict):
|
||||
return []
|
||||
messages = channel_values.get("messages")
|
||||
if not messages or not isinstance(messages, list):
|
||||
return []
|
||||
return messages
|
||||
|
||||
|
||||
def _step_from_message(msg: dict, step_number: int) -> ReplayStep | None:
|
||||
"""Convert a single message dict to a ReplayStep. Returns None for unknown types."""
|
||||
msg_type = msg.get("type", "")
|
||||
timestamp = msg.get("created_at") or _EMPTY_TIMESTAMP
|
||||
content = msg.get("content") or ""
|
||||
if isinstance(content, list):
|
||||
# LangChain may encode content as a list of parts
|
||||
content = " ".join(
|
||||
part.get("text", "") if isinstance(part, dict) else str(part)
|
||||
for part in content
|
||||
)
|
||||
|
||||
if msg_type == "human":
|
||||
return ReplayStep(
|
||||
step=step_number,
|
||||
type=StepType.user_message,
|
||||
timestamp=timestamp,
|
||||
content=content,
|
||||
)
|
||||
|
||||
if msg_type == "ai":
|
||||
tool_calls = msg.get("tool_calls") or []
|
||||
if tool_calls:
|
||||
first = tool_calls[0]
|
||||
return ReplayStep(
|
||||
step=step_number,
|
||||
type=StepType.tool_call,
|
||||
timestamp=timestamp,
|
||||
content=content,
|
||||
tool=first.get("name"),
|
||||
params=dict(first.get("args") or {}),
|
||||
)
|
||||
return ReplayStep(
|
||||
step=step_number,
|
||||
type=StepType.agent_response,
|
||||
timestamp=timestamp,
|
||||
content=content,
|
||||
agent=msg.get("name"),
|
||||
)
|
||||
|
||||
if msg_type == "tool":
|
||||
raw = content
|
||||
result: dict | None = None
|
||||
try:
|
||||
import json
|
||||
|
||||
result = json.loads(raw)
|
||||
except (ValueError, TypeError):
|
||||
result = {"raw": raw}
|
||||
return ReplayStep(
|
||||
step=step_number,
|
||||
type=StepType.tool_result,
|
||||
timestamp=timestamp,
|
||||
tool=msg.get("name"),
|
||||
result=result,
|
||||
)
|
||||
|
||||
logger.debug("Skipping unknown message type: %s", msg_type)
|
||||
return None
|
||||
|
||||
|
||||
def transform_checkpoints(rows: list[dict]) -> list[ReplayStep]:
|
||||
"""Transform a list of checkpoint rows into an ordered list of ReplaySteps.
|
||||
|
||||
Steps are numbered sequentially starting from 1 across all rows.
|
||||
Unknown or malformed messages are silently skipped.
|
||||
"""
|
||||
steps: list[ReplayStep] = []
|
||||
step_number = 1
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
messages = _extract_messages(row)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error extracting messages from checkpoint row")
|
||||
continue
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
step = _step_from_message(msg, step_number)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception("Error converting message to ReplayStep")
|
||||
step = None
|
||||
|
||||
if step is not None:
|
||||
steps.append(step)
|
||||
step_number += 1
|
||||
|
||||
return steps
|
||||
131
backend/app/safety.py
Normal file
131
backend/app/safety.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Safety policy for destructive-action confirmation rules.
|
||||
|
||||
This module makes the confirmation rules explicit and auditable. Every tool
|
||||
call passes through ``requires_confirmation`` before execution to decide
|
||||
whether human-in-the-loop approval is needed.
|
||||
|
||||
Policy summary
|
||||
--------------
|
||||
- ``read`` actions: execute immediately, no confirmation required.
|
||||
- ``write`` actions: require human approval via interrupt gate.
|
||||
- OpenAPI-imported endpoints: use ``needs_interrupt`` from classification.
|
||||
- If both the agent permission AND the endpoint classification agree
|
||||
the action is read-only, it executes without confirmation.
|
||||
|
||||
Multi-intent semantics
|
||||
----------------------
|
||||
When a user message contains multiple intents (e.g. "cancel my order and
|
||||
apply a refund"), the supervisor routes them sequentially. Each action is
|
||||
evaluated independently:
|
||||
- If a write action is blocked by an interrupt, subsequent actions in the
|
||||
same message are paused until the interrupt is resolved.
|
||||
- Read actions that follow a blocked write are also paused (sequential,
|
||||
not best-effort) to preserve causal ordering.
|
||||
- If an interrupt is rejected, the remaining actions are skipped and the
|
||||
agent informs the user.
|
||||
|
||||
MCP error taxonomy
|
||||
------------------
|
||||
Tool execution errors are classified into categories for retry decisions:
|
||||
|
||||
- ``transient``: network timeouts, rate limits, 5xx -- retryable up to 3 times.
|
||||
- ``validation``: bad parameters, 4xx -- not retryable, report to user.
|
||||
- ``auth``: 401/403 -- not retryable, escalate.
|
||||
- ``unknown``: unclassified -- not retryable, log and escalate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConfirmationPolicy:
|
||||
"""Result of evaluating whether an action needs confirmation."""
|
||||
|
||||
requires_confirmation: bool
|
||||
reason: str
|
||||
|
||||
|
||||
def requires_confirmation(
|
||||
*,
|
||||
agent_permission: Literal["read", "write"],
|
||||
needs_interrupt: bool | None = None,
|
||||
) -> ConfirmationPolicy:
|
||||
"""Determine whether an action requires human confirmation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
agent_permission:
|
||||
The permission level of the agent executing the action.
|
||||
needs_interrupt:
|
||||
Override from OpenAPI classification. When ``None``, the decision
|
||||
is based solely on ``agent_permission``.
|
||||
"""
|
||||
if needs_interrupt is not None:
|
||||
if needs_interrupt:
|
||||
return ConfirmationPolicy(
|
||||
requires_confirmation=True,
|
||||
reason="Endpoint classified as requiring human approval",
|
||||
)
|
||||
return ConfirmationPolicy(
|
||||
requires_confirmation=False,
|
||||
reason="Endpoint classified as safe (no interrupt needed)",
|
||||
)
|
||||
|
||||
if agent_permission == "write":
|
||||
return ConfirmationPolicy(
|
||||
requires_confirmation=True,
|
||||
reason="Write-permission agent actions require confirmation",
|
||||
)
|
||||
|
||||
return ConfirmationPolicy(
|
||||
requires_confirmation=False,
|
||||
reason="Read-only agent actions execute immediately",
|
||||
)
|
||||
|
||||
|
||||
# --- MCP Error Taxonomy ---
|
||||
|
||||
|
||||
MCP_ERROR_CATEGORY = Literal["transient", "validation", "auth", "unknown"]
|
||||
|
||||
_TRANSIENT_STATUS_CODES = frozenset({408, 429, 500, 502, 503, 504})
|
||||
_AUTH_STATUS_CODES = frozenset({401, 403})
|
||||
_MAX_RETRIES = 3
|
||||
|
||||
|
||||
def classify_mcp_error(
|
||||
*,
|
||||
status_code: int | None = None,
|
||||
error_message: str = "",
|
||||
) -> MCP_ERROR_CATEGORY:
|
||||
"""Classify an MCP tool error for retry decisions."""
|
||||
if status_code is not None:
|
||||
if status_code in _TRANSIENT_STATUS_CODES:
|
||||
return "transient"
|
||||
if status_code in _AUTH_STATUS_CODES:
|
||||
return "auth"
|
||||
if 400 <= status_code < 500:
|
||||
return "validation"
|
||||
|
||||
lower_msg = error_message.lower()
|
||||
if any(kw in lower_msg for kw in ("timeout", "timed out", "rate limit")):
|
||||
return "transient"
|
||||
if any(kw in lower_msg for kw in ("unauthorized", "forbidden")):
|
||||
return "auth"
|
||||
if any(kw in lower_msg for kw in ("invalid", "missing", "bad request")):
|
||||
return "validation"
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def is_retryable(category: MCP_ERROR_CATEGORY) -> bool:
|
||||
"""Return whether a given error category is retryable."""
|
||||
return category == "transient"
|
||||
|
||||
|
||||
def max_retries() -> int:
|
||||
"""Maximum retry attempts for transient errors."""
|
||||
return _MAX_RETRIES
|
||||
@@ -1,9 +1,18 @@
|
||||
"""Session TTL management with sliding window and interrupt extension."""
|
||||
"""Session TTL management with sliding window and interrupt extension.
|
||||
|
||||
Provides both in-memory (SessionManager) and PostgreSQL-backed
|
||||
(PgSessionManager) implementations behind a common Protocol.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -13,8 +22,19 @@ class SessionState:
|
||||
has_pending_interrupt: bool
|
||||
|
||||
|
||||
class SessionManagerProtocol(Protocol):
|
||||
"""Protocol for session TTL management."""
|
||||
|
||||
def touch(self, thread_id: str) -> SessionState: ...
|
||||
def is_expired(self, thread_id: str) -> bool: ...
|
||||
def extend_for_interrupt(self, thread_id: str) -> SessionState: ...
|
||||
def resolve_interrupt(self, thread_id: str) -> SessionState: ...
|
||||
def get_state(self, thread_id: str) -> SessionState | None: ...
|
||||
def remove(self, thread_id: str) -> None: ...
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages session TTL with sliding window and interrupt extensions.
|
||||
"""In-memory session manager for single-worker development.
|
||||
|
||||
- Each message resets the TTL (sliding window).
|
||||
- A pending interrupt suspends expiration until resolved.
|
||||
@@ -40,10 +60,8 @@ class SessionManager:
|
||||
state = self._sessions.get(thread_id)
|
||||
if state is None:
|
||||
return True
|
||||
|
||||
if state.has_pending_interrupt:
|
||||
return False
|
||||
|
||||
elapsed = time.time() - state.last_activity
|
||||
return elapsed > self._session_ttl
|
||||
|
||||
@@ -52,7 +70,6 @@ class SessionManager:
|
||||
existing = self._sessions.get(thread_id)
|
||||
if existing is None:
|
||||
return self.touch(thread_id)
|
||||
|
||||
new_state = SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=existing.last_activity,
|
||||
@@ -76,3 +93,120 @@ class SessionManager:
|
||||
|
||||
def remove(self, thread_id: str) -> None:
|
||||
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
|
||||
|
||||
|
||||
# Alias for explicit naming
|
||||
InMemorySessionManager = SessionManager
|
||||
|
||||
|
||||
class PgSessionManager:
|
||||
"""PostgreSQL-backed session manager for multi-worker production."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool: AsyncConnectionPool,
|
||||
session_ttl_seconds: int = 1800,
|
||||
) -> None:
|
||||
self._pool = pool
|
||||
self._session_ttl = session_ttl_seconds
|
||||
|
||||
def touch(self, thread_id: str) -> SessionState:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(self._touch(thread_id))
|
||||
|
||||
async def _touch(self, thread_id: str) -> SessionState:
|
||||
now = datetime.now(timezone.utc)
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
||||
VALUES (%(tid)s, %(now)s, FALSE)
|
||||
ON CONFLICT (thread_id) DO UPDATE
|
||||
SET last_activity = %(now)s
|
||||
""",
|
||||
{"tid": thread_id, "now": now},
|
||||
)
|
||||
return SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=now.timestamp(),
|
||||
has_pending_interrupt=False,
|
||||
)
|
||||
|
||||
def is_expired(self, thread_id: str) -> bool:
|
||||
state = self.get_state(thread_id)
|
||||
if state is None:
|
||||
return True
|
||||
if state.has_pending_interrupt:
|
||||
return False
|
||||
elapsed = time.time() - state.last_activity
|
||||
return elapsed > self._session_ttl
|
||||
|
||||
def extend_for_interrupt(self, thread_id: str) -> SessionState:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._set_interrupt(thread_id, True)
|
||||
)
|
||||
|
||||
def resolve_interrupt(self, thread_id: str) -> SessionState:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._set_interrupt(thread_id, False)
|
||||
)
|
||||
|
||||
async def _set_interrupt(
|
||||
self, thread_id: str, has_interrupt: bool
|
||||
) -> SessionState:
|
||||
now = datetime.now(timezone.utc)
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
|
||||
VALUES (%(tid)s, %(now)s, %(interrupt)s)
|
||||
ON CONFLICT (thread_id) DO UPDATE
|
||||
SET last_activity = %(now)s,
|
||||
has_pending_interrupt = %(interrupt)s
|
||||
""",
|
||||
{"tid": thread_id, "now": now, "interrupt": has_interrupt},
|
||||
)
|
||||
return SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=now.timestamp(),
|
||||
has_pending_interrupt=has_interrupt,
|
||||
)
|
||||
|
||||
def get_state(self, thread_id: str) -> SessionState | None:
|
||||
import asyncio
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self._get_state(thread_id)
|
||||
)
|
||||
|
||||
async def _get_state(self, thread_id: str) -> SessionState | None:
|
||||
async with self._pool.connection() as conn:
|
||||
cursor = await conn.execute(
|
||||
"SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s",
|
||||
{"tid": thread_id},
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return SessionState(
|
||||
thread_id=thread_id,
|
||||
last_activity=row["last_activity"].timestamp(),
|
||||
has_pending_interrupt=row["has_pending_interrupt"],
|
||||
)
|
||||
|
||||
def remove(self, thread_id: str) -> None:
|
||||
import asyncio
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(self._remove(thread_id))
|
||||
|
||||
async def _remove(self, thread_id: str) -> None:
|
||||
async with self._pool.connection() as conn:
|
||||
await conn.execute(
|
||||
"DELETE FROM sessions WHERE thread_id = %(tid)s",
|
||||
{"tid": thread_id},
|
||||
)
|
||||
|
||||
3
backend/app/tools/__init__.py
Normal file
3
backend/app/tools/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Tools package for smart-support backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
72
backend/app/tools/error_handler.py
Normal file
72
backend/app/tools/error_handler.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Error classification and retry logic for tool calls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class ErrorCategory(Enum):
|
||||
"""Categories for error classification to guide retry decisions."""
|
||||
|
||||
RETRYABLE = "retryable"
|
||||
NON_RETRYABLE = "non_retryable"
|
||||
AUTH_FAILURE = "auth_failure"
|
||||
TIMEOUT = "timeout"
|
||||
NETWORK = "network"
|
||||
|
||||
|
||||
def classify_error(exc: Exception) -> ErrorCategory:
|
||||
"""Classify an exception into an ErrorCategory.
|
||||
|
||||
Rules:
|
||||
- httpx.TimeoutException -> TIMEOUT
|
||||
- httpx.ConnectError -> NETWORK
|
||||
- httpx.HTTPStatusError 401/403 -> AUTH_FAILURE
|
||||
- httpx.HTTPStatusError 429/500/502/503 -> RETRYABLE
|
||||
- anything else -> NON_RETRYABLE
|
||||
"""
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
return ErrorCategory.TIMEOUT
|
||||
if isinstance(exc, httpx.ConnectError):
|
||||
return ErrorCategory.NETWORK
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
code = exc.response.status_code
|
||||
if code in (401, 403):
|
||||
return ErrorCategory.AUTH_FAILURE
|
||||
if code in (429, 500, 502, 503):
|
||||
return ErrorCategory.RETRYABLE
|
||||
return ErrorCategory.NON_RETRYABLE
|
||||
return ErrorCategory.NON_RETRYABLE
|
||||
|
||||
|
||||
async def with_retry(
|
||||
fn: Callable[..., Any],
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 1.0,
|
||||
) -> Any:
|
||||
"""Execute an async callable with exponential backoff for RETRYABLE errors.
|
||||
|
||||
Only ErrorCategory.RETRYABLE errors trigger retries. All other error
|
||||
categories raise immediately after the first attempt.
|
||||
"""
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
return await fn()
|
||||
except Exception as exc:
|
||||
category = classify_error(exc)
|
||||
if category != ErrorCategory.RETRYABLE:
|
||||
raise
|
||||
last_exc = exc
|
||||
if attempt < max_retries:
|
||||
delay = base_delay * (2 ** (attempt - 1))
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
raise last_exc # type: ignore[misc]
|
||||
30
backend/app/ws_context.py
Normal file
30
backend/app/ws_context.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""WebSocketContext -- bundles all dependencies needed by dispatch_message."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.analytics.event_recorder import AnalyticsRecorder
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.conversation_tracker import ConversationTrackerProtocol
|
||||
from app.graph_context import GraphContext
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WebSocketContext:
|
||||
"""All dependencies required for WebSocket message processing.
|
||||
|
||||
Replaces the previous 9-parameter function signature in dispatch_message.
|
||||
"""
|
||||
|
||||
graph_ctx: GraphContext
|
||||
session_manager: SessionManager
|
||||
callback_handler: TokenUsageCallbackHandler
|
||||
interrupt_manager: InterruptManager | None = None
|
||||
analytics_recorder: AnalyticsRecorder | None = None
|
||||
conversation_tracker: ConversationTrackerProtocol | None = None
|
||||
pool: Any = None
|
||||
@@ -3,47 +3,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
MAX_MESSAGE_SIZE = 32_768 # 32 KB
|
||||
MAX_CONTENT_LENGTH = 8_000 # characters
|
||||
MAX_CONTENT_LENGTH = 10_000 # characters
|
||||
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
|
||||
|
||||
# Rate limiting: max 10 messages per 10-second window, per thread
|
||||
_RATE_LIMIT_MAX = 10
|
||||
_RATE_LIMIT_WINDOW = 10.0
|
||||
_MAX_TRACKED_THREADS = 10_000
|
||||
_thread_timestamps: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
|
||||
def _evict_stale_threads(cutoff: float) -> None:
|
||||
"""Remove thread entries with no recent timestamps to prevent memory leak."""
|
||||
stale = [tid for tid, ts in _thread_timestamps.items() if not ts or ts[-1] < cutoff]
|
||||
for tid in stale:
|
||||
del _thread_timestamps[tid]
|
||||
|
||||
|
||||
async def handle_user_message(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
ctx: GraphContext,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
content: str,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Process a user message through the graph and stream results back."""
|
||||
if session_manager.is_expired(thread_id):
|
||||
existing = session_manager.get_state(thread_id)
|
||||
if existing is not None and session_manager.is_expired(thread_id):
|
||||
msg = "Session expired. Please start a new conversation."
|
||||
await _send_json(ws, {"type": "error", "message": msg})
|
||||
return
|
||||
|
||||
session_manager.touch(thread_id)
|
||||
|
||||
classification = await ctx.classify_intent(content)
|
||||
if classification is not None:
|
||||
logger.info(
|
||||
"Intent classification for thread %s: ambiguous=%s, intents=%s",
|
||||
thread_id,
|
||||
classification.is_ambiguous,
|
||||
[i.agent_name for i in classification.intents],
|
||||
)
|
||||
|
||||
if classification.is_ambiguous and classification.clarification_question:
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
"type": "clarification",
|
||||
"thread_id": thread_id,
|
||||
"message": classification.clarification_question,
|
||||
},
|
||||
)
|
||||
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
|
||||
return
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||
|
||||
if classification and len(classification.intents) > 1:
|
||||
agent_names = [i.agent_name for i in classification.intents]
|
||||
hint = (
|
||||
f"\n[System: This request involves multiple actions. "
|
||||
f"Execute in order: {', '.join(agent_names)}]"
|
||||
)
|
||||
input_msg = {"messages": [HumanMessage(content=content + hint)]}
|
||||
else:
|
||||
input_msg = {"messages": [HumanMessage(content=content)]}
|
||||
|
||||
try:
|
||||
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||
async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"):
|
||||
msg_chunk, metadata = chunk
|
||||
node = metadata.get("langgraph_node", "")
|
||||
|
||||
@@ -68,10 +119,18 @@ async def handle_user_message(
|
||||
},
|
||||
)
|
||||
|
||||
state = await graph.aget_state(config)
|
||||
state = await ctx.graph.aget_state(config)
|
||||
if _has_interrupt(state):
|
||||
interrupt_data = _extract_interrupt(state)
|
||||
session_manager.extend_for_interrupt(thread_id)
|
||||
|
||||
if interrupt_manager is not None:
|
||||
interrupt_manager.register(
|
||||
thread_id=thread_id,
|
||||
action=interrupt_data.get("action", "unknown"),
|
||||
params=interrupt_data.get("params", {}),
|
||||
)
|
||||
|
||||
await _send_json(
|
||||
ws,
|
||||
{
|
||||
@@ -91,20 +150,32 @@ async def handle_user_message(
|
||||
|
||||
async def handle_interrupt_response(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
ctx: GraphContext,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
thread_id: str,
|
||||
approved: bool,
|
||||
interrupt_manager: InterruptManager | None = None,
|
||||
) -> None:
|
||||
"""Resume graph execution after interrupt approval/rejection."""
|
||||
if interrupt_manager is not None:
|
||||
status = interrupt_manager.check_status(thread_id)
|
||||
if status is not None and status.is_expired:
|
||||
retry_prompt = interrupt_manager.generate_retry_prompt(status.record)
|
||||
interrupt_manager.resolve(thread_id)
|
||||
session_manager.resolve_interrupt(thread_id)
|
||||
await _send_json(ws, retry_prompt)
|
||||
return
|
||||
|
||||
interrupt_manager.resolve(thread_id)
|
||||
|
||||
session_manager.resolve_interrupt(thread_id)
|
||||
session_manager.touch(thread_id)
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
|
||||
|
||||
try:
|
||||
async for chunk in graph.astream(
|
||||
async for chunk in ctx.graph.astream(
|
||||
Command(resume=approved),
|
||||
config=config,
|
||||
stream_mode="messages",
|
||||
@@ -132,9 +203,7 @@ async def handle_interrupt_response(
|
||||
|
||||
async def dispatch_message(
|
||||
ws: WebSocket,
|
||||
graph: CompiledStateGraph,
|
||||
session_manager: SessionManager,
|
||||
callback_handler: TokenUsageCallbackHandler,
|
||||
ctx: WebSocketContext,
|
||||
raw_data: str,
|
||||
) -> None:
|
||||
"""Parse and route an incoming WebSocket message."""
|
||||
@@ -144,10 +213,14 @@ async def dispatch_message(
|
||||
|
||||
try:
|
||||
data = json.loads(raw_data)
|
||||
except json.JSONDecodeError:
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
|
||||
return
|
||||
|
||||
if not isinstance(data, dict):
|
||||
await _send_json(ws, {"type": "error", "message": "Invalid JSON: expected object"})
|
||||
return
|
||||
|
||||
msg_type = data.get("type")
|
||||
thread_id = data.get("thread_id", "")
|
||||
|
||||
@@ -161,24 +234,81 @@ async def dispatch_message(
|
||||
|
||||
if msg_type == "message":
|
||||
content = data.get("content", "")
|
||||
if not content:
|
||||
if not content or not content.strip():
|
||||
await _send_json(ws, {"type": "error", "message": "Missing message content"})
|
||||
return
|
||||
if len(content) > MAX_CONTENT_LENGTH:
|
||||
await _send_json(ws, {"type": "error", "message": "Message content too long"})
|
||||
return
|
||||
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
|
||||
|
||||
# Rate limiting check (per-thread, with bounded memory)
|
||||
now = time.time()
|
||||
cutoff = now - _RATE_LIMIT_WINDOW
|
||||
if len(_thread_timestamps) > _MAX_TRACKED_THREADS:
|
||||
_evict_stale_threads(cutoff)
|
||||
recent = [t for t in _thread_timestamps[thread_id] if t >= cutoff]
|
||||
if len(recent) >= _RATE_LIMIT_MAX:
|
||||
await _send_json(ws, {"type": "error", "message": "Rate limit exceeded"})
|
||||
return
|
||||
_thread_timestamps[thread_id] = [*recent, now]
|
||||
|
||||
await handle_user_message(
|
||||
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
|
||||
thread_id, content,
|
||||
interrupt_manager=ctx.interrupt_manager,
|
||||
)
|
||||
await _fire_and_forget_tracking(
|
||||
thread_id=thread_id,
|
||||
pool=ctx.pool,
|
||||
analytics_recorder=ctx.analytics_recorder,
|
||||
conversation_tracker=ctx.conversation_tracker,
|
||||
agent_name=None,
|
||||
tokens=0,
|
||||
cost=0.0,
|
||||
)
|
||||
|
||||
elif msg_type == "interrupt_response":
|
||||
approved = data.get("approved", False)
|
||||
await handle_interrupt_response(
|
||||
ws, graph, session_manager, callback_handler, thread_id, approved
|
||||
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
|
||||
thread_id, approved,
|
||||
interrupt_manager=ctx.interrupt_manager,
|
||||
)
|
||||
|
||||
else:
|
||||
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
|
||||
|
||||
|
||||
async def _fire_and_forget_tracking(
|
||||
thread_id: str,
|
||||
pool: object,
|
||||
analytics_recorder: object | None,
|
||||
conversation_tracker: object | None,
|
||||
agent_name: str | None,
|
||||
tokens: int,
|
||||
cost: float,
|
||||
) -> None:
|
||||
"""Fire-and-forget analytics/tracking; failures must NOT break chat."""
|
||||
try:
|
||||
if conversation_tracker is not None and pool is not None:
|
||||
await conversation_tracker.ensure_conversation(pool, thread_id)
|
||||
await conversation_tracker.record_turn(pool, thread_id, agent_name, tokens, cost)
|
||||
except Exception:
|
||||
logger.exception("Conversation tracker error for thread %s (suppressed)", thread_id)
|
||||
|
||||
try:
|
||||
if analytics_recorder is not None:
|
||||
await analytics_recorder.record(
|
||||
thread_id=thread_id,
|
||||
event_type="message",
|
||||
agent_name=agent_name,
|
||||
tokens_used=tokens,
|
||||
cost_usd=cost,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Analytics recorder error for thread %s (suppressed)", thread_id)
|
||||
|
||||
|
||||
def _has_interrupt(state: Any) -> bool:
|
||||
"""Check if the graph state has a pending interrupt."""
|
||||
tasks = getattr(state, "tasks", ())
|
||||
|
||||
153
backend/fixtures/demo_data.py
Normal file
153
backend/fixtures/demo_data.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Seed script -- inserts sample conversations and analytics events for demo purposes.
|
||||
|
||||
Usage:
|
||||
cd backend
|
||||
python fixtures/demo_data.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import psycopg
|
||||
|
||||
DATABASE_URL = os.environ.get(
|
||||
"DATABASE_URL",
|
||||
"postgresql://smart_support:dev_password@localhost:5432/smart_support",
|
||||
)
|
||||
|
||||
SAMPLE_CONVERSATIONS = [
|
||||
{
|
||||
"thread_id": "demo-thread-001",
|
||||
"agents_used": ["order_agent"],
|
||||
"turn_count": 3,
|
||||
"total_tokens": 1250,
|
||||
"total_cost_usd": 0.00375,
|
||||
"resolution_type": "resolved",
|
||||
"minutes_ago": 5,
|
||||
},
|
||||
{
|
||||
"thread_id": "demo-thread-002",
|
||||
"agents_used": ["order_agent", "refund_agent"],
|
||||
"turn_count": 6,
|
||||
"total_tokens": 3200,
|
||||
"total_cost_usd": 0.0096,
|
||||
"resolution_type": "resolved",
|
||||
"minutes_ago": 30,
|
||||
},
|
||||
{
|
||||
"thread_id": "demo-thread-003",
|
||||
"agents_used": ["general_agent"],
|
||||
"turn_count": 2,
|
||||
"total_tokens": 800,
|
||||
"total_cost_usd": 0.0024,
|
||||
"resolution_type": None,
|
||||
"minutes_ago": 60,
|
||||
},
|
||||
{
|
||||
"thread_id": "demo-thread-004",
|
||||
"agents_used": ["order_agent", "general_agent"],
|
||||
"turn_count": 8,
|
||||
"total_tokens": 4500,
|
||||
"total_cost_usd": 0.0135,
|
||||
"resolution_type": "escalated",
|
||||
"minutes_ago": 120,
|
||||
},
|
||||
{
|
||||
"thread_id": "demo-thread-005",
|
||||
"agents_used": ["refund_agent"],
|
||||
"turn_count": 4,
|
||||
"total_tokens": 2100,
|
||||
"total_cost_usd": 0.0063,
|
||||
"resolution_type": "resolved",
|
||||
"minutes_ago": 240,
|
||||
},
|
||||
]
|
||||
|
||||
SAMPLE_EVENTS = [
|
||||
{"thread_id": "demo-thread-001", "event_type": "message", "agent_name": "order_agent", "tokens_used": 400, "cost_usd": 0.0012, "success": True},
|
||||
{"thread_id": "demo-thread-001", "event_type": "tool_call", "agent_name": "order_agent", "tool_name": "get_order_status", "tokens_used": 0, "cost_usd": 0.0, "success": True},
|
||||
{"thread_id": "demo-thread-002", "event_type": "message", "agent_name": "order_agent", "tokens_used": 1600, "cost_usd": 0.0048, "success": True},
|
||||
{"thread_id": "demo-thread-002", "event_type": "message", "agent_name": "refund_agent", "tokens_used": 1600, "cost_usd": 0.0048, "success": True},
|
||||
{"thread_id": "demo-thread-002", "event_type": "tool_call", "agent_name": "refund_agent", "tool_name": "process_refund", "tokens_used": 0, "cost_usd": 0.0, "success": True},
|
||||
{"thread_id": "demo-thread-003", "event_type": "message", "agent_name": "general_agent", "tokens_used": 800, "cost_usd": 0.0024, "success": True},
|
||||
{"thread_id": "demo-thread-004", "event_type": "message", "agent_name": "order_agent", "tokens_used": 2000, "cost_usd": 0.006, "success": True},
|
||||
{"thread_id": "demo-thread-004", "event_type": "escalation", "agent_name": "general_agent", "tokens_used": 2500, "cost_usd": 0.0075, "success": False},
|
||||
{"thread_id": "demo-thread-005", "event_type": "message", "agent_name": "refund_agent", "tokens_used": 2100, "cost_usd": 0.0063, "success": True},
|
||||
]
|
||||
|
||||
_INSERT_CONVERSATION = """
|
||||
INSERT INTO conversations
|
||||
(thread_id, started_at, last_activity, turn_count, agents_used,
|
||||
total_tokens, total_cost_usd, resolution_type, ended_at)
|
||||
VALUES
|
||||
(%(thread_id)s, %(started_at)s, %(last_activity)s, %(turn_count)s,
|
||||
%(agents_used)s, %(total_tokens)s, %(total_cost_usd)s,
|
||||
%(resolution_type)s, %(ended_at)s)
|
||||
ON CONFLICT (thread_id) DO NOTHING
|
||||
"""
|
||||
|
||||
_INSERT_EVENT = """
|
||||
INSERT INTO analytics_events
|
||||
(thread_id, event_type, agent_name, tool_name, tokens_used, cost_usd, success)
|
||||
VALUES
|
||||
(%(thread_id)s, %(event_type)s, %(agent_name)s, %(tool_name)s,
|
||||
%(tokens_used)s, %(cost_usd)s, %(success)s)
|
||||
"""
|
||||
|
||||
|
||||
async def seed() -> None:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
async with await psycopg.AsyncConnection.connect(DATABASE_URL) as conn:
|
||||
print("Seeding conversations...")
|
||||
for conv in SAMPLE_CONVERSATIONS:
|
||||
started_at = now - timedelta(minutes=conv["minutes_ago"])
|
||||
last_activity = started_at + timedelta(minutes=conv["turn_count"] * 2)
|
||||
ended_at = last_activity if conv["resolution_type"] else None
|
||||
|
||||
await conn.execute(
|
||||
_INSERT_CONVERSATION,
|
||||
{
|
||||
"thread_id": conv["thread_id"],
|
||||
"started_at": started_at,
|
||||
"last_activity": last_activity,
|
||||
"turn_count": conv["turn_count"],
|
||||
"agents_used": conv["agents_used"],
|
||||
"total_tokens": conv["total_tokens"],
|
||||
"total_cost_usd": conv["total_cost_usd"],
|
||||
"resolution_type": conv["resolution_type"],
|
||||
"ended_at": ended_at,
|
||||
},
|
||||
)
|
||||
print(f" Inserted conversation {conv['thread_id']}")
|
||||
|
||||
print("Seeding analytics events...")
|
||||
for event in SAMPLE_EVENTS:
|
||||
await conn.execute(
|
||||
_INSERT_EVENT,
|
||||
{
|
||||
"thread_id": event["thread_id"],
|
||||
"event_type": event["event_type"],
|
||||
"agent_name": event.get("agent_name"),
|
||||
"tool_name": event.get("tool_name"),
|
||||
"tokens_used": event.get("tokens_used", 0),
|
||||
"cost_usd": event.get("cost_usd", 0.0),
|
||||
"success": event.get("success"),
|
||||
},
|
||||
)
|
||||
print(f" Inserted event {event['event_type']} for {event['thread_id']}")
|
||||
|
||||
await conn.commit()
|
||||
|
||||
print("Done. Demo data seeded successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(seed())
|
||||
238
backend/fixtures/sample_openapi.yaml
Normal file
238
backend/fixtures/sample_openapi.yaml
Normal file
@@ -0,0 +1,238 @@
|
||||
openapi: "3.0.3"
|
||||
info:
|
||||
title: "E-Commerce API"
|
||||
description: "Sample e-commerce API for Smart Support demo."
|
||||
version: "1.0.0"
|
||||
|
||||
servers:
|
||||
- url: "https://api.example-shop.com/v1"
|
||||
description: "Production server"
|
||||
|
||||
paths:
|
||||
/orders/{order_id}:
|
||||
get:
|
||||
operationId: getOrder
|
||||
summary: "Get order details"
|
||||
description: "Retrieve the full details of a specific order."
|
||||
parameters:
|
||||
- name: order_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Order details"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Order"
|
||||
|
||||
/orders/{order_id}/cancel:
|
||||
post:
|
||||
operationId: cancelOrder
|
||||
summary: "Cancel an order"
|
||||
description: "Cancel an order that has not yet been shipped."
|
||||
parameters:
|
||||
- name: order_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
reason:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Order cancelled"
|
||||
"400":
|
||||
description: "Order cannot be cancelled (already shipped)"
|
||||
|
||||
/orders/{order_id}/refund:
|
||||
post:
|
||||
operationId: refundOrder
|
||||
summary: "Request a refund"
|
||||
description: "Submit a refund request for a completed order."
|
||||
parameters:
|
||||
- name: order_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
amount:
|
||||
type: number
|
||||
description: "Refund amount in USD. Leave null for full refund."
|
||||
reason:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Refund submitted"
|
||||
"400":
|
||||
description: "Invalid refund request"
|
||||
|
||||
/customers/{customer_id}:
|
||||
get:
|
||||
operationId: getCustomer
|
||||
summary: "Get customer profile"
|
||||
description: "Retrieve customer profile and account information."
|
||||
parameters:
|
||||
- name: customer_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Customer profile"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Customer"
|
||||
|
||||
/customers/{customer_id}/orders:
|
||||
get:
|
||||
operationId: listCustomerOrders
|
||||
summary: "List customer orders"
|
||||
description: "Get a paginated list of orders for a customer."
|
||||
parameters:
|
||||
- name: customer_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: page
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 1
|
||||
- name: per_page
|
||||
in: query
|
||||
schema:
|
||||
type: integer
|
||||
default: 20
|
||||
responses:
|
||||
"200":
|
||||
description: "List of orders"
|
||||
|
||||
/products/{product_id}:
|
||||
get:
|
||||
operationId: getProduct
|
||||
summary: "Get product details"
|
||||
description: "Retrieve product information including inventory status."
|
||||
parameters:
|
||||
- name: product_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Product details"
|
||||
|
||||
/support/tickets:
|
||||
post:
|
||||
operationId: createSupportTicket
|
||||
summary: "Create support ticket"
|
||||
description: "Open a new support ticket for a customer issue."
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CreateTicketRequest"
|
||||
responses:
|
||||
"201":
|
||||
description: "Ticket created"
|
||||
|
||||
/support/tickets/{ticket_id}:
|
||||
get:
|
||||
operationId: getSupportTicket
|
||||
summary: "Get support ticket"
|
||||
description: "Retrieve a support ticket and its conversation history."
|
||||
parameters:
|
||||
- name: ticket_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: "Ticket details"
|
||||
|
||||
components:
|
||||
schemas:
|
||||
Order:
|
||||
type: object
|
||||
properties:
|
||||
order_id:
|
||||
type: string
|
||||
customer_id:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
enum: [pending, processing, shipped, delivered, cancelled, refunded]
|
||||
items:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/OrderItem"
|
||||
total_usd:
|
||||
type: number
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
OrderItem:
|
||||
type: object
|
||||
properties:
|
||||
product_id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
quantity:
|
||||
type: integer
|
||||
unit_price_usd:
|
||||
type: number
|
||||
|
||||
Customer:
|
||||
type: object
|
||||
properties:
|
||||
customer_id:
|
||||
type: string
|
||||
email:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
tier:
|
||||
type: string
|
||||
enum: [standard, premium, vip]
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
|
||||
CreateTicketRequest:
|
||||
type: object
|
||||
required: [customer_id, subject, description]
|
||||
properties:
|
||||
customer_id:
|
||||
type: string
|
||||
subject:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
priority:
|
||||
type: string
|
||||
enum: [low, medium, high, urgent]
|
||||
default: medium
|
||||
@@ -6,18 +6,23 @@ requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"fastapi>=0.115,<1.0",
|
||||
"uvicorn[standard]>=0.34,<1.0",
|
||||
"langgraph>=0.4,<1.0",
|
||||
"langgraph-supervisor>=0.0.12,<1.0",
|
||||
"langgraph>=1.0,<2.0",
|
||||
"langgraph-supervisor>=0.0.30,<1.0",
|
||||
"langgraph-checkpoint-postgres>=3.0,<4.0",
|
||||
"langchain-core>=0.3,<1.0",
|
||||
"langchain-anthropic>=0.3,<2.0",
|
||||
"langchain-openai>=0.3,<1.0",
|
||||
"langchain>=1.0,<2.0",
|
||||
"langchain-core>=1.0,<2.0",
|
||||
"langchain-anthropic>=1.0,<2.0",
|
||||
"langchain-openai>=1.0,<2.0",
|
||||
"langchain-google-genai>=2.1,<3.0",
|
||||
"psycopg[binary,pool]>=3.2,<4.0",
|
||||
"pydantic>=2.10,<3.0",
|
||||
"pydantic-settings>=2.7,<3.0",
|
||||
"pyyaml>=6.0,<7.0",
|
||||
"python-dotenv>=1.0,<2.0",
|
||||
"httpx>=0.28,<1.0",
|
||||
"openapi-spec-validator>=0.7,<1.0",
|
||||
"alembic>=1.13,<2.0",
|
||||
"structlog>=24.0,<26.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -27,6 +32,7 @@ dev = [
|
||||
"pytest-cov>=6.0,<7.0",
|
||||
"httpx>=0.28,<1.0",
|
||||
"ruff>=0.9,<1.0",
|
||||
"pytest-httpx>=0.35,<1.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
42
backend/templates/e-commerce.yaml
Normal file
42
backend/templates/e-commerce.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
agents:
|
||||
- name: order_lookup
|
||||
description: "Looks up order status and tracking information. Use for queries about order status, shipping, and delivery."
|
||||
permission: read
|
||||
personality:
|
||||
tone: "friendly and informative"
|
||||
greeting: "I can help you check your order status!"
|
||||
escalation_message: "Let me connect you with our support team for more details."
|
||||
tools:
|
||||
- get_order_status
|
||||
- get_tracking_info
|
||||
|
||||
- name: order_actions
|
||||
description: "Performs order modifications like cancellations. Use when the customer wants to cancel, modify, or change an order."
|
||||
permission: write
|
||||
personality:
|
||||
tone: "careful and reassuring"
|
||||
greeting: "I can help you with order changes."
|
||||
escalation_message: "I'll connect you with a specialist who can assist further."
|
||||
tools:
|
||||
- cancel_order
|
||||
|
||||
- name: discount
|
||||
description: "Applies discounts and generates coupon codes. Use when the customer asks for discounts, promotions, or coupon codes."
|
||||
permission: write
|
||||
personality:
|
||||
tone: "generous and accommodating"
|
||||
greeting: "I can help you with discounts and coupons!"
|
||||
escalation_message: "Let me connect you with our promotions team."
|
||||
tools:
|
||||
- apply_discount
|
||||
- generate_coupon
|
||||
|
||||
- name: fallback
|
||||
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
|
||||
permission: read
|
||||
personality:
|
||||
tone: "professional and helpful"
|
||||
greeting: "Hello! How can I help you today?"
|
||||
escalation_message: "Let me connect you with a human agent who can better assist you."
|
||||
tools:
|
||||
- fallback_respond
|
||||
31
backend/templates/fintech.yaml
Normal file
31
backend/templates/fintech.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
agents:
|
||||
- name: transaction_lookup
|
||||
description: "Looks up transaction history, balances, and payment details. Use for queries about transactions and account activity."
|
||||
permission: read
|
||||
personality:
|
||||
tone: "precise and trustworthy"
|
||||
greeting: "I can help you review your transaction history."
|
||||
escalation_message: "Let me connect you with our financial support team."
|
||||
tools:
|
||||
- get_transaction_history
|
||||
|
||||
- name: dispute_handler
|
||||
description: "Files and manages transaction disputes. Use when the customer wants to dispute a charge or check dispute status."
|
||||
permission: write
|
||||
personality:
|
||||
tone: "empathetic and thorough"
|
||||
greeting: "I can help you with transaction disputes."
|
||||
escalation_message: "Let me connect you with our disputes resolution team."
|
||||
tools:
|
||||
- file_dispute
|
||||
- check_dispute_status
|
||||
|
||||
- name: fallback
|
||||
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
|
||||
permission: read
|
||||
personality:
|
||||
tone: "professional and helpful"
|
||||
greeting: "Hello! How can I help you today?"
|
||||
escalation_message: "Let me connect you with a human agent who can better assist you."
|
||||
tools:
|
||||
- fallback_respond
|
||||
31
backend/templates/saas.yaml
Normal file
31
backend/templates/saas.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
agents:
|
||||
- name: account_lookup
|
||||
description: "Looks up account status, subscription details, and billing history. Use for queries about account information."
|
||||
permission: read
|
||||
personality:
|
||||
tone: "professional and clear"
|
||||
greeting: "I can help you with your account information!"
|
||||
escalation_message: "Let me connect you with our account support team."
|
||||
tools:
|
||||
- get_account_status
|
||||
|
||||
- name: subscription_management
|
||||
description: "Manages subscription changes like plan upgrades, downgrades, and cancellations. Use when the customer wants to change their subscription."
|
||||
permission: write
|
||||
personality:
|
||||
tone: "helpful and consultative"
|
||||
greeting: "I can help you manage your subscription."
|
||||
escalation_message: "Let me connect you with our billing specialist."
|
||||
tools:
|
||||
- change_plan
|
||||
- cancel_subscription
|
||||
|
||||
- name: fallback
|
||||
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
|
||||
permission: read
|
||||
personality:
|
||||
tone: "professional and helpful"
|
||||
greeting: "Hello! How can I help you today?"
|
||||
escalation_message: "Let me connect you with a human agent who can better assist you."
|
||||
tools:
|
||||
- fallback_respond
|
||||
@@ -15,6 +15,16 @@ if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_rate_limit_state() -> None:
|
||||
"""Clear module-level rate limit state between tests to prevent leakage."""
|
||||
import app.ws_handler as ws_handler
|
||||
|
||||
ws_handler._thread_timestamps.clear()
|
||||
yield
|
||||
ws_handler._thread_timestamps.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings() -> Settings:
|
||||
return Settings(
|
||||
|
||||
230
backend/tests/e2e/conftest.py
Normal file
230
backend/tests/e2e/conftest.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""E2E test fixtures -- full FastAPI app with mocked LLM and database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.analytics.api import router as analytics_router
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.openapi.review_api import _job_store, router as openapi_router
|
||||
from app.replay.api import router as replay_router
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph helpers -- simulate LangGraph streaming behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AsyncIterHelper:
|
||||
"""Make a list behave as an async iterator."""
|
||||
|
||||
def __init__(self, items: list) -> None:
|
||||
self._items = list(items)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._items:
|
||||
raise StopAsyncIteration
|
||||
return self._items.pop(0)
|
||||
|
||||
|
||||
def make_chunk(content: str, node: str = "order_lookup") -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = content
|
||||
c.tool_calls = []
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def make_tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = ""
|
||||
c.tool_calls = [{"name": name, "args": args}]
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def make_state(*, interrupt: bool = False, data: dict | None = None) -> Any:
|
||||
s = MagicMock()
|
||||
if interrupt:
|
||||
obj = MagicMock()
|
||||
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
|
||||
t = MagicMock()
|
||||
t.interrupts = (obj,)
|
||||
s.tasks = (t,)
|
||||
else:
|
||||
s.tasks = ()
|
||||
return s
|
||||
|
||||
|
||||
def make_graph(
|
||||
chunks: list | None = None,
|
||||
state: Any = None,
|
||||
resume_chunks: list | None = None,
|
||||
) -> MagicMock:
|
||||
"""Build a mock LangGraph CompiledStateGraph."""
|
||||
g = MagicMock()
|
||||
|
||||
if state is None:
|
||||
state = make_state()
|
||||
|
||||
streams = [chunks or [], resume_chunks or []]
|
||||
idx = {"n": 0}
|
||||
|
||||
def astream_side_effect(*a, **kw):
|
||||
i = min(idx["n"], len(streams) - 1)
|
||||
idx["n"] += 1
|
||||
return AsyncIterHelper(list(streams[i]))
|
||||
|
||||
g.astream = MagicMock(side_effect=astream_side_effect)
|
||||
g.aget_state = AsyncMock(return_value=state)
|
||||
return g
|
||||
|
||||
|
||||
def make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||
"""Build a GraphContext wrapping a mock graph."""
|
||||
g = graph or make_graph()
|
||||
registry = MagicMock()
|
||||
registry.list_agents = MagicMock(return_value=())
|
||||
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake database pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakeCursor:
|
||||
"""Minimal async cursor returning pre-configured rows."""
|
||||
|
||||
def __init__(self, rows: list[dict]) -> None:
|
||||
self._rows = rows
|
||||
|
||||
async def fetchall(self) -> list[dict]:
|
||||
return self._rows
|
||||
|
||||
async def fetchone(self) -> tuple | dict | None:
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
"""Fake async connection that returns a FakeCursor."""
|
||||
|
||||
def __init__(self, rows: list[dict]) -> None:
|
||||
self._rows = rows
|
||||
|
||||
async def execute(self, query: str, params: dict | None = None) -> FakeCursor:
|
||||
return FakeCursor(self._rows)
|
||||
|
||||
|
||||
class FakePool:
|
||||
"""Minimal pool that yields a fake connection."""
|
||||
|
||||
def __init__(self, rows: list[dict] | None = None) -> None:
|
||||
self._rows = rows or []
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self):
|
||||
yield FakeConnection(self._rows)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_e2e_app(
|
||||
graph: MagicMock | None = None,
|
||||
pool: FakePool | None = None,
|
||||
session_ttl: int = 3600,
|
||||
interrupt_ttl: int = 1800,
|
||||
) -> FastAPI:
|
||||
"""Create a FastAPI app wired with mocked dependencies for E2E testing."""
|
||||
g = graph or make_graph()
|
||||
graph_ctx = make_graph_ctx(g)
|
||||
p = pool or FakePool()
|
||||
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||
|
||||
app = FastAPI(title="Smart Support E2E Test")
|
||||
app.include_router(openapi_router)
|
||||
app.include_router(replay_router)
|
||||
app.include_router(analytics_router)
|
||||
|
||||
app.state.graph_ctx = graph_ctx
|
||||
app.state.session_manager = sm
|
||||
app.state.interrupt_manager = im
|
||||
app.state.pool = p
|
||||
app.state.settings = MagicMock(llm_model="test-model")
|
||||
app.state.analytics_recorder = AsyncMock()
|
||||
app.state.conversation_tracker = AsyncMock()
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
def health_check() -> dict:
|
||||
return {"status": "ok", "version": "test"}
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(ws: WebSocket) -> None:
|
||||
await ws.accept()
|
||||
try:
|
||||
while True:
|
||||
raw_data = await ws.receive_text()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=app.state.graph_ctx,
|
||||
session_manager=app.state.session_manager,
|
||||
callback_handler=TokenUsageCallbackHandler(model_name="test-model"),
|
||||
interrupt_manager=app.state.interrupt_manager,
|
||||
analytics_recorder=app.state.analytics_recorder,
|
||||
conversation_tracker=app.state.conversation_tracker,
|
||||
pool=app.state.pool,
|
||||
)
|
||||
await dispatch_message(ws, ws_ctx, raw_data)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_graph():
|
||||
"""Default graph fixture -- returns tokens and message_complete."""
|
||||
return make_graph(
|
||||
chunks=[make_chunk("Order 1042 is "), make_chunk("shipped.")]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_app(e2e_graph):
|
||||
"""Default E2E app fixture."""
|
||||
return create_e2e_app(graph=e2e_graph)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def e2e_client(e2e_app):
|
||||
"""Async HTTP client for E2E tests."""
|
||||
transport = ASGITransport(app=e2e_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_openapi_job_store():
|
||||
"""Clear the in-memory job store between tests."""
|
||||
_job_store.clear()
|
||||
yield
|
||||
_job_store.clear()
|
||||
384
backend/tests/e2e/test_chat_flows.py
Normal file
384
backend/tests/e2e/test_chat_flows.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""E2E tests for critical chat user flows (flows 1-4).
|
||||
|
||||
Flow 1: Happy path -- query order, get answer
|
||||
Flow 2: Approval flow -- write operation, interrupt, approve, execute
|
||||
Flow 3: Rejection flow -- write operation, interrupt, reject, no execution
|
||||
Flow 4: Multi-turn context -- sequential messages in same session
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from tests.e2e.conftest import (
|
||||
create_e2e_app,
|
||||
make_chunk,
|
||||
make_graph,
|
||||
make_state,
|
||||
make_tool_chunk,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
class TestFlow1HappyPath:
|
||||
"""Flow 1: query order -> get answer with streaming tokens."""
|
||||
|
||||
def test_websocket_happy_path_order_query(self) -> None:
|
||||
graph = make_graph(
|
||||
chunks=[
|
||||
make_tool_chunk("get_order_status", {"order_id": "1042"}),
|
||||
make_chunk("Order 1042 has been shipped and is on its way."),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-happy-1",
|
||||
"content": "What is the status of order 1042?",
|
||||
})
|
||||
|
||||
messages = []
|
||||
while True:
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
|
||||
tool_calls = [m for m in messages if m["type"] == "tool_call"]
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0]["tool"] == "get_order_status"
|
||||
assert tool_calls[0]["args"] == {"order_id": "1042"}
|
||||
|
||||
tokens = [m for m in messages if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
assert "shipped" in tokens[0]["content"]
|
||||
|
||||
completes = [m for m in messages if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
assert completes[0]["thread_id"] == "e2e-happy-1"
|
||||
|
||||
def test_websocket_multiple_token_stream(self) -> None:
|
||||
"""Verify streaming returns multiple token chunks."""
|
||||
graph = make_graph(
|
||||
chunks=[
|
||||
make_chunk("Your order "),
|
||||
make_chunk("1042 "),
|
||||
make_chunk("was delivered "),
|
||||
make_chunk("yesterday."),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-stream-1",
|
||||
"content": "Where is my order?",
|
||||
})
|
||||
|
||||
messages = _collect_until_complete(ws)
|
||||
|
||||
tokens = [m for m in messages if m["type"] == "token"]
|
||||
assert len(tokens) == 4
|
||||
full_text = "".join(t["content"] for t in tokens)
|
||||
assert "1042" in full_text
|
||||
assert "delivered" in full_text
|
||||
|
||||
|
||||
class TestFlow2ApprovalFlow:
|
||||
"""Flow 2: write operation -> interrupt -> approve -> execute."""
|
||||
|
||||
def test_interrupt_approve_executes_action(self) -> None:
|
||||
interrupt_state = make_state(
|
||||
interrupt=True,
|
||||
data={"action": "cancel_order", "order_id": "1042"},
|
||||
)
|
||||
graph = make_graph(
|
||||
chunks=[],
|
||||
state=interrupt_state,
|
||||
resume_chunks=[
|
||||
make_chunk("Order 1042 has been cancelled successfully.", "order_actions"),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Step 1: Send cancel request
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-approve-1",
|
||||
"content": "Cancel order 1042",
|
||||
})
|
||||
|
||||
messages = _collect_until_type(ws, "interrupt")
|
||||
|
||||
interrupts = [m for m in messages if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
assert interrupts[0]["action"] == "cancel_order"
|
||||
assert interrupts[0]["thread_id"] == "e2e-approve-1"
|
||||
|
||||
# Step 2: Approve the interrupt
|
||||
ws.send_json({
|
||||
"type": "interrupt_response",
|
||||
"thread_id": "e2e-approve-1",
|
||||
"approved": True,
|
||||
})
|
||||
|
||||
resume_messages = _collect_until_complete(ws)
|
||||
|
||||
tokens = [m for m in resume_messages if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
assert "cancelled" in tokens[0]["content"]
|
||||
assert tokens[0]["agent"] == "order_actions"
|
||||
|
||||
completes = [m for m in resume_messages if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
|
||||
|
||||
class TestFlow3RejectionFlow:
|
||||
"""Flow 3: write operation -> interrupt -> reject -> no execution."""
|
||||
|
||||
def test_interrupt_reject_does_not_execute(self) -> None:
|
||||
interrupt_state = make_state(
|
||||
interrupt=True,
|
||||
data={"action": "cancel_order", "order_id": "1042"},
|
||||
)
|
||||
graph = make_graph(
|
||||
chunks=[],
|
||||
state=interrupt_state,
|
||||
resume_chunks=[
|
||||
make_chunk("Understood. Order 1042 will remain active.", "order_actions"),
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Step 1: Trigger interrupt
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-reject-1",
|
||||
"content": "Cancel order 1042",
|
||||
})
|
||||
|
||||
messages = _collect_until_type(ws, "interrupt")
|
||||
assert any(m["type"] == "interrupt" for m in messages)
|
||||
|
||||
# Step 2: Reject
|
||||
ws.send_json({
|
||||
"type": "interrupt_response",
|
||||
"thread_id": "e2e-reject-1",
|
||||
"approved": False,
|
||||
})
|
||||
|
||||
resume_messages = _collect_until_complete(ws)
|
||||
|
||||
tokens = [m for m in resume_messages if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
assert "remain active" in tokens[0]["content"]
|
||||
|
||||
# Verify graph.astream was called with resume=False
|
||||
resume_call = graph.astream.call_args_list[-1]
|
||||
command = resume_call[0][0]
|
||||
assert command.resume is False
|
||||
|
||||
|
||||
class TestFlow4MultiTurnContext:
|
||||
"""Flow 4: multi-turn conversation in the same session."""
|
||||
|
||||
def test_multi_turn_messages_share_session(self) -> None:
|
||||
"""Multiple messages in the same thread_id maintain session context."""
|
||||
graph = make_graph(
|
||||
chunks=[make_chunk("Order 1042 status: shipped.")],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Turn 1: Query order
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-multi-1",
|
||||
"content": "What is the status of order 1042?",
|
||||
})
|
||||
turn1 = _collect_until_complete(ws)
|
||||
assert any(m["type"] == "message_complete" for m in turn1)
|
||||
|
||||
# Turn 2: Follow-up in same thread
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-multi-1",
|
||||
"content": "When will it arrive?",
|
||||
})
|
||||
turn2 = _collect_until_complete(ws)
|
||||
assert any(m["type"] == "message_complete" for m in turn2)
|
||||
|
||||
# Turn 3: Another follow-up
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-multi-1",
|
||||
"content": "Can you track it?",
|
||||
})
|
||||
turn3 = _collect_until_complete(ws)
|
||||
assert any(m["type"] == "message_complete" for m in turn3)
|
||||
|
||||
# Verify all turns used the same thread_id in graph calls
|
||||
for call in graph.astream.call_args_list:
|
||||
config = call[1].get("config", call[0][1] if len(call[0]) > 1 else {})
|
||||
assert config["configurable"]["thread_id"] == "e2e-multi-1"
|
||||
|
||||
def test_separate_threads_are_independent(self) -> None:
|
||||
"""Different thread_ids have independent sessions."""
|
||||
graph = make_graph(
|
||||
chunks=[make_chunk("Response.")],
|
||||
)
|
||||
app = create_e2e_app(graph=graph)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# Thread A
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-thread-a",
|
||||
"content": "Hello from thread A",
|
||||
})
|
||||
_collect_until_complete(ws)
|
||||
|
||||
# Thread B
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-thread-b",
|
||||
"content": "Hello from thread B",
|
||||
})
|
||||
_collect_until_complete(ws)
|
||||
|
||||
# Both threads should exist as separate sessions
|
||||
sm = app.state.session_manager
|
||||
assert sm.get_state("e2e-thread-a") is not None
|
||||
assert sm.get_state("e2e-thread-b") is not None
|
||||
|
||||
|
||||
class TestChatEdgeCases:
|
||||
"""Edge cases and error handling for the chat WebSocket."""
|
||||
|
||||
def test_invalid_json_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_text("not valid json")
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert "Invalid JSON" in msg["message"]
|
||||
|
||||
def test_missing_thread_id_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({"type": "message", "content": "hello"})
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert "thread_id" in msg["message"]
|
||||
|
||||
def test_empty_content_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-err-1",
|
||||
"content": "",
|
||||
})
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
|
||||
def test_expired_session_returns_error(self) -> None:
|
||||
graph = make_graph(chunks=[make_chunk("Response.")])
|
||||
app = create_e2e_app(graph=graph, session_ttl=0)
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
# First message creates the session (TTL=0)
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-expired-1",
|
||||
"content": "hello",
|
||||
})
|
||||
_collect_until_complete_or_error(ws)
|
||||
|
||||
# Second message finds the session expired (TTL=0)
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-expired-1",
|
||||
"content": "hello again",
|
||||
})
|
||||
messages = _collect_until_complete_or_error(ws)
|
||||
errors = [m for m in messages if m["type"] == "error"]
|
||||
assert len(errors) >= 1
|
||||
assert "expired" in errors[0]["message"].lower()
|
||||
|
||||
def test_oversized_message_returns_error(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_text("x" * 40_000)
|
||||
msg = ws.receive_json()
|
||||
assert msg["type"] == "error"
|
||||
assert "too large" in msg["message"].lower()
|
||||
|
||||
def test_health_endpoint(self) -> None:
|
||||
app = create_e2e_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _collect_until_complete(ws, *, max_messages: int = 50) -> list[dict]:
|
||||
"""Receive WebSocket messages until message_complete or error."""
|
||||
messages = []
|
||||
for _ in range(max_messages):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
return messages
|
||||
|
||||
|
||||
def _collect_until_type(ws, msg_type: str, *, max_messages: int = 50) -> list[dict]:
|
||||
"""Receive until a specific message type is received."""
|
||||
messages = []
|
||||
for _ in range(max_messages):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] == msg_type:
|
||||
break
|
||||
return messages
|
||||
|
||||
|
||||
def _collect_until_complete_or_error(ws, *, max_messages: int = 50) -> list[dict]:
|
||||
"""Receive until message_complete or error."""
|
||||
messages = []
|
||||
for _ in range(max_messages):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
return messages
|
||||
201
backend/tests/e2e/test_openapi_import.py
Normal file
201
backend/tests/e2e/test_openapi_import.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""E2E tests for OpenAPI import flow (flow 5).
|
||||
|
||||
Flow 5: paste OpenAPI spec URL -> import job -> classify endpoints ->
|
||||
review classifications -> approve -> tool generation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.openapi.models import ClassificationResult, EndpointInfo
|
||||
from app.openapi.review_api import _job_store
|
||||
from tests.e2e.conftest import create_e2e_app
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
def _fake_endpoint(
|
||||
path: str = "/orders/{id}",
|
||||
method: str = "GET",
|
||||
operation_id: str = "getOrder",
|
||||
summary: str = "Get order details",
|
||||
) -> EndpointInfo:
|
||||
return EndpointInfo(
|
||||
path=path,
|
||||
method=method,
|
||||
operation_id=operation_id,
|
||||
summary=summary,
|
||||
description="",
|
||||
parameters=(),
|
||||
request_body_schema=None,
|
||||
response_schema=None,
|
||||
)
|
||||
|
||||
|
||||
def _fake_classification(
|
||||
endpoint: EndpointInfo | None = None,
|
||||
access_type: str = "read",
|
||||
needs_interrupt: bool = False,
|
||||
agent_group: str = "order_lookup",
|
||||
) -> ClassificationResult:
|
||||
return ClassificationResult(
|
||||
endpoint=endpoint or _fake_endpoint(),
|
||||
access_type=access_type,
|
||||
customer_params=["order_id"],
|
||||
agent_group=agent_group,
|
||||
confidence=0.95,
|
||||
needs_interrupt=needs_interrupt,
|
||||
)
|
||||
|
||||
|
||||
class TestFlow5OpenAPIImport:
|
||||
"""Flow 5: full OpenAPI import lifecycle."""
|
||||
|
||||
def test_import_job_lifecycle(self) -> None:
|
||||
"""Start import -> check status -> review classifications -> approve."""
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
# Step 1: Start import job
|
||||
resp = client.post(
|
||||
"/api/v1/openapi/import",
|
||||
json={"url": "https://api.example.com/openapi.json"},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
body = resp.json()
|
||||
assert body["status"] == "pending"
|
||||
job_id = body["job_id"]
|
||||
|
||||
# Step 2: Check job status (still pending since background task hasn't run)
|
||||
resp = client.get(f"/api/v1/openapi/jobs/{job_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["job_id"] == job_id
|
||||
|
||||
def test_import_job_with_classifications(self) -> None:
|
||||
"""Simulate completed import and review classified endpoints."""
|
||||
app = create_e2e_app()
|
||||
|
||||
# Seed a completed job directly
|
||||
ep_read = _fake_endpoint("/orders/{id}", "GET", "getOrder", "Get order")
|
||||
ep_write = _fake_endpoint("/orders/{id}/cancel", "POST", "cancelOrder", "Cancel order")
|
||||
|
||||
clf_read = _fake_classification(ep_read, "read", False, "order_lookup")
|
||||
clf_write = _fake_classification(ep_write, "write", True, "order_actions")
|
||||
|
||||
job_id = "test-job-001"
|
||||
_job_store[job_id] = {
|
||||
"job_id": job_id,
|
||||
"status": "done",
|
||||
"spec_url": "https://api.example.com/openapi.json",
|
||||
"total_endpoints": 2,
|
||||
"classified_count": 2,
|
||||
"error_message": None,
|
||||
"classifications": [clf_read, clf_write],
|
||||
}
|
||||
|
||||
with TestClient(app) as client:
|
||||
# Step 1: Get classifications
|
||||
resp = client.get(f"/api/v1/openapi/jobs/{job_id}/classifications")
|
||||
assert resp.status_code == 200
|
||||
classifications = resp.json()
|
||||
assert len(classifications) == 2
|
||||
|
||||
# Verify read endpoint
|
||||
read_clf = classifications[0]
|
||||
assert read_clf["access_type"] == "read"
|
||||
assert read_clf["needs_interrupt"] is False
|
||||
assert read_clf["endpoint"]["path"] == "/orders/{id}"
|
||||
|
||||
# Verify write endpoint
|
||||
write_clf = classifications[1]
|
||||
assert write_clf["access_type"] == "write"
|
||||
assert write_clf["needs_interrupt"] is True
|
||||
assert write_clf["endpoint"]["path"] == "/orders/{id}/cancel"
|
||||
|
||||
# Step 2: Update a classification
|
||||
resp = client.put(
|
||||
f"/api/v1/openapi/jobs/{job_id}/classifications/0",
|
||||
json={
|
||||
"access_type": "write",
|
||||
"needs_interrupt": True,
|
||||
"agent_group": "order_actions",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
updated = resp.json()
|
||||
assert updated["access_type"] == "write"
|
||||
assert updated["needs_interrupt"] is True
|
||||
assert updated["agent_group"] == "order_actions"
|
||||
|
||||
# Step 3: Approve the job
|
||||
resp = client.post(f"/api/v1/openapi/jobs/{job_id}/approve")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "approved"
|
||||
|
||||
def test_import_nonexistent_job_returns_404(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/openapi/jobs/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_import_invalid_url_returns_422(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/api/v1/openapi/import", json={"url": "not-a-url"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
def test_classification_index_out_of_range(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
job_id = "test-job-range"
|
||||
_job_store[job_id] = {
|
||||
"job_id": job_id,
|
||||
"status": "done",
|
||||
"spec_url": "https://example.com/spec.json",
|
||||
"total_endpoints": 1,
|
||||
"classified_count": 1,
|
||||
"error_message": None,
|
||||
"classifications": [_fake_classification()],
|
||||
}
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.put(
|
||||
f"/api/v1/openapi/jobs/{job_id}/classifications/99",
|
||||
json={
|
||||
"access_type": "read",
|
||||
"needs_interrupt": False,
|
||||
"agent_group": "order_lookup",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_update_classification_invalid_agent_group(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
job_id = "test-job-invalid"
|
||||
_job_store[job_id] = {
|
||||
"job_id": job_id,
|
||||
"status": "done",
|
||||
"spec_url": "https://example.com/spec.json",
|
||||
"total_endpoints": 1,
|
||||
"classified_count": 1,
|
||||
"error_message": None,
|
||||
"classifications": [_fake_classification()],
|
||||
}
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.put(
|
||||
f"/api/v1/openapi/jobs/{job_id}/classifications/0",
|
||||
json={
|
||||
"access_type": "read",
|
||||
"needs_interrupt": False,
|
||||
"agent_group": "invalid group!", # spaces and special chars
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
230
backend/tests/e2e/test_replay_analytics.py
Normal file
230
backend/tests/e2e/test_replay_analytics.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""E2E tests for replay and analytics flows (flow 6).
|
||||
|
||||
Flow 6: list conversations -> select one -> step-by-step replay.
|
||||
Also tests the analytics dashboard endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from tests.e2e.conftest import FakePool, create_e2e_app
|
||||
|
||||
pytestmark = pytest.mark.e2e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom pool that returns specific data per query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ReplayPool(FakePool):
|
||||
"""Pool that returns different data depending on the SQL query."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversations: list[dict] | None = None,
|
||||
checkpoints: list[dict] | None = None,
|
||||
analytics_rows: list[dict] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._conversations = conversations or []
|
||||
self._checkpoints = checkpoints or []
|
||||
self._analytics = analytics_rows or []
|
||||
|
||||
class _Conn:
|
||||
def __init__(self, convos, checkpoints, analytics):
|
||||
self._convos = convos
|
||||
self._checkpoints = checkpoints
|
||||
self._analytics = analytics
|
||||
|
||||
async def execute(self, query: str, params=None):
|
||||
from tests.e2e.conftest import FakeCursor
|
||||
|
||||
if "COUNT" in query and "conversations" in query:
|
||||
return FakeCursor([(len(self._convos),)])
|
||||
if "conversations" in query and "SELECT" in query:
|
||||
# Respect LIMIT/OFFSET from params if provided
|
||||
rows = self._convos
|
||||
if params:
|
||||
offset = params.get("offset", 0)
|
||||
limit = params.get("limit", len(rows))
|
||||
rows = rows[offset : offset + limit]
|
||||
return FakeCursor(rows)
|
||||
if "checkpoints" in query:
|
||||
return FakeCursor(self._checkpoints)
|
||||
# Analytics queries
|
||||
return FakeCursor(self._analytics)
|
||||
|
||||
def connection(self):
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
conn = self._Conn(self._conversations, self._checkpoints, self._analytics)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _ctx():
|
||||
yield conn
|
||||
|
||||
return _ctx()
|
||||
|
||||
|
||||
class TestFlow6ReplayConversation:
|
||||
"""Flow 6: list conversations -> select one -> step replay."""
|
||||
|
||||
def test_list_conversations(self) -> None:
|
||||
now = datetime.now(tz=timezone.utc).isoformat()
|
||||
conversations = [
|
||||
{
|
||||
"thread_id": "conv-001",
|
||||
"created_at": now,
|
||||
"last_activity": now,
|
||||
"status": "active",
|
||||
"total_tokens": 150,
|
||||
"total_cost_usd": 0.003,
|
||||
},
|
||||
{
|
||||
"thread_id": "conv-002",
|
||||
"created_at": now,
|
||||
"last_activity": now,
|
||||
"status": "completed",
|
||||
"total_tokens": 300,
|
||||
"total_cost_usd": 0.006,
|
||||
},
|
||||
]
|
||||
pool = ReplayPool(conversations=conversations)
|
||||
app = create_e2e_app(pool=pool)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/conversations")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
data = body["data"]
|
||||
assert len(data["conversations"]) == 2
|
||||
assert data["conversations"][0]["thread_id"] == "conv-001"
|
||||
assert data["conversations"][1]["thread_id"] == "conv-002"
|
||||
assert data["total"] == 2
|
||||
|
||||
def test_list_conversations_pagination(self) -> None:
|
||||
conversations = [
|
||||
{
|
||||
"thread_id": f"conv-{i:03d}",
|
||||
"created_at": "2026-04-01T00:00:00Z",
|
||||
"last_activity": "2026-04-01T00:00:00Z",
|
||||
"status": "active",
|
||||
"total_tokens": 100,
|
||||
"total_cost_usd": 0.001,
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
pool = ReplayPool(conversations=conversations)
|
||||
app = create_e2e_app(pool=pool)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/conversations", params={"page": 1, "per_page": 2})
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
data = body["data"]
|
||||
assert data["total"] == 5
|
||||
assert data["page"] == 1
|
||||
assert data["per_page"] == 2
|
||||
assert len(data["conversations"]) == 2
|
||||
|
||||
def test_replay_thread_not_found(self) -> None:
|
||||
pool = ReplayPool(checkpoints=[])
|
||||
app = create_e2e_app(pool=pool)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/replay/nonexistent-thread")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_replay_invalid_thread_id_format(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
# Thread ID with special chars fails regex validation
|
||||
resp = client.get("/api/v1/replay/invalid%20thread%21%40")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestAnalyticsDashboard:
|
||||
"""Analytics endpoint tests."""
|
||||
|
||||
def test_analytics_invalid_range_format(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/analytics", params={"range": "invalid"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_analytics_range_too_large(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/analytics", params={"range": "999d"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_analytics_range_zero_rejected(self) -> None:
|
||||
app = create_e2e_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/analytics", params={"range": "0d"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
class TestFullUserJourney:
|
||||
"""End-to-end journey: chat -> then check replay list shows the conversation."""
|
||||
|
||||
def test_chat_then_check_conversations_endpoint(self) -> None:
|
||||
"""After chatting via WebSocket, the conversations endpoint is reachable."""
|
||||
from tests.e2e.conftest import make_chunk, make_graph
|
||||
|
||||
graph = make_graph(chunks=[make_chunk("Your order is shipped.")])
|
||||
now = datetime.now(tz=timezone.utc).isoformat()
|
||||
pool = ReplayPool(
|
||||
conversations=[
|
||||
{
|
||||
"thread_id": "e2e-journey-1",
|
||||
"created_at": now,
|
||||
"last_activity": now,
|
||||
"status": "active",
|
||||
"total_tokens": 50,
|
||||
"total_cost_usd": 0.001,
|
||||
},
|
||||
],
|
||||
)
|
||||
app = create_e2e_app(graph=graph, pool=pool)
|
||||
|
||||
with TestClient(app) as client:
|
||||
# Step 1: Chat via WebSocket
|
||||
with client.websocket_connect("/ws") as ws:
|
||||
ws.send_json({
|
||||
"type": "message",
|
||||
"thread_id": "e2e-journey-1",
|
||||
"content": "Where is my order?",
|
||||
})
|
||||
messages = []
|
||||
for _ in range(20):
|
||||
msg = ws.receive_json()
|
||||
messages.append(msg)
|
||||
if msg["type"] in ("message_complete", "error"):
|
||||
break
|
||||
assert any(m["type"] == "message_complete" for m in messages)
|
||||
|
||||
# Step 2: Check conversations endpoint
|
||||
resp = client.get("/api/v1/conversations")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
assert any(
|
||||
c["thread_id"] == "e2e-journey-1"
|
||||
for c in body["data"]["conversations"]
|
||||
)
|
||||
|
||||
# Step 3: Health check still works
|
||||
resp = client.get("/api/v1/health")
|
||||
assert resp.status_code == 200
|
||||
183
backend/tests/integration/test_analytics_api.py
Normal file
183
backend/tests/integration/test_analytics_api.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Integration tests for the /api/v1/analytics endpoint.
|
||||
|
||||
Tests the full API layer (routing, parameter validation, serialization,
|
||||
error handling) with a mocked database pool.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_SAMPLE_RESULT = AnalyticsResult(
|
||||
range="7d",
|
||||
total_conversations=42,
|
||||
resolution_rate=0.85,
|
||||
escalation_rate=0.05,
|
||||
avg_turns_per_conversation=3.2,
|
||||
avg_cost_per_conversation_usd=0.012,
|
||||
agent_usage=(),
|
||||
interrupt_stats=InterruptStats(total=10, approved=7, rejected=2, expired=1),
|
||||
)
|
||||
|
||||
|
||||
def _build_app():
|
||||
"""Build a minimal FastAPI app with the analytics router and mocked deps."""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.analytics.api import router as analytics_router
|
||||
from app.api_utils import envelope
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(analytics_router)
|
||||
|
||||
@test_app.exception_handler(Exception)
|
||||
async def _catch_all(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=envelope(None, success=False, error="Internal server error"),
|
||||
)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
@test_app.exception_handler(HTTPException)
|
||||
async def _http_exc(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=envelope(None, success=False, error=exc.detail),
|
||||
)
|
||||
|
||||
@test_app.exception_handler(RequestValidationError)
|
||||
async def _validation_exc(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=envelope(None, success=False, error=str(exc)),
|
||||
)
|
||||
|
||||
# No admin_api_key set -> auth is skipped
|
||||
test_app.state.settings = MagicMock(admin_api_key="")
|
||||
test_app.state.pool = MagicMock()
|
||||
|
||||
return test_app
|
||||
|
||||
|
||||
class TestAnalyticsValidRange:
|
||||
"""Test analytics endpoint with valid range parameters."""
|
||||
|
||||
async def test_valid_range_7d_returns_envelope(self) -> None:
|
||||
"""GET /api/v1/analytics?range=7d returns success envelope with data."""
|
||||
test_app = _build_app()
|
||||
with patch(
|
||||
"app.analytics.api.get_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_SAMPLE_RESULT,
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/analytics", params={"range": "7d"})
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
assert body["error"] is None
|
||||
assert body["data"]["total_conversations"] == 42
|
||||
assert body["data"]["resolution_rate"] == 0.85
|
||||
|
||||
async def test_default_range_returns_success(self) -> None:
|
||||
"""GET /api/v1/analytics with no range param defaults to 7d."""
|
||||
test_app = _build_app()
|
||||
with patch(
|
||||
"app.analytics.api.get_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_SAMPLE_RESULT,
|
||||
) as mock_get:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/analytics")
|
||||
|
||||
assert resp.status_code == 200
|
||||
# Verify default range of 7 days was passed
|
||||
mock_get.assert_called_once()
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[1].get("range_days", call_args[0][1] if len(call_args[0]) > 1 else None) in (7, None) or call_args[0][1] == 7
|
||||
|
||||
async def test_large_range_365d_works(self) -> None:
|
||||
"""GET /api/v1/analytics?range=365d is accepted (max boundary)."""
|
||||
test_app = _build_app()
|
||||
result = AnalyticsResult(
|
||||
range="365d",
|
||||
total_conversations=1000,
|
||||
resolution_rate=0.9,
|
||||
escalation_rate=0.02,
|
||||
avg_turns_per_conversation=4.0,
|
||||
avg_cost_per_conversation_usd=0.01,
|
||||
agent_usage=(),
|
||||
interrupt_stats=InterruptStats(),
|
||||
)
|
||||
with patch(
|
||||
"app.analytics.api.get_analytics",
|
||||
new_callable=AsyncMock,
|
||||
return_value=result,
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/analytics", params={"range": "365d"})
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["success"] is True
|
||||
|
||||
|
||||
class TestAnalyticsInvalidRange:
|
||||
"""Test analytics endpoint with invalid range parameters."""
|
||||
|
||||
async def test_invalid_range_format_returns_400(self) -> None:
|
||||
"""GET /api/v1/analytics?range=abc returns 400 error envelope."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/analytics", params={"range": "abc"})
|
||||
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert body["data"] is None
|
||||
assert "Invalid range format" in body["error"]
|
||||
|
||||
async def test_zero_day_range_returns_400(self) -> None:
|
||||
"""GET /api/v1/analytics?range=0d returns 400 because 0 is below minimum."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/analytics", params={"range": "0d"})
|
||||
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert "between 1 and 365" in body["error"]
|
||||
|
||||
async def test_range_exceeding_max_returns_400(self) -> None:
|
||||
"""GET /api/v1/analytics?range=999d returns 400 because it exceeds 365."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/analytics", params={"range": "999d"})
|
||||
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert "between 1 and 365" in body["error"]
|
||||
128
backend/tests/integration/test_error_responses.py
Normal file
128
backend/tests/integration/test_error_responses.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Integration tests for global error handling and envelope format consistency.
|
||||
|
||||
Tests that all error responses from the FastAPI app conform to the
|
||||
standard envelope: {"success": false, "data": null, "error": "..."}.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _build_app():
|
||||
"""Build the actual FastAPI app with exception handlers but mocked state."""
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.analytics.api import router as analytics_router
|
||||
from app.api_utils import envelope
|
||||
from app.replay.api import router as replay_router
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(analytics_router)
|
||||
test_app.include_router(replay_router)
|
||||
|
||||
@test_app.exception_handler(HTTPException)
|
||||
async def _http_exc(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=envelope(None, success=False, error=exc.detail),
|
||||
)
|
||||
|
||||
@test_app.exception_handler(RequestValidationError)
|
||||
async def _validation_exc(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=envelope(None, success=False, error=str(exc)),
|
||||
)
|
||||
|
||||
@test_app.exception_handler(Exception)
|
||||
async def _catch_all(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=envelope(None, success=False, error="Internal server error"),
|
||||
)
|
||||
|
||||
@test_app.get("/api/v1/health")
|
||||
def health_check():
|
||||
return {"status": "ok", "version": "0.6.0"}
|
||||
|
||||
test_app.state.settings = MagicMock(admin_api_key="")
|
||||
test_app.state.pool = MagicMock()
|
||||
|
||||
return test_app
|
||||
|
||||
|
||||
class TestEnvelopeFormat:
|
||||
"""Tests that error responses consistently follow envelope format."""
|
||||
|
||||
async def test_http_400_produces_envelope(self) -> None:
|
||||
"""A 400 error returns standard envelope with success=false."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/analytics", params={"range": "invalid"})
|
||||
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert body["data"] is None
|
||||
assert isinstance(body["error"], str)
|
||||
assert len(body["error"]) > 0
|
||||
|
||||
async def test_validation_error_produces_422_envelope(self) -> None:
|
||||
"""Invalid query param type returns 422 with envelope format."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
# page must be >= 1; passing 0 triggers validation error
|
||||
resp = await client.get("/api/v1/conversations", params={"page": 0})
|
||||
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert body["data"] is None
|
||||
assert isinstance(body["error"], str)
|
||||
|
||||
async def test_all_error_fields_present(self) -> None:
|
||||
"""Error envelope contains exactly success, data, and error keys."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/analytics", params={"range": "bad"})
|
||||
|
||||
body = resp.json()
|
||||
assert set(body.keys()) == {"success", "data", "error"}
|
||||
|
||||
async def test_health_endpoint_returns_200(self) -> None:
|
||||
"""Health check returns 200 with status ok."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/health")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "ok"
|
||||
assert "version" in body
|
||||
|
||||
async def test_unknown_endpoint_returns_404(self) -> None:
|
||||
"""Requesting a non-existent path returns 404."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/nonexistent-path")
|
||||
|
||||
# FastAPI returns 404 for unknown routes; may or may not be wrapped
|
||||
assert resp.status_code == 404
|
||||
203
backend/tests/integration/test_import_pipeline.py
Normal file
203
backend/tests/integration/test_import_pipeline.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Integration tests for the OpenAPI import pipeline orchestrator.
|
||||
|
||||
Tests the full pipeline: fetch -> validate -> parse -> classify.
|
||||
Uses mocked HTTP and mocked LLM classifier.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.openapi.models import ImportJob
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
_VALID_SPEC_JSON = """{
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/orders": {
|
||||
"get": {
|
||||
"operationId": "list_orders",
|
||||
"summary": "List orders",
|
||||
"description": "Returns all orders",
|
||||
"responses": {"200": {"description": "OK"}}
|
||||
}
|
||||
},
|
||||
"/orders/{id}": {
|
||||
"delete": {
|
||||
"operationId": "delete_order",
|
||||
"summary": "Delete order",
|
||||
"description": "Deletes an order",
|
||||
"parameters": [
|
||||
{"name": "id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {"204": {"description": "Deleted"}}
|
||||
}
|
||||
}
|
||||
}
|
||||
}"""
|
||||
|
||||
_INVALID_SPEC_JSON = '{"not": "a valid openapi spec"}'
|
||||
|
||||
_PUBLIC_IP = "93.184.216.34"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_classifier():
|
||||
"""A mock classifier that classifies using heuristics."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
return HeuristicClassifier()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(mock_classifier):
|
||||
"""Create an ImportOrchestrator with the mock classifier."""
|
||||
from app.openapi.importer import ImportOrchestrator
|
||||
|
||||
return ImportOrchestrator(classifier=mock_classifier)
|
||||
|
||||
|
||||
class TestImportOrchestratorSuccess:
|
||||
"""Tests for successful import pipeline flows."""
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_full_pipeline_succeeds(self, orchestrator, httpx_mock) -> None:
|
||||
"""Full pipeline with valid spec and mocked HTTP succeeds."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/spec.json",
|
||||
text=_VALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
job = await orchestrator.start_import(
|
||||
url="https://example.com/api/spec.json",
|
||||
job_id="test-job-1",
|
||||
on_progress=None,
|
||||
)
|
||||
assert isinstance(job, ImportJob)
|
||||
assert job.status == "done"
|
||||
assert job.job_id == "test-job-1"
|
||||
assert job.total_endpoints == 2
|
||||
assert job.classified_count == 2
|
||||
assert job.error_message is None
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_progress_callback_called_at_stages(self, orchestrator, httpx_mock) -> None:
|
||||
"""on_progress callback is called at each pipeline stage."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/spec.json",
|
||||
text=_VALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
stages_seen: list[str] = []
|
||||
|
||||
def on_progress(stage: str, job: ImportJob) -> None:
|
||||
stages_seen.append(stage)
|
||||
|
||||
await orchestrator.start_import(
|
||||
url="https://example.com/api/spec.json",
|
||||
job_id="test-job-2",
|
||||
on_progress=on_progress,
|
||||
)
|
||||
assert "fetching" in stages_seen
|
||||
assert "validating" in stages_seen
|
||||
assert "parsing" in stages_seen
|
||||
assert "classifying" in stages_seen
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_none_progress_callback_does_not_raise(
|
||||
self, orchestrator, httpx_mock
|
||||
) -> None:
|
||||
"""Passing None as on_progress does not raise."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/spec.json",
|
||||
text=_VALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
job = await orchestrator.start_import(
|
||||
url="https://example.com/api/spec.json",
|
||||
job_id="test-job-3",
|
||||
on_progress=None,
|
||||
)
|
||||
assert job.status == "done"
|
||||
|
||||
|
||||
class TestImportOrchestratorFailures:
|
||||
"""Tests for error handling in the import pipeline."""
|
||||
|
||||
async def test_fetch_failure_sets_failed_status(self, orchestrator) -> None:
|
||||
"""When HTTP fetch fails, job status is 'failed'."""
|
||||
with patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]):
|
||||
job = await orchestrator.start_import(
|
||||
url="http://internal.corp/spec.json",
|
||||
job_id="test-job-fail-1",
|
||||
on_progress=None,
|
||||
)
|
||||
assert job.status == "failed"
|
||||
assert job.error_message is not None
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_validation_failure_sets_failed_status(
|
||||
self, orchestrator, httpx_mock
|
||||
) -> None:
|
||||
"""When spec validation fails, job status is 'failed'."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/bad.json",
|
||||
text=_INVALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
job = await orchestrator.start_import(
|
||||
url="https://example.com/api/bad.json",
|
||||
job_id="test-job-fail-2",
|
||||
on_progress=None,
|
||||
)
|
||||
assert job.status == "failed"
|
||||
assert job.error_message is not None
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_error_message_is_descriptive(self, orchestrator, httpx_mock) -> None:
|
||||
"""Error message contains useful context."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/bad.json",
|
||||
text=_INVALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
job = await orchestrator.start_import(
|
||||
url="https://example.com/api/bad.json",
|
||||
job_id="test-job-fail-3",
|
||||
on_progress=None,
|
||||
)
|
||||
assert len(job.error_message) > 0
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_failed_status_progress_called_with_failed(
|
||||
self, orchestrator, httpx_mock
|
||||
) -> None:
|
||||
"""When pipeline fails, on_progress is called with 'failed' stage."""
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api/bad.json",
|
||||
text=_INVALID_SPEC_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
stages_seen: list[str] = []
|
||||
|
||||
def on_progress(stage: str, job: ImportJob) -> None:
|
||||
stages_seen.append(stage)
|
||||
|
||||
await orchestrator.start_import(
|
||||
url="https://example.com/api/bad.json",
|
||||
job_id="test-job-fail-4",
|
||||
on_progress=on_progress,
|
||||
)
|
||||
assert "failed" in stages_seen
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_public_dns():
|
||||
with patch("app.openapi.ssrf.resolve_hostname", return_value=[_PUBLIC_IP]):
|
||||
yield
|
||||
164
backend/tests/integration/test_openapi_api.py
Normal file
164
backend/tests/integration/test_openapi_api.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Integration tests for /api/v1/openapi/ endpoints.
|
||||
|
||||
Tests the full API layer for the OpenAPI import review workflow,
|
||||
including job creation, status retrieval, classification updates,
|
||||
and approval triggering.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _build_app():
|
||||
"""Build a minimal FastAPI app with the openapi router and mocked deps."""
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api_utils import envelope
|
||||
from app.openapi.review_api import router as openapi_router
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(openapi_router)
|
||||
|
||||
@test_app.exception_handler(HTTPException)
|
||||
async def _http_exc(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=envelope(None, success=False, error=exc.detail),
|
||||
)
|
||||
|
||||
@test_app.exception_handler(RequestValidationError)
|
||||
async def _validation_exc(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=envelope(None, success=False, error=str(exc)),
|
||||
)
|
||||
|
||||
test_app.state.settings = MagicMock(admin_api_key="")
|
||||
|
||||
return test_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_job_store():
|
||||
"""Clear the in-memory job store between tests."""
|
||||
from app.openapi.review_api import _job_store
|
||||
|
||||
_job_store.clear()
|
||||
yield
|
||||
_job_store.clear()
|
||||
|
||||
|
||||
class TestImportEndpoint:
|
||||
"""Tests for POST /api/v1/openapi/import."""
|
||||
|
||||
async def test_import_returns_202_with_job_id(self) -> None:
|
||||
"""Starting an import returns 202 with a job_id."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/openapi/import",
|
||||
json={"url": "https://example.com/api/spec.json"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 202
|
||||
body = resp.json()
|
||||
assert "job_id" in body
|
||||
assert body["status"] == "pending"
|
||||
assert body["spec_url"] == "https://example.com/api/spec.json"
|
||||
|
||||
async def test_import_invalid_url_returns_422(self) -> None:
|
||||
"""POST with invalid URL (no http/https) returns 422."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/openapi/import",
|
||||
json={"url": "ftp://example.com/spec.json"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
|
||||
|
||||
class TestJobStatusEndpoint:
|
||||
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
|
||||
|
||||
async def test_get_existing_job_returns_status(self) -> None:
|
||||
"""Retrieving an existing job returns its status."""
|
||||
from app.openapi.review_api import _job_store
|
||||
|
||||
_job_store["test-job-1"] = {
|
||||
"job_id": "test-job-1",
|
||||
"status": "done",
|
||||
"spec_url": "https://example.com/spec.json",
|
||||
"total_endpoints": 5,
|
||||
"classified_count": 5,
|
||||
"error_message": None,
|
||||
"classifications": [],
|
||||
}
|
||||
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/openapi/jobs/test-job-1")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["job_id"] == "test-job-1"
|
||||
assert body["status"] == "done"
|
||||
assert body["total_endpoints"] == 5
|
||||
|
||||
async def test_get_unknown_job_returns_404(self) -> None:
|
||||
"""Retrieving a non-existent job returns 404 error envelope."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/openapi/jobs/unknown-id-999")
|
||||
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert "not found" in body["error"].lower()
|
||||
|
||||
|
||||
class TestApproveEndpoint:
|
||||
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
|
||||
|
||||
async def test_approve_with_no_classifications_returns_400(self) -> None:
|
||||
"""Approving a job with no classifications returns 400."""
|
||||
from app.openapi.review_api import _job_store
|
||||
|
||||
_job_store["empty-job"] = {
|
||||
"job_id": "empty-job",
|
||||
"status": "done",
|
||||
"spec_url": "https://example.com/spec.json",
|
||||
"total_endpoints": 0,
|
||||
"classified_count": 0,
|
||||
"error_message": None,
|
||||
"classifications": [],
|
||||
}
|
||||
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.post("/api/v1/openapi/jobs/empty-job/approve")
|
||||
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert "no classifications" in body["error"].lower()
|
||||
512
backend/tests/integration/test_phase2_checkpoints.py
Normal file
512
backend/tests/integration/test_phase2_checkpoints.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Phase 2 checkpoint acceptance tests.
|
||||
|
||||
Each test maps to one checkpoint criterion from DEVELOPMENT-PLAN.md:
|
||||
1. "查询订单 1042" -> routes to order_lookup agent
|
||||
2. "取消订单 1042 并给我一个 10% 折扣" -> sequential multi-agent execution
|
||||
3. Ambiguous message -> fallback asks for clarification
|
||||
4. Interrupt > 30 min TTL -> auto-cancel + retry prompt
|
||||
5. Agent escalation -> Webhook POST succeeds (or logs after retries)
|
||||
6. E-commerce template -> 4 pre-configured agents work
|
||||
7. pytest --cov >= 80% (verified separately)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
|
||||
from app.graph_context import GraphContext
|
||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.registry import AgentConfig, AgentRegistry
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AsyncIterHelper:
|
||||
def __init__(self, items: list) -> None:
|
||||
self._items = list(items)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._items:
|
||||
raise StopAsyncIteration
|
||||
return self._items.pop(0)
|
||||
|
||||
|
||||
class FakeWS:
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[dict] = []
|
||||
|
||||
async def send_json(self, data: dict) -> None:
|
||||
self.sent.append(data)
|
||||
|
||||
|
||||
def _chunk(content: str, node: str) -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = content
|
||||
c.tool_calls = []
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def _tool_chunk(name: str, args: dict, node: str) -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = ""
|
||||
c.tool_calls = [{"name": name, "args": args}]
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def _state(*, interrupt: bool = False, data: dict | None = None):
|
||||
s = MagicMock()
|
||||
if interrupt:
|
||||
obj = MagicMock()
|
||||
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
|
||||
t = MagicMock()
|
||||
t.interrupts = (obj,)
|
||||
s.tasks = (t,)
|
||||
else:
|
||||
s.tasks = ()
|
||||
return s
|
||||
|
||||
|
||||
def _agent(name: str, desc: str, perm: str = "read") -> AgentConfig:
|
||||
return AgentConfig(name=name, description=desc, permission=perm, tools=["fallback_respond"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 1: "查询订单 1042" -> 路由到订单查询 Agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint1OrderQueryRouting:
|
||||
"""Verify intent classifier routes order queries to order_lookup."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_order_query_classified_to_order_lookup(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="order query"),
|
||||
),
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (
|
||||
_agent("order_lookup", "Looks up order status and tracking"),
|
||||
_agent("order_actions", "Modifies orders", "write"),
|
||||
_agent("discount", "Applies discounts", "write"),
|
||||
_agent("fallback", "Handles unclear requests"),
|
||||
)
|
||||
|
||||
result = await classifier.classify("查询订单 1042", agents)
|
||||
assert len(result.intents) == 1
|
||||
assert result.intents[0].agent_name == "order_lookup"
|
||||
assert result.intents[0].confidence >= 0.9
|
||||
assert not result.is_ambiguous
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_order_query_streams_from_order_lookup_agent(self) -> None:
|
||||
"""Full dispatch: classify -> route -> stream from order_lookup."""
|
||||
graph = MagicMock()
|
||||
# Classifier returns order_lookup
|
||||
mock_classifier = AsyncMock()
|
||||
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||
))
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
|
||||
# Graph streams order_lookup response
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([
|
||||
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
||||
_chunk("Order 1042 is shipped.", "order_lookup"),
|
||||
]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
|
||||
|
||||
token_msgs = [m for m in ws.sent if m["type"] == "token"]
|
||||
assert any(m["agent"] == "order_lookup" for m in token_msgs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 2: Multi-intent -> sequential execution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint2MultiIntentSequential:
|
||||
"""Verify multi-intent classified and hint injected for sequential execution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_intent_classification(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (
|
||||
_agent("order_actions", "Modifies orders", "write"),
|
||||
_agent("discount", "Applies discounts", "write"),
|
||||
_agent("fallback", "Handles unclear requests"),
|
||||
)
|
||||
|
||||
result = await classifier.classify("取消订单 1042 并给我一个 10% 折扣", agents)
|
||||
assert len(result.intents) == 2
|
||||
assert result.intents[0].agent_name == "order_actions"
|
||||
assert result.intents[1].agent_name == "discount"
|
||||
assert not result.is_ambiguous
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_intent_injects_routing_hint(self) -> None:
|
||||
"""When multi-intent detected, a [System: ...] hint is appended to the message."""
|
||||
graph = MagicMock()
|
||||
mock_classifier = AsyncMock()
|
||||
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
))
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({
|
||||
"type": "message",
|
||||
"thread_id": "t1",
|
||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||
})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
# Verify the graph was called with the routing hint in the message
|
||||
call_args = graph.astream.call_args
|
||||
input_msg = call_args[0][0]
|
||||
msg_content = input_msg["messages"][0].content
|
||||
assert "[System:" in msg_content
|
||||
assert "order_actions" in msg_content
|
||||
assert "discount" in msg_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 3: Ambiguous message -> clarification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint3AmbiguousClarification:
|
||||
"""Verify ambiguous messages trigger clarification prompt."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_intent_returns_clarification(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),),
|
||||
is_ambiguous=False, # low confidence will trigger ambiguity threshold
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_agent("order_lookup", "Orders"), _agent("fallback", "Fallback"))
|
||||
|
||||
result = await classifier.classify("嗯...", agents)
|
||||
assert result.is_ambiguous
|
||||
assert result.clarification_question is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_sends_clarification_via_websocket(self) -> None:
|
||||
graph = MagicMock()
|
||||
mock_classifier = AsyncMock()
|
||||
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
|
||||
intents=(),
|
||||
is_ambiguous=True,
|
||||
clarification_question=(
|
||||
"Could you please provide more details about what you need help with?"
|
||||
),
|
||||
))
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=_state())
|
||||
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
|
||||
assert len(clarifications) == 1
|
||||
assert "more details" in clarifications[0]["message"]
|
||||
|
||||
# Should NOT call graph.astream since we returned early
|
||||
graph.astream.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 4: Interrupt > 30 min -> auto-cancel + retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint4InterruptTTLAutoCancel:
|
||||
"""Verify interrupt TTL expiration triggers auto-cancel and retry prompt."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
|
||||
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||
graph = MagicMock()
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
|
||||
graph.aget_state = AsyncMock(return_value=st)
|
||||
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager(ttl_seconds=1800) # 30 minutes
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
# Trigger interrupt
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
|
||||
# Simulate 31 minutes passing
|
||||
record = im._interrupts["t1"]
|
||||
ws.sent.clear()
|
||||
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = record.created_at + 1860 # 31 min
|
||||
|
||||
raw = json.dumps({
|
||||
"type": "interrupt_response",
|
||||
"thread_id": "t1",
|
||||
"approved": True,
|
||||
})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
# Should get retry prompt, NOT resume the graph
|
||||
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
|
||||
assert len(expired_msgs) == 1
|
||||
assert "30 minutes" in expired_msgs[0]["message"]
|
||||
assert expired_msgs[0]["action"] == "cancel_order"
|
||||
assert expired_msgs[0]["thread_id"] == "t1"
|
||||
|
||||
def test_cleanup_expired_returns_records(self) -> None:
|
||||
im = InterruptManager(ttl_seconds=1800)
|
||||
im.register("t1", "cancel_order", {"order_id": "1042"})
|
||||
im.register("t2", "apply_discount", {"order_id": "1043"})
|
||||
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
record = im._interrupts["t1"]
|
||||
mock_time.time.return_value = record.created_at + 1860
|
||||
expired = im.cleanup_expired()
|
||||
|
||||
assert len(expired) == 2
|
||||
actions = {r.action for r in expired}
|
||||
assert "cancel_order" in actions
|
||||
assert "apply_discount" in actions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 5: Agent escalation -> Webhook POST
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint5WebhookEscalation:
|
||||
"""Verify webhook escalation sends POST and retries on failure."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_post_success(self) -> None:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("app.escalation.httpx.AsyncClient", return_value=mock_client):
|
||||
escalator = WebhookEscalator(url="https://support.example.com/escalate")
|
||||
payload = EscalationPayload(
|
||||
thread_id="t1",
|
||||
reason="Agent cannot resolve customer issue",
|
||||
conversation_summary="Customer asked about refund policy",
|
||||
metadata={"customer_id": "C-123"},
|
||||
)
|
||||
result = await escalator.escalate(payload)
|
||||
|
||||
assert result.success
|
||||
assert result.status_code == 200
|
||||
assert result.attempts == 1
|
||||
|
||||
# Verify POST was called with correct payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[0][0] == "https://support.example.com/escalate"
|
||||
posted_data = call_args[1]["json"]
|
||||
assert posted_data["thread_id"] == "t1"
|
||||
assert posted_data["reason"] == "Agent cannot resolve customer issue"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_retries_then_logs(self) -> None:
|
||||
fail_response = AsyncMock()
|
||||
fail_response.status_code = 503
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=fail_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
escalator = WebhookEscalator(
|
||||
url="https://support.example.com/escalate",
|
||||
max_retries=3,
|
||||
)
|
||||
payload = EscalationPayload(
|
||||
thread_id="t1",
|
||||
reason="Escalation needed",
|
||||
conversation_summary="Summary",
|
||||
)
|
||||
result = await escalator.escalate(payload)
|
||||
|
||||
assert not result.success
|
||||
assert result.attempts == 3
|
||||
assert "503" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_escalator_when_disabled(self) -> None:
|
||||
escalator = NoOpEscalator()
|
||||
payload = EscalationPayload(
|
||||
thread_id="t1",
|
||||
reason="Test",
|
||||
conversation_summary="Test",
|
||||
)
|
||||
result = await escalator.escalate(payload)
|
||||
assert not result.success
|
||||
assert "disabled" in result.error.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Checkpoint 6: E-commerce template -> pre-configured agents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestCheckpoint6EcommerceTemplate:
|
||||
"""Verify e-commerce template loads with correct agents."""
|
||||
|
||||
def test_ecommerce_template_loads_4_agents(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
assert len(registry) == 4
|
||||
|
||||
def test_ecommerce_template_has_correct_agents(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
agents = registry.list_agents()
|
||||
names = {a.name for a in agents}
|
||||
assert names == {"order_lookup", "order_actions", "discount", "fallback"}
|
||||
|
||||
def test_ecommerce_order_lookup_is_read(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
agent = registry.get_agent("order_lookup")
|
||||
assert agent.permission == "read"
|
||||
assert "get_order_status" in agent.tools
|
||||
assert "get_tracking_info" in agent.tools
|
||||
|
||||
def test_ecommerce_order_actions_is_write(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
agent = registry.get_agent("order_actions")
|
||||
assert agent.permission == "write"
|
||||
assert "cancel_order" in agent.tools
|
||||
|
||||
def test_ecommerce_discount_is_write(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
agent = registry.get_agent("discount")
|
||||
assert agent.permission == "write"
|
||||
assert "apply_discount" in agent.tools
|
||||
assert "generate_coupon" in agent.tools
|
||||
|
||||
def test_ecommerce_fallback_is_read(self) -> None:
|
||||
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
|
||||
agent = registry.get_agent("fallback")
|
||||
assert agent.permission == "read"
|
||||
|
||||
def test_all_three_templates_available(self) -> None:
|
||||
templates = AgentRegistry.list_templates(TEMPLATES_DIR)
|
||||
assert "e-commerce" in templates
|
||||
assert "saas" in templates
|
||||
assert "fintech" in templates
|
||||
213
backend/tests/integration/test_replay_api.py
Normal file
213
backend/tests/integration/test_replay_api.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Integration tests for /api/v1/conversations and /api/v1/replay/{thread_id}.
|
||||
|
||||
Tests the full API layer with a mocked database pool, verifying routing,
|
||||
serialization, pagination, and error handling in envelope format.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def _make_fake_cursor(rows, *, fetchone_value=None):
|
||||
"""Build a fake async cursor returning the given rows on fetchall."""
|
||||
cursor = AsyncMock()
|
||||
cursor.fetchall = AsyncMock(return_value=rows)
|
||||
if fetchone_value is not None:
|
||||
cursor.fetchone = AsyncMock(return_value=fetchone_value)
|
||||
return cursor
|
||||
|
||||
|
||||
class _FakeConnection:
|
||||
"""Fake async connection that returns pre-configured cursors in order."""
|
||||
|
||||
def __init__(self, cursors: list) -> None:
|
||||
self._cursors = list(cursors)
|
||||
self._idx = 0
|
||||
|
||||
async def execute(self, sql, params=None):
|
||||
cursor = self._cursors[self._idx]
|
||||
self._idx += 1
|
||||
return cursor
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class _FakePool:
|
||||
"""Fake connection pool that yields a fake connection."""
|
||||
|
||||
def __init__(self, conn: _FakeConnection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
def connection(self):
|
||||
return self._conn
|
||||
|
||||
|
||||
def _build_app(pool=None):
|
||||
"""Build a minimal FastAPI app with the replay router and mocked deps."""
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api_utils import envelope
|
||||
from app.replay.api import router as replay_router
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(replay_router)
|
||||
|
||||
@test_app.exception_handler(HTTPException)
|
||||
async def _http_exc(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=envelope(None, success=False, error=exc.detail),
|
||||
)
|
||||
|
||||
@test_app.exception_handler(RequestValidationError)
|
||||
async def _validation_exc(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=envelope(None, success=False, error=str(exc)),
|
||||
)
|
||||
|
||||
test_app.state.settings = MagicMock(admin_api_key="")
|
||||
test_app.state.pool = pool or MagicMock()
|
||||
|
||||
return test_app
|
||||
|
||||
|
||||
class TestListConversations:
|
||||
"""Tests for GET /api/v1/conversations endpoint."""
|
||||
|
||||
async def test_returns_paginated_envelope(self) -> None:
|
||||
"""Conversations list returns envelope with pagination metadata."""
|
||||
count_cursor = _make_fake_cursor([], fetchone_value=(3,))
|
||||
rows = [
|
||||
{"thread_id": "t1", "created_at": "2026-01-01", "last_activity": "2026-01-01",
|
||||
"status": "active", "total_tokens": 100, "total_cost_usd": 0.01},
|
||||
{"thread_id": "t2", "created_at": "2026-01-02", "last_activity": "2026-01-02",
|
||||
"status": "resolved", "total_tokens": 200, "total_cost_usd": 0.02},
|
||||
]
|
||||
list_cursor = _make_fake_cursor(rows)
|
||||
conn = _FakeConnection([count_cursor, list_cursor])
|
||||
pool = _FakePool(conn)
|
||||
test_app = _build_app(pool)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/conversations")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
assert body["data"]["total"] == 3
|
||||
assert len(body["data"]["conversations"]) == 2
|
||||
assert body["data"]["page"] == 1
|
||||
assert body["data"]["per_page"] == 20
|
||||
|
||||
async def test_custom_page_and_per_page(self) -> None:
|
||||
"""Custom page/per_page params are reflected in the response."""
|
||||
count_cursor = _make_fake_cursor([], fetchone_value=(50,))
|
||||
list_cursor = _make_fake_cursor([])
|
||||
conn = _FakeConnection([count_cursor, list_cursor])
|
||||
pool = _FakePool(conn)
|
||||
test_app = _build_app(pool)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/conversations", params={"page": 3, "per_page": 10})
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["data"]["page"] == 3
|
||||
assert body["data"]["per_page"] == 10
|
||||
|
||||
async def test_invalid_page_returns_422(self) -> None:
|
||||
"""page=0 violates ge=1 constraint and returns 422 error envelope."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/conversations", params={"page": 0})
|
||||
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
|
||||
|
||||
class TestReplayEndpoint:
|
||||
"""Tests for GET /api/v1/replay/{thread_id} endpoint."""
|
||||
|
||||
async def test_valid_thread_returns_timeline(self) -> None:
|
||||
"""Replay with valid thread_id returns steps in envelope format."""
|
||||
checkpoint_rows = [
|
||||
{
|
||||
"thread_id": "abc123",
|
||||
"checkpoint_id": "cp1",
|
||||
"checkpoint": {
|
||||
"channel_values": {
|
||||
"messages": [
|
||||
{"type": "human", "content": "Hello", "created_at": "2026-01-01T00:00:00Z"},
|
||||
{"type": "ai", "content": "Hi there!", "created_at": "2026-01-01T00:00:01Z"},
|
||||
]
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
}
|
||||
]
|
||||
cursor = _make_fake_cursor(checkpoint_rows)
|
||||
conn = _FakeConnection([cursor])
|
||||
pool = _FakePool(conn)
|
||||
test_app = _build_app(pool)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/replay/abc123")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
assert body["data"]["thread_id"] == "abc123"
|
||||
assert body["data"]["total_steps"] == 2
|
||||
assert len(body["data"]["steps"]) == 2
|
||||
assert body["data"]["steps"][0]["type"] == "user_message"
|
||||
assert body["data"]["steps"][1]["type"] == "agent_response"
|
||||
|
||||
async def test_invalid_thread_id_format_returns_400(self) -> None:
|
||||
"""Thread IDs with path traversal characters are rejected with 400."""
|
||||
test_app = _build_app()
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/replay/../../etc/passwd")
|
||||
|
||||
# FastAPI may return 400 from our handler or 404 from routing
|
||||
assert resp.status_code in (400, 404, 422)
|
||||
|
||||
async def test_nonexistent_thread_returns_404(self) -> None:
|
||||
"""Replay with a thread_id that has no checkpoints returns 404."""
|
||||
cursor = _make_fake_cursor([])
|
||||
conn = _FakeConnection([cursor])
|
||||
pool = _FakePool(conn)
|
||||
test_app = _build_app(pool)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=test_app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.get("/api/v1/replay/nonexistent-thread")
|
||||
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert "not found" in body["error"].lower()
|
||||
371
backend/tests/integration/test_routing.py
Normal file
371
backend/tests/integration/test_routing.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Integration tests for multi-agent routing flow.
|
||||
|
||||
Tests the full pipeline: intent classification -> supervisor routing ->
|
||||
agent execution -> response streaming, exercising cross-module integration
|
||||
with mocked LLM.
|
||||
|
||||
Required by Phase 2 test plan:
|
||||
- Unit: intent classification accuracy
|
||||
- Unit: multi-intent sequential execution
|
||||
- Integration: complete multi-agent routing flow
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.registry import AgentConfig
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AsyncIterHelper:
|
||||
def __init__(self, items: list) -> None:
|
||||
self._items = list(items)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._items:
|
||||
raise StopAsyncIteration
|
||||
return self._items.pop(0)
|
||||
|
||||
|
||||
class FakeWS:
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[dict] = []
|
||||
|
||||
async def send_json(self, data: dict) -> None:
|
||||
self.sent.append(data)
|
||||
|
||||
|
||||
def _chunk(content: str, node: str) -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = content
|
||||
c.tool_calls = []
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def _tool_chunk(name: str, args: dict, node: str) -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = ""
|
||||
c.tool_calls = [{"name": name, "args": args}]
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def _state(*, interrupt: bool = False, data: dict | None = None):
|
||||
s = MagicMock()
|
||||
if interrupt:
|
||||
obj = MagicMock()
|
||||
obj.value = data or {}
|
||||
t = MagicMock()
|
||||
t.interrupts = (obj,)
|
||||
s.tasks = (t,)
|
||||
else:
|
||||
s.tasks = ()
|
||||
return s
|
||||
|
||||
|
||||
AGENTS = (
|
||||
AgentConfig(
|
||||
name="order_lookup", description="Looks up orders",
|
||||
permission="read", tools=["get_order_status", "get_tracking_info"],
|
||||
),
|
||||
AgentConfig(
|
||||
name="order_actions", description="Modifies orders",
|
||||
permission="write", tools=["cancel_order"],
|
||||
),
|
||||
AgentConfig(
|
||||
name="discount", description="Applies discounts",
|
||||
permission="write", tools=["apply_discount", "generate_coupon"],
|
||||
),
|
||||
AgentConfig(
|
||||
name="fallback", description="Handles unclear requests",
|
||||
permission="read", tools=["fallback_respond"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_classifier(result: ClassificationResult) -> AsyncMock:
|
||||
"""Create a mock classifier returning the given result."""
|
||||
classifier = AsyncMock()
|
||||
classifier.classify = AsyncMock(return_value=result)
|
||||
return classifier
|
||||
|
||||
|
||||
def _make_graph_and_ctx(
|
||||
classifier_result: ClassificationResult | None,
|
||||
chunks: list,
|
||||
state=None,
|
||||
) -> tuple[MagicMock, GraphContext]:
|
||||
"""Build a graph mock and GraphContext with optional intent classifier."""
|
||||
graph = MagicMock()
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
|
||||
graph.aget_state = AsyncMock(return_value=state or _state())
|
||||
|
||||
if classifier_result is not None:
|
||||
classifier = _make_classifier(classifier_result)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=AGENTS)
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=classifier,
|
||||
)
|
||||
else:
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(
|
||||
graph=graph, registry=mock_registry, intent_classifier=None,
|
||||
)
|
||||
|
||||
return graph, graph_ctx
|
||||
|
||||
|
||||
async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
|
||||
sm = SessionManager()
|
||||
sm.touch(thread_id)
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
return ws.sent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-intent routing to each agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestSingleIntentRouting:
|
||||
"""Verify single-intent messages route to the correct agent."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_order_lookup(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(
|
||||
agent_name="order_lookup", confidence=0.95, reasoning="status query",
|
||||
),),
|
||||
)
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
|
||||
_chunk("Order 1042 is shipped.", "order_lookup"),
|
||||
])
|
||||
|
||||
msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
|
||||
|
||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["tool"] == "get_order_status"
|
||||
assert tools[0]["agent"] == "order_lookup"
|
||||
|
||||
tokens = [m for m in msgs if m["type"] == "token"]
|
||||
assert any("shipped" in t["content"] for t in tokens)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_order_actions(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
|
||||
)
|
||||
graph, graph_ctx = _make_graph_and_ctx(
|
||||
result,
|
||||
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
|
||||
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
|
||||
)
|
||||
|
||||
msgs = await _dispatch(graph_ctx, "Cancel order 1042")
|
||||
|
||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||
assert tools[0]["tool"] == "cancel_order"
|
||||
assert tools[0]["agent"] == "order_actions"
|
||||
|
||||
interrupts = [m for m in msgs if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_discount(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
|
||||
)
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
|
||||
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
|
||||
])
|
||||
|
||||
msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
|
||||
|
||||
tools = [m for m in msgs if m["type"] == "tool_call"]
|
||||
assert tools[0]["tool"] == "generate_coupon"
|
||||
assert tools[0]["agent"] == "discount"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_fallback(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
|
||||
)
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_chunk("I can help with order inquiries.", "fallback"),
|
||||
])
|
||||
|
||||
msgs = await _dispatch(graph_ctx, "What can you do?")
|
||||
|
||||
tokens = [m for m in msgs if m["type"] == "token"]
|
||||
assert tokens[0]["agent"] == "fallback"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-intent routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestMultiIntentRouting:
|
||||
"""Verify multi-intent triggers sequential execution hint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_intents_inject_routing_hint(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
)
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [
|
||||
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
|
||||
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
|
||||
])
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({
|
||||
"type": "message",
|
||||
"thread_id": "t1",
|
||||
"content": "取消订单 1042 并给我一个 10% 折扣",
|
||||
})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
# Verify routing hint was injected
|
||||
call_args = graph.astream.call_args[0][0]
|
||||
msg_content = call_args["messages"][0].content
|
||||
assert "[System:" in msg_content
|
||||
assert "order_actions" in msg_content
|
||||
assert "discount" in msg_content
|
||||
|
||||
# Both tool calls should appear
|
||||
tools = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||
tool_names = {t["tool"] for t in tools}
|
||||
assert "cancel_order" in tool_names
|
||||
assert "apply_discount" in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_intent_no_routing_hint(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||
)
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
|
||||
|
||||
sm = SessionManager()
|
||||
sm.touch("t1")
|
||||
im = InterruptManager()
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
msg_content = graph.astream.call_args[0][0]["messages"][0].content
|
||||
assert "[System:" not in msg_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ambiguity routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAmbiguityRouting:
|
||||
"""Verify ambiguous intents produce clarification, not agent calls."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_skips_graph_returns_clarification(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(),
|
||||
is_ambiguous=True,
|
||||
clarification_question="Could you please clarify what you need?",
|
||||
)
|
||||
graph, graph_ctx = _make_graph_and_ctx(result, [])
|
||||
|
||||
msgs = await _dispatch(graph_ctx, "嗯...")
|
||||
|
||||
clarifications = [m for m in msgs if m["type"] == "clarification"]
|
||||
assert len(clarifications) == 1
|
||||
assert "clarify" in clarifications[0]["message"]
|
||||
|
||||
# Graph should NOT have been called
|
||||
graph.astream.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_confidence_triggers_ambiguity(self) -> None:
|
||||
"""LLMIntentClassifier applies threshold -- low confidence -> ambiguous."""
|
||||
raw_result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="fallback", confidence=0.2, reasoning="unclear"),),
|
||||
is_ambiguous=False,
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=raw_result)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
result = await classifier.classify("hmm", AGENTS)
|
||||
|
||||
assert result.is_ambiguous
|
||||
assert result.clarification_question is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# No classifier fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestNoClassifierFallback:
|
||||
"""Verify system works without intent classifier (falls back to supervisor prompt)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_classifier_routes_via_supervisor(self) -> None:
|
||||
graph, graph_ctx = _make_graph_and_ctx(
|
||||
classifier_result=None,
|
||||
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
|
||||
)
|
||||
|
||||
msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
|
||||
|
||||
tokens = [m for m in msgs if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
completes = [m for m in msgs if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
159
backend/tests/integration/test_session_interrupt_lifecycle.py
Normal file
159
backend/tests/integration/test_session_interrupt_lifecycle.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Integration tests for SessionManager + InterruptManager lifecycle.
|
||||
|
||||
These tests exercise the in-memory managers together, verifying the full
|
||||
lifecycle of sessions and interrupts: creation, TTL sliding, interrupt
|
||||
registration/resolution, and expired-interrupt cleanup.
|
||||
|
||||
No database required -- both managers are in-memory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
class TestSessionInterruptLifecycle:
|
||||
"""Tests for the combined session + interrupt lifecycle."""
|
||||
|
||||
def test_create_session_register_interrupt_check_status(self) -> None:
|
||||
"""Full lifecycle: create session, register interrupt, verify both states."""
|
||||
sm = SessionManager(session_ttl_seconds=3600)
|
||||
im = InterruptManager(ttl_seconds=300)
|
||||
|
||||
# Create a session
|
||||
state = sm.touch("thread-1")
|
||||
assert state.thread_id == "thread-1"
|
||||
assert not state.has_pending_interrupt
|
||||
assert not sm.is_expired("thread-1")
|
||||
|
||||
# Register an interrupt
|
||||
record = im.register("thread-1", "cancel_order", {"order_id": "1042"})
|
||||
sm.extend_for_interrupt("thread-1")
|
||||
|
||||
assert im.has_pending("thread-1")
|
||||
session_state = sm.get_state("thread-1")
|
||||
assert session_state is not None
|
||||
assert session_state.has_pending_interrupt
|
||||
|
||||
# Session should not expire while interrupt is pending
|
||||
assert not sm.is_expired("thread-1")
|
||||
|
||||
def test_interrupt_expiry_after_ttl(self) -> None:
|
||||
"""Interrupt expires when TTL elapses, even if session is alive."""
|
||||
im = InterruptManager(ttl_seconds=5)
|
||||
|
||||
record = im.register("thread-2", "refund", {"amount": 50})
|
||||
assert im.has_pending("thread-2")
|
||||
|
||||
# Simulate time passing beyond TTL
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = record.created_at + 10
|
||||
assert not im.has_pending("thread-2")
|
||||
|
||||
status = im.check_status("thread-2")
|
||||
assert status is not None
|
||||
assert status.is_expired
|
||||
assert status.remaining_seconds == 0.0
|
||||
|
||||
def test_interrupt_resolve_flow(self) -> None:
|
||||
"""Resolving an interrupt removes it from pending and resets session."""
|
||||
sm = SessionManager(session_ttl_seconds=3600)
|
||||
im = InterruptManager(ttl_seconds=300)
|
||||
|
||||
sm.touch("thread-3")
|
||||
im.register("thread-3", "delete_account", {"user_id": "u1"})
|
||||
sm.extend_for_interrupt("thread-3")
|
||||
|
||||
# Verify pending state
|
||||
assert im.has_pending("thread-3")
|
||||
assert sm.get_state("thread-3").has_pending_interrupt
|
||||
|
||||
# Resolve
|
||||
im.resolve("thread-3")
|
||||
sm.resolve_interrupt("thread-3")
|
||||
|
||||
assert not im.has_pending("thread-3")
|
||||
session_state = sm.get_state("thread-3")
|
||||
assert session_state is not None
|
||||
assert not session_state.has_pending_interrupt
|
||||
|
||||
def test_cleanup_expired_removes_old_interrupts(self) -> None:
|
||||
"""cleanup_expired removes only expired interrupts, keeping active ones."""
|
||||
im = InterruptManager(ttl_seconds=10)
|
||||
|
||||
# Register two interrupts at different times
|
||||
old_record = im.register("thread-old", "action_old", {})
|
||||
new_record = im.register("thread-new", "action_new", {})
|
||||
|
||||
# Simulate time where only old one expired
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
# Move old record's creation to the past
|
||||
im._interrupts["thread-old"] = old_record.__class__(
|
||||
interrupt_id=old_record.interrupt_id,
|
||||
thread_id=old_record.thread_id,
|
||||
action=old_record.action,
|
||||
params=old_record.params,
|
||||
created_at=time.time() - 20,
|
||||
ttl_seconds=old_record.ttl_seconds,
|
||||
)
|
||||
mock_time.time.return_value = time.time()
|
||||
|
||||
expired = im.cleanup_expired()
|
||||
assert len(expired) == 1
|
||||
assert expired[0].thread_id == "thread-old"
|
||||
|
||||
# New one should still be pending
|
||||
assert im.has_pending("thread-new")
|
||||
assert not im.has_pending("thread-old")
|
||||
|
||||
def test_session_ttl_sliding_window(self) -> None:
|
||||
"""Touching a session resets the sliding window TTL."""
|
||||
sm = SessionManager(session_ttl_seconds=3600)
|
||||
|
||||
state1 = sm.touch("thread-5")
|
||||
first_activity = state1.last_activity
|
||||
|
||||
time.sleep(0.01)
|
||||
state2 = sm.touch("thread-5")
|
||||
second_activity = state2.last_activity
|
||||
|
||||
assert second_activity > first_activity
|
||||
assert not sm.is_expired("thread-5")
|
||||
|
||||
def test_session_expires_after_ttl_without_activity(self) -> None:
|
||||
"""Session expires when TTL passes without a touch or interrupt."""
|
||||
sm = SessionManager(session_ttl_seconds=0)
|
||||
sm.touch("thread-6")
|
||||
|
||||
# TTL is 0 so session is immediately expired
|
||||
assert sm.is_expired("thread-6")
|
||||
|
||||
def test_pending_interrupt_prevents_session_expiry(self) -> None:
|
||||
"""A session with pending interrupt does not expire even with TTL=0."""
|
||||
sm = SessionManager(session_ttl_seconds=0)
|
||||
sm.touch("thread-7")
|
||||
sm.extend_for_interrupt("thread-7")
|
||||
|
||||
# Even with TTL=0, session should not expire because of pending interrupt
|
||||
assert not sm.is_expired("thread-7")
|
||||
|
||||
def test_retry_prompt_for_expired_interrupt(self) -> None:
|
||||
"""InterruptManager generates a retry prompt for expired interrupts."""
|
||||
im = InterruptManager(ttl_seconds=300)
|
||||
record = im.register("thread-8", "cancel_order", {"order_id": "1042"})
|
||||
|
||||
prompt = im.generate_retry_prompt(record)
|
||||
|
||||
assert prompt["type"] == "interrupt_expired"
|
||||
assert prompt["thread_id"] == "thread-8"
|
||||
assert "cancel_order" in prompt["action"]
|
||||
assert "cancel_order" in prompt["message"]
|
||||
assert "expired" in prompt["message"].lower()
|
||||
360
backend/tests/integration/test_websocket.py
Normal file
360
backend/tests/integration/test_websocket.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""Integration tests for WebSocket message flow.
|
||||
|
||||
These tests exercise dispatch_message end-to-end with a mocked LangGraph
|
||||
graph, verifying streaming, interrupt approval/rejection, session TTL,
|
||||
and interrupt TTL expiration through the full message handling pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.interrupt_manager import InterruptManager
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AsyncIterHelper:
|
||||
"""Make a list behave as an async iterator."""
|
||||
|
||||
def __init__(self, items: list) -> None:
|
||||
self._items = list(items)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._items:
|
||||
raise StopAsyncIteration
|
||||
return self._items.pop(0)
|
||||
|
||||
|
||||
class FakeWS:
|
||||
"""Fake WebSocket that records sent messages."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[dict] = []
|
||||
|
||||
async def send_json(self, data: dict) -> None:
|
||||
self.sent.append(data)
|
||||
|
||||
|
||||
def _chunk(content: str, node: str = "order_lookup") -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = content
|
||||
c.tool_calls = []
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def _tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
|
||||
c = MagicMock()
|
||||
c.content = ""
|
||||
c.tool_calls = [{"name": name, "args": args}]
|
||||
return (c, {"langgraph_node": node})
|
||||
|
||||
|
||||
def _state(*, interrupt: bool = False, data: dict | None = None) -> Any:
|
||||
s = MagicMock()
|
||||
if interrupt:
|
||||
obj = MagicMock()
|
||||
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
|
||||
t = MagicMock()
|
||||
t.interrupts = (obj,)
|
||||
s.tasks = (t,)
|
||||
else:
|
||||
s.tasks = ()
|
||||
return s
|
||||
|
||||
|
||||
def _graph(
|
||||
chunks: list | None = None,
|
||||
st: Any = None,
|
||||
resume_chunks: list | None = None,
|
||||
) -> MagicMock:
|
||||
g = MagicMock()
|
||||
|
||||
if st is None:
|
||||
st = _state()
|
||||
|
||||
streams = [chunks or [], resume_chunks or []]
|
||||
idx = {"n": 0}
|
||||
|
||||
def make_stream(*a, **kw):
|
||||
i = min(idx["n"], len(streams) - 1)
|
||||
idx["n"] += 1
|
||||
return AsyncIterHelper(list(streams[i]))
|
||||
|
||||
g.astream = MagicMock(side_effect=make_stream)
|
||||
g.aget_state = AsyncMock(return_value=st)
|
||||
return g
|
||||
|
||||
|
||||
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
|
||||
g = graph or _graph()
|
||||
registry = MagicMock()
|
||||
registry.list_agents = MagicMock(return_value=())
|
||||
return GraphContext(graph=g, registry=registry, intent_classifier=None)
|
||||
|
||||
|
||||
def _setup(
|
||||
graph=None,
|
||||
session_ttl: int = 1800,
|
||||
interrupt_ttl: int = 1800,
|
||||
thread_id: str = "t1",
|
||||
touch: bool = True,
|
||||
):
|
||||
"""Create test dependencies. Pre-touches session by default."""
|
||||
g = graph or _graph()
|
||||
graph_ctx = _make_graph_ctx(g)
|
||||
sm = SessionManager(session_ttl_seconds=session_ttl)
|
||||
im = InterruptManager(ttl_seconds=interrupt_ttl)
|
||||
cb = TokenUsageCallbackHandler()
|
||||
ws = FakeWS()
|
||||
ws_ctx = WebSocketContext(
|
||||
graph_ctx=graph_ctx, session_manager=sm,
|
||||
callback_handler=cb, interrupt_manager=im,
|
||||
)
|
||||
if touch:
|
||||
sm.touch(thread_id)
|
||||
return g, sm, im, cb, ws, ws_ctx
|
||||
|
||||
|
||||
async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"):
|
||||
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
|
||||
async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True):
|
||||
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestWebSocketHappyPath:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_receives_tokens_and_complete(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup(
|
||||
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
|
||||
)
|
||||
await _send(ws, ws_ctx, content="What is the status of order 1042?")
|
||||
|
||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||
assert len(tokens) == 2
|
||||
assert tokens[0]["content"] == "Order 1042 is "
|
||||
assert tokens[0]["agent"] == "order_lookup"
|
||||
assert tokens[1]["content"] == "shipped."
|
||||
|
||||
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
assert completes[0]["thread_id"] == "t1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_streamed(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup(
|
||||
graph=_graph(chunks=[
|
||||
_tool_chunk("get_order_status", {"order_id": "1042"}),
|
||||
_chunk("Order shipped."),
|
||||
])
|
||||
)
|
||||
await _send(ws, ws_ctx, content="Check order 1042")
|
||||
|
||||
tools = [m for m in ws.sent if m["type"] == "tool_call"]
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["tool"] == "get_order_status"
|
||||
assert tools[0]["args"] == {"order_id": "1042"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_messages_same_session(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
for i in range(3):
|
||||
await _send(ws, ws_ctx, content=f"msg {i}")
|
||||
|
||||
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||
assert len(completes) == 3
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestWebSocketInterruptApproval:
|
||||
@pytest.mark.asyncio
|
||||
async def test_interrupt_then_approve(self) -> None:
|
||||
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
|
||||
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
||||
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
||||
|
||||
# Send message -> triggers interrupt
|
||||
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||
|
||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
assert interrupts[0]["action"] == "cancel_order"
|
||||
assert interrupts[0]["thread_id"] == "t1"
|
||||
assert im.has_pending("t1")
|
||||
|
||||
# Approve
|
||||
ws.sent.clear()
|
||||
await _respond(ws, ws_ctx, approved=True)
|
||||
|
||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||
assert len(tokens) == 1
|
||||
assert "cancelled" in tokens[0]["content"]
|
||||
|
||||
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
assert not im.has_pending("t1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interrupt_then_reject(self) -> None:
|
||||
st_int = _state(interrupt=True)
|
||||
resume = [_chunk("Order remains active.", "order_actions")]
|
||||
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
|
||||
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
|
||||
|
||||
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||
ws.sent.clear()
|
||||
|
||||
await _respond(ws, ws_ctx, approved=False)
|
||||
|
||||
tokens = [m for m in ws.sent if m["type"] == "token"]
|
||||
assert "remains active" in tokens[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestWebSocketSessionTTL:
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_session_returns_error(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0)
|
||||
# Session was touched in _setup, but TTL is 0 so it's already expired
|
||||
await _send(ws, ws_ctx, content="hello")
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "expired" in ws.sent[0]["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_not_expired(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
||||
await _send(ws, ws_ctx, content="hello")
|
||||
completes = [m for m in ws.sent if m["type"] == "message_complete"]
|
||||
assert len(completes) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sliding_window_resets_on_message(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
|
||||
|
||||
await _send(ws, ws_ctx, content="hello")
|
||||
first_activity = sm.get_state("t1").last_activity
|
||||
|
||||
time.sleep(0.01)
|
||||
await _send(ws, ws_ctx, content="hello again")
|
||||
second_activity = sm.get_state("t1").last_activity
|
||||
|
||||
assert second_activity > first_activity
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interrupt_extends_session_ttl(self) -> None:
|
||||
st_int = _state(interrupt=True)
|
||||
g = _graph(chunks=[], st=st_int)
|
||||
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600)
|
||||
|
||||
await _send(ws, ws_ctx, content="cancel order")
|
||||
|
||||
state = sm.get_state("t1")
|
||||
assert state is not None
|
||||
assert state.has_pending_interrupt
|
||||
assert not sm.is_expired("t1")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestWebSocketValidation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
await dispatch_message(ws, ws_ctx, "not json")
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "Invalid JSON" in ws.sent[0]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_thread_id(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
raw = json.dumps({"type": "message", "content": "hi"})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "thread_id" in ws.sent[0]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_thread_id_format(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_content(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1"})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_message_type(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "Unknown" in ws.sent[0]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_too_large(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
await dispatch_message(ws, ws_ctx, "x" * 40_000)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "too large" in ws.sent[0]["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_too_long(self) -> None:
|
||||
g, sm, im, cb, ws, ws_ctx = _setup()
|
||||
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
|
||||
await dispatch_message(ws, ws_ctx, raw)
|
||||
assert ws.sent[0]["type"] == "error"
|
||||
assert "too long" in ws.sent[0]["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestWebSocketInterruptTTL:
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
|
||||
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
|
||||
g = _graph(chunks=[], st=st_int)
|
||||
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5)
|
||||
|
||||
# Trigger interrupt
|
||||
await _send(ws, ws_ctx, content="Cancel order 1042")
|
||||
|
||||
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
|
||||
assert len(interrupts) == 1
|
||||
|
||||
# Expire the interrupt
|
||||
record = im._interrupts["t1"]
|
||||
ws.sent.clear()
|
||||
|
||||
with patch("app.interrupt_manager.time") as mock_time:
|
||||
mock_time.time.return_value = record.created_at + 10
|
||||
await _respond(ws, ws_ctx, approved=True)
|
||||
|
||||
assert ws.sent[0]["type"] == "interrupt_expired"
|
||||
assert "cancel_order" in ws.sent[0]["message"]
|
||||
assert ws.sent[0]["thread_id"] == "t1"
|
||||
1
backend/tests/unit/analytics/__init__.py
Normal file
1
backend/tests/unit/analytics/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit tests for app.analytics module."""
|
||||
149
backend/tests/unit/analytics/test_api.py
Normal file
149
backend/tests/unit/analytics/test_api.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Unit tests for app.analytics.api."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
from app.analytics.api import router
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
return app
|
||||
|
||||
|
||||
def _make_mock_pool() -> MagicMock:
|
||||
mock_conn = AsyncMock()
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.connection.return_value = mock_ctx
|
||||
return mock_pool
|
||||
|
||||
|
||||
def _make_analytics_result() -> object:
|
||||
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
|
||||
|
||||
return AnalyticsResult(
|
||||
range="7d",
|
||||
total_conversations=50,
|
||||
resolution_rate=0.8,
|
||||
escalation_rate=0.1,
|
||||
avg_turns_per_conversation=3.5,
|
||||
avg_cost_per_conversation_usd=0.02,
|
||||
agent_usage=(AgentUsage(agent="order_agent", count=30, percentage=60.0),),
|
||||
interrupt_stats=InterruptStats(total=5, approved=4, rejected=1, expired=0),
|
||||
)
|
||||
|
||||
|
||||
def _get_analytics(app: FastAPI, path: str = "/api/v1/analytics", **patch_kwargs: object) -> object:
|
||||
"""Helper: patch get_analytics, make request, return (response, mock)."""
|
||||
analytics_result = _make_analytics_result()
|
||||
with (
|
||||
patch("app.analytics.api.get_analytics", return_value=analytics_result) as mock_ga,
|
||||
TestClient(app) as client,
|
||||
):
|
||||
resp = client.get(path)
|
||||
return resp, mock_ga
|
||||
|
||||
|
||||
class TestAnalyticsEndpoint:
|
||||
def test_returns_200_with_default_range(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool()
|
||||
resp, _ = _get_analytics(app)
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
assert body["error"] is None
|
||||
assert body["data"]["range"] == "7d"
|
||||
|
||||
def test_returns_correct_analytics_structure(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool()
|
||||
resp, _ = _get_analytics(app)
|
||||
|
||||
data = resp.json()["data"]
|
||||
assert "total_conversations" in data
|
||||
assert "resolution_rate" in data
|
||||
assert "escalation_rate" in data
|
||||
assert "avg_turns_per_conversation" in data
|
||||
assert "avg_cost_per_conversation_usd" in data
|
||||
assert "agent_usage" in data
|
||||
assert "interrupt_stats" in data
|
||||
|
||||
def test_custom_range_7d(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool()
|
||||
resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=7d")
|
||||
|
||||
assert resp.status_code == 200
|
||||
mock_ga.assert_called_once()
|
||||
call_kwargs = mock_ga.call_args
|
||||
assert call_kwargs[1]["range_days"] == 7 or call_kwargs[0][1] == 7
|
||||
|
||||
def test_custom_range_30d(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool()
|
||||
resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=30d")
|
||||
|
||||
assert resp.status_code == 200
|
||||
call_kwargs = mock_ga.call_args
|
||||
assert call_kwargs[1].get("range_days") == 30 or (
|
||||
len(call_kwargs[0]) > 1 and call_kwargs[0][1] == 30
|
||||
)
|
||||
|
||||
def test_invalid_range_format_returns_400(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/analytics?range=invalid")
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_range_without_d_suffix_returns_400(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool()
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/analytics?range=7")
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_agent_usage_in_response(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool()
|
||||
resp, _ = _get_analytics(app)
|
||||
|
||||
data = resp.json()["data"]
|
||||
assert len(data["agent_usage"]) == 1
|
||||
assert data["agent_usage"][0]["agent"] == "order_agent"
|
||||
|
||||
def test_interrupt_stats_in_response(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool()
|
||||
resp, _ = _get_analytics(app)
|
||||
|
||||
data = resp.json()["data"]
|
||||
assert data["interrupt_stats"]["total"] == 5
|
||||
assert data["interrupt_stats"]["approved"] == 4
|
||||
|
||||
def test_envelope_format(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool()
|
||||
resp, _ = _get_analytics(app)
|
||||
|
||||
body = resp.json()
|
||||
assert "success" in body
|
||||
assert "data" in body
|
||||
assert "error" in body
|
||||
155
backend/tests/unit/analytics/test_event_recorder.py
Normal file
155
backend/tests/unit/analytics/test_event_recorder.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Unit tests for app.analytics.event_recorder."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestAnalyticsRecorderProtocol:
|
||||
def test_postgres_recorder_implements_protocol(self) -> None:
|
||||
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||
|
||||
mock_pool = MagicMock()
|
||||
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
|
||||
# Runtime check: has record method
|
||||
assert hasattr(recorder, "record")
|
||||
assert callable(recorder.record)
|
||||
|
||||
def test_noop_recorder_implements_protocol(self) -> None:
|
||||
from app.analytics.event_recorder import NoOpAnalyticsRecorder
|
||||
|
||||
recorder = NoOpAnalyticsRecorder()
|
||||
assert hasattr(recorder, "record")
|
||||
assert callable(recorder.record)
|
||||
|
||||
|
||||
class TestNoOpAnalyticsRecorder:
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_does_nothing(self) -> None:
|
||||
from app.analytics.event_recorder import NoOpAnalyticsRecorder
|
||||
|
||||
recorder = NoOpAnalyticsRecorder()
|
||||
# Should not raise
|
||||
await recorder.record(
|
||||
thread_id="t1",
|
||||
event_type="tool_call",
|
||||
agent_name="order_agent",
|
||||
tool_name="get_order",
|
||||
tokens_used=50,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_with_all_params(self) -> None:
|
||||
from app.analytics.event_recorder import NoOpAnalyticsRecorder
|
||||
|
||||
recorder = NoOpAnalyticsRecorder()
|
||||
await recorder.record(
|
||||
thread_id="t1",
|
||||
event_type="agent_response",
|
||||
agent_name="fallback",
|
||||
tool_name=None,
|
||||
tokens_used=100,
|
||||
cost_usd=0.002,
|
||||
duration_ms=150,
|
||||
success=True,
|
||||
error_message=None,
|
||||
metadata={"extra": "data"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_minimal_params(self) -> None:
|
||||
from app.analytics.event_recorder import NoOpAnalyticsRecorder
|
||||
|
||||
recorder = NoOpAnalyticsRecorder()
|
||||
# Only required params
|
||||
await recorder.record(thread_id="t1", event_type="conversation_start")
|
||||
|
||||
|
||||
class TestPostgresAnalyticsRecorder:
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_executes_insert(self) -> None:
|
||||
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.connection.return_value = mock_ctx
|
||||
|
||||
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
|
||||
await recorder.record(
|
||||
thread_id="t1",
|
||||
event_type="tool_call",
|
||||
agent_name="order_agent",
|
||||
tokens_used=50,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
mock_conn.execute.assert_awaited_once()
|
||||
call_args = mock_conn.execute.call_args
|
||||
sql = call_args[0][0]
|
||||
assert "INSERT INTO analytics_events" in sql
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_passes_correct_params(self) -> None:
|
||||
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.connection.return_value = mock_ctx
|
||||
|
||||
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
|
||||
await recorder.record(
|
||||
thread_id="thread-xyz",
|
||||
event_type="agent_response",
|
||||
agent_name="discount_agent",
|
||||
tool_name="apply_discount",
|
||||
tokens_used=75,
|
||||
cost_usd=0.002,
|
||||
duration_ms=300,
|
||||
success=True,
|
||||
error_message=None,
|
||||
metadata={"promo": "10PCT"},
|
||||
)
|
||||
call_args = mock_conn.execute.call_args
|
||||
params = call_args[0][1]
|
||||
assert params["thread_id"] == "thread-xyz"
|
||||
assert params["event_type"] == "agent_response"
|
||||
assert params["agent_name"] == "discount_agent"
|
||||
assert params["tokens_used"] == 75
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_stores_metadata_as_dict(self) -> None:
|
||||
from app.analytics.event_recorder import PostgresAnalyticsRecorder
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.connection.return_value = mock_ctx
|
||||
|
||||
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
|
||||
await recorder.record(
|
||||
thread_id="t1",
|
||||
event_type="tool_call",
|
||||
metadata={"key": "val"},
|
||||
)
|
||||
call_args = mock_conn.execute.call_args
|
||||
params = call_args[0][1]
|
||||
# PostgresAnalyticsRecorder wraps metadata with psycopg Json() adapter.
|
||||
# Unwrap to compare the inner dict.
|
||||
from psycopg.types.json import Json
|
||||
|
||||
meta = params["metadata"]
|
||||
if isinstance(meta, Json):
|
||||
meta = meta.obj
|
||||
assert meta == {"key": "val"}
|
||||
106
backend/tests/unit/analytics/test_models.py
Normal file
106
backend/tests/unit/analytics/test_models.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Unit tests for app.analytics.models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestAgentUsage:
|
||||
def test_agent_usage_construction(self) -> None:
|
||||
from app.analytics.models import AgentUsage
|
||||
|
||||
au = AgentUsage(agent="order_agent", count=10, percentage=50.0)
|
||||
assert au.agent == "order_agent"
|
||||
assert au.count == 10
|
||||
assert au.percentage == 50.0
|
||||
|
||||
def test_agent_usage_is_frozen(self) -> None:
|
||||
from app.analytics.models import AgentUsage
|
||||
|
||||
au = AgentUsage(agent="a", count=1, percentage=100.0)
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
au.count = 2 # type: ignore[misc]
|
||||
|
||||
|
||||
class TestInterruptStats:
|
||||
def test_interrupt_stats_defaults(self) -> None:
|
||||
from app.analytics.models import InterruptStats
|
||||
|
||||
stats = InterruptStats()
|
||||
assert stats.total == 0
|
||||
assert stats.approved == 0
|
||||
assert stats.rejected == 0
|
||||
assert stats.expired == 0
|
||||
|
||||
def test_interrupt_stats_custom_values(self) -> None:
|
||||
from app.analytics.models import InterruptStats
|
||||
|
||||
stats = InterruptStats(total=10, approved=7, rejected=2, expired=1)
|
||||
assert stats.total == 10
|
||||
assert stats.approved == 7
|
||||
assert stats.rejected == 2
|
||||
assert stats.expired == 1
|
||||
|
||||
def test_interrupt_stats_is_frozen(self) -> None:
|
||||
from app.analytics.models import InterruptStats
|
||||
|
||||
stats = InterruptStats()
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
stats.total = 5 # type: ignore[misc]
|
||||
|
||||
|
||||
class TestAnalyticsResult:
|
||||
def test_analytics_result_construction(self) -> None:
|
||||
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
|
||||
|
||||
result = AnalyticsResult(
|
||||
range="7d",
|
||||
total_conversations=100,
|
||||
resolution_rate=0.85,
|
||||
escalation_rate=0.05,
|
||||
avg_turns_per_conversation=4.2,
|
||||
avg_cost_per_conversation_usd=0.03,
|
||||
agent_usage=(AgentUsage(agent="order_agent", count=60, percentage=60.0),),
|
||||
interrupt_stats=InterruptStats(total=5, approved=4, rejected=1, expired=0),
|
||||
)
|
||||
assert result.range == "7d"
|
||||
assert result.total_conversations == 100
|
||||
assert result.resolution_rate == 0.85
|
||||
assert result.escalation_rate == 0.05
|
||||
assert result.avg_turns_per_conversation == 4.2
|
||||
assert result.avg_cost_per_conversation_usd == 0.03
|
||||
assert len(result.agent_usage) == 1
|
||||
assert result.interrupt_stats.total == 5
|
||||
|
||||
def test_analytics_result_is_frozen(self) -> None:
|
||||
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||
|
||||
result = AnalyticsResult(
|
||||
range="7d",
|
||||
total_conversations=0,
|
||||
resolution_rate=0.0,
|
||||
escalation_rate=0.0,
|
||||
avg_turns_per_conversation=0.0,
|
||||
avg_cost_per_conversation_usd=0.0,
|
||||
agent_usage=(),
|
||||
interrupt_stats=InterruptStats(),
|
||||
)
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
result.range = "30d" # type: ignore[misc]
|
||||
|
||||
def test_analytics_result_empty_agent_usage(self) -> None:
|
||||
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||
|
||||
result = AnalyticsResult(
|
||||
range="7d",
|
||||
total_conversations=0,
|
||||
resolution_rate=0.0,
|
||||
escalation_rate=0.0,
|
||||
avg_turns_per_conversation=0.0,
|
||||
avg_cost_per_conversation_usd=0.0,
|
||||
agent_usage=(),
|
||||
interrupt_stats=InterruptStats(),
|
||||
)
|
||||
assert result.agent_usage == ()
|
||||
249
backend/tests/unit/analytics/test_queries.py
Normal file
249
backend/tests/unit/analytics/test_queries.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Unit tests for app.analytics.queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_pool_with_fetchone(result: dict | None) -> MagicMock:
|
||||
mock_cursor = AsyncMock()
|
||||
mock_cursor.fetchone = AsyncMock(return_value=result)
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.connection.return_value = mock_ctx
|
||||
return mock_pool
|
||||
|
||||
|
||||
def _make_pool_with_fetchall(result: list[dict]) -> MagicMock:
|
||||
mock_cursor = AsyncMock()
|
||||
mock_cursor.fetchall = AsyncMock(return_value=result)
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.connection.return_value = mock_ctx
|
||||
return mock_pool
|
||||
|
||||
|
||||
class TestResolutionRate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_float(self) -> None:
|
||||
from app.analytics.queries import resolution_rate
|
||||
|
||||
pool = _make_pool_with_fetchone({"rate": 0.85})
|
||||
result = await resolution_rate(pool, range_days=7)
|
||||
assert isinstance(result, float)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_state_returns_zero(self) -> None:
|
||||
from app.analytics.queries import resolution_rate
|
||||
|
||||
pool = _make_pool_with_fetchone(None)
|
||||
result = await resolution_rate(pool, range_days=7)
|
||||
assert result == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_correct_value(self) -> None:
|
||||
from app.analytics.queries import resolution_rate
|
||||
|
||||
pool = _make_pool_with_fetchone({"rate": 0.75})
|
||||
result = await resolution_rate(pool, range_days=7)
|
||||
assert result == 0.75
|
||||
|
||||
|
||||
class TestAgentUsageQuery:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_tuple(self) -> None:
|
||||
from app.analytics.queries import agent_usage
|
||||
|
||||
pool = _make_pool_with_fetchall([])
|
||||
result = await agent_usage(pool, range_days=7)
|
||||
assert isinstance(result, tuple)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_state_returns_empty_tuple(self) -> None:
|
||||
from app.analytics.queries import agent_usage
|
||||
|
||||
pool = _make_pool_with_fetchall([])
|
||||
result = await agent_usage(pool, range_days=7)
|
||||
assert result == ()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maps_rows_to_agent_usage_objects(self) -> None:
|
||||
from app.analytics.models import AgentUsage
|
||||
from app.analytics.queries import agent_usage
|
||||
|
||||
pool = _make_pool_with_fetchall([
|
||||
{"agent": "order_agent", "count": 10, "percentage": 66.7},
|
||||
{"agent": "discount_agent", "count": 5, "percentage": 33.3},
|
||||
])
|
||||
result = await agent_usage(pool, range_days=7)
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], AgentUsage)
|
||||
assert result[0].agent == "order_agent"
|
||||
assert result[0].count == 10
|
||||
|
||||
|
||||
class TestEscalationRate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_float(self) -> None:
|
||||
from app.analytics.queries import escalation_rate
|
||||
|
||||
pool = _make_pool_with_fetchone({"rate": 0.05})
|
||||
result = await escalation_rate(pool, range_days=7)
|
||||
assert isinstance(result, float)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_state_returns_zero(self) -> None:
|
||||
from app.analytics.queries import escalation_rate
|
||||
|
||||
pool = _make_pool_with_fetchone(None)
|
||||
result = await escalation_rate(pool, range_days=7)
|
||||
assert result == 0.0
|
||||
|
||||
|
||||
class TestCostPerConversation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_float(self) -> None:
|
||||
from app.analytics.queries import cost_per_conversation
|
||||
|
||||
pool = _make_pool_with_fetchone({"avg_cost": 0.03})
|
||||
result = await cost_per_conversation(pool, range_days=7)
|
||||
assert isinstance(result, float)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_state_returns_zero(self) -> None:
|
||||
from app.analytics.queries import cost_per_conversation
|
||||
|
||||
pool = _make_pool_with_fetchone(None)
|
||||
result = await cost_per_conversation(pool, range_days=7)
|
||||
assert result == 0.0
|
||||
|
||||
|
||||
class TestInterruptStatsQuery:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_interrupt_stats(self) -> None:
|
||||
from app.analytics.models import InterruptStats
|
||||
from app.analytics.queries import interrupt_stats
|
||||
|
||||
pool = _make_pool_with_fetchone(
|
||||
{"total": 10, "approved": 7, "rejected": 2, "expired": 1}
|
||||
)
|
||||
result = await interrupt_stats(pool, range_days=7)
|
||||
assert isinstance(result, InterruptStats)
|
||||
assert result.total == 10
|
||||
assert result.approved == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_state_returns_zeros(self) -> None:
|
||||
from app.analytics.models import InterruptStats
|
||||
from app.analytics.queries import interrupt_stats
|
||||
|
||||
pool = _make_pool_with_fetchone(None)
|
||||
result = await interrupt_stats(pool, range_days=7)
|
||||
assert isinstance(result, InterruptStats)
|
||||
assert result.total == 0
|
||||
assert result.approved == 0
|
||||
assert result.rejected == 0
|
||||
assert result.expired == 0
|
||||
|
||||
|
||||
class TestTotalConversations:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_count(self) -> None:
|
||||
from app.analytics.queries import _total_conversations
|
||||
|
||||
pool = _make_pool_with_fetchone({"total": 42})
|
||||
result = await _total_conversations(pool, range_days=7)
|
||||
assert result == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_state_returns_zero(self) -> None:
|
||||
from app.analytics.queries import _total_conversations
|
||||
|
||||
pool = _make_pool_with_fetchone(None)
|
||||
result = await _total_conversations(pool, range_days=7)
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestAvgTurns:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_float(self) -> None:
|
||||
from app.analytics.queries import _avg_turns
|
||||
|
||||
pool = _make_pool_with_fetchone({"avg_turns": 3.5})
|
||||
result = await _avg_turns(pool, range_days=7)
|
||||
assert result == 3.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_state_returns_zero(self) -> None:
|
||||
from app.analytics.queries import _avg_turns
|
||||
|
||||
pool = _make_pool_with_fetchone(None)
|
||||
result = await _avg_turns(pool, range_days=7)
|
||||
assert result == 0.0
|
||||
|
||||
|
||||
class TestGetAnalytics:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_analytics_result(self) -> None:
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||
from app.analytics.queries import get_analytics
|
||||
|
||||
mock_pool = MagicMock()
|
||||
|
||||
with (
|
||||
patch("app.analytics.queries.resolution_rate", return_value=0.85),
|
||||
patch("app.analytics.queries.escalation_rate", return_value=0.05),
|
||||
patch("app.analytics.queries.cost_per_conversation", return_value=0.03),
|
||||
patch("app.analytics.queries.agent_usage", return_value=()),
|
||||
patch(
|
||||
"app.analytics.queries.interrupt_stats",
|
||||
return_value=InterruptStats(),
|
||||
),
|
||||
patch("app.analytics.queries._total_conversations", return_value=100),
|
||||
patch("app.analytics.queries._avg_turns", return_value=4.2),
|
||||
):
|
||||
result = await get_analytics(mock_pool, range_days=7)
|
||||
|
||||
assert isinstance(result, AnalyticsResult)
|
||||
assert result.range == "7d"
|
||||
assert result.total_conversations == 100
|
||||
assert result.resolution_rate == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_state_returns_zeros(self) -> None:
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.analytics.models import AnalyticsResult, InterruptStats
|
||||
from app.analytics.queries import get_analytics
|
||||
|
||||
mock_pool = MagicMock()
|
||||
|
||||
with (
|
||||
patch("app.analytics.queries.resolution_rate", return_value=0.0),
|
||||
patch("app.analytics.queries.escalation_rate", return_value=0.0),
|
||||
patch("app.analytics.queries.cost_per_conversation", return_value=0.0),
|
||||
patch("app.analytics.queries.agent_usage", return_value=()),
|
||||
patch("app.analytics.queries.interrupt_stats", return_value=InterruptStats()),
|
||||
patch("app.analytics.queries._total_conversations", return_value=0),
|
||||
patch("app.analytics.queries._avg_turns", return_value=0.0),
|
||||
):
|
||||
result = await get_analytics(mock_pool, range_days=7)
|
||||
|
||||
assert isinstance(result, AnalyticsResult)
|
||||
assert result.total_conversations == 0
|
||||
assert result.resolution_rate == 0.0
|
||||
assert result.agent_usage == ()
|
||||
0
backend/tests/unit/openapi/__init__.py
Normal file
0
backend/tests/unit/openapi/__init__.py
Normal file
249
backend/tests/unit/openapi/test_classifier.py
Normal file
249
backend/tests/unit/openapi/test_classifier.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Tests for OpenAPI endpoint classifier module.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.openapi.models import EndpointInfo, ParameterInfo
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_endpoint(
|
||||
path: str = "/items",
|
||||
method: str = "GET",
|
||||
operation_id: str = "list_items",
|
||||
summary: str = "List items",
|
||||
description: str = "",
|
||||
parameters: tuple[ParameterInfo, ...] = (),
|
||||
) -> EndpointInfo:
|
||||
return EndpointInfo(
|
||||
path=path,
|
||||
method=method,
|
||||
operation_id=operation_id,
|
||||
summary=summary,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
|
||||
_ORDER_PARAM = ParameterInfo(
|
||||
name="order_id", location="path", required=True, schema_type="string"
|
||||
)
|
||||
_CUSTOMER_PARAM = ParameterInfo(
|
||||
name="customer_id", location="query", required=False, schema_type="string"
|
||||
)
|
||||
|
||||
|
||||
class TestHeuristicClassifier:
|
||||
"""Tests for the rule-based HeuristicClassifier."""
|
||||
|
||||
async def test_get_classified_as_read(self) -> None:
|
||||
"""GET endpoints are classified as read access."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="GET")
|
||||
results = await clf.classify((ep,))
|
||||
assert len(results) == 1
|
||||
assert results[0].access_type == "read"
|
||||
|
||||
async def test_post_classified_as_write(self) -> None:
|
||||
"""POST endpoints are classified as write access."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="POST")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].access_type == "write"
|
||||
|
||||
async def test_post_needs_interrupt(self) -> None:
|
||||
"""POST endpoints require interrupt/approval."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="POST")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].needs_interrupt is True
|
||||
|
||||
async def test_put_classified_as_write(self) -> None:
|
||||
"""PUT endpoints are classified as write access."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="PUT")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].access_type == "write"
|
||||
|
||||
async def test_delete_classified_as_write_with_interrupt(self) -> None:
|
||||
"""DELETE endpoints are classified as write and require interrupt."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="DELETE")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].access_type == "write"
|
||||
assert results[0].needs_interrupt is True
|
||||
|
||||
async def test_get_does_not_need_interrupt(self) -> None:
|
||||
"""GET endpoints do not require interrupt."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="GET")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].needs_interrupt is False
|
||||
|
||||
async def test_empty_endpoints_returns_empty_tuple(self) -> None:
|
||||
"""Empty input yields empty output."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
results = await clf.classify(())
|
||||
assert results == ()
|
||||
|
||||
async def test_customer_params_detected_order_id(self) -> None:
|
||||
"""Parameters named order_id are recognized as customer params."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="GET", parameters=(_ORDER_PARAM,))
|
||||
results = await clf.classify((ep,))
|
||||
assert "order_id" in results[0].customer_params
|
||||
|
||||
async def test_customer_params_detected_customer_id(self) -> None:
|
||||
"""Parameters named customer_id are recognized as customer params."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="GET", parameters=(_CUSTOMER_PARAM,))
|
||||
results = await clf.classify((ep,))
|
||||
assert "customer_id" in results[0].customer_params
|
||||
|
||||
async def test_result_is_tuple(self) -> None:
|
||||
"""classify returns a tuple (immutable)."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint()
|
||||
results = await clf.classify((ep,))
|
||||
assert isinstance(results, tuple)
|
||||
|
||||
async def test_classification_has_confidence(self) -> None:
|
||||
"""Heuristic results have a confidence value between 0 and 1."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint()
|
||||
results = await clf.classify((ep,))
|
||||
assert 0.0 <= results[0].confidence <= 1.0
|
||||
|
||||
async def test_patch_classified_as_write(self) -> None:
|
||||
"""PATCH endpoints are classified as write access."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
ep = _make_endpoint(method="PATCH")
|
||||
results = await clf.classify((ep,))
|
||||
assert results[0].access_type == "write"
|
||||
|
||||
|
||||
class TestLLMClassifier:
|
||||
"""Tests for the LLM-backed classifier."""
|
||||
|
||||
def _make_mock_llm(self, classifications: list[dict]) -> MagicMock:
|
||||
"""Create a mock LLM that returns structured classification data."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = str(classifications)
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
return mock_llm
|
||||
|
||||
async def test_llm_classifier_classifies_endpoints(self) -> None:
|
||||
"""LLM classifier returns ClassificationResult for each endpoint."""
|
||||
from app.openapi.classifier import LLMClassifier
|
||||
|
||||
ep = _make_endpoint(method="GET")
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = (
|
||||
'[{"access_type": "read", "agent_group": "support",'
|
||||
' "needs_interrupt": false, "customer_params": [], "confidence": 0.9}]'
|
||||
)
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
clf = LLMClassifier(llm=mock_llm)
|
||||
results = await clf.classify((ep,))
|
||||
assert len(results) == 1
|
||||
assert results[0].access_type == "read"
|
||||
|
||||
async def test_llm_failure_falls_back_to_heuristic(self) -> None:
|
||||
"""When LLM raises an exception, falls back to heuristic classifier."""
|
||||
from app.openapi.classifier import LLMClassifier
|
||||
|
||||
ep = _make_endpoint(method="GET")
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
|
||||
|
||||
clf = LLMClassifier(llm=mock_llm)
|
||||
results = await clf.classify((ep,))
|
||||
# Falls back to heuristic: GET = read
|
||||
assert len(results) == 1
|
||||
assert results[0].access_type == "read"
|
||||
|
||||
async def test_llm_invalid_json_falls_back_to_heuristic(self) -> None:
|
||||
"""When LLM returns unparseable output, falls back to heuristic."""
|
||||
from app.openapi.classifier import LLMClassifier
|
||||
|
||||
ep = _make_endpoint(method="DELETE")
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "this is not valid json at all"
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
clf = LLMClassifier(llm=mock_llm)
|
||||
results = await clf.classify((ep,))
|
||||
# Fallback: DELETE = write with interrupt
|
||||
assert results[0].access_type == "write"
|
||||
assert results[0].needs_interrupt is True
|
||||
|
||||
async def test_llm_empty_endpoints_returns_empty(self) -> None:
|
||||
"""Empty input yields empty output without calling LLM."""
|
||||
from app.openapi.classifier import LLMClassifier
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock()
|
||||
|
||||
clf = LLMClassifier(llm=mock_llm)
|
||||
results = await clf.classify(())
|
||||
assert results == ()
|
||||
mock_llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
class TestClassifierProtocol:
|
||||
"""Verify both classifiers conform to ClassifierProtocol."""
|
||||
|
||||
def test_heuristic_has_classify_method(self) -> None:
|
||||
"""HeuristicClassifier exposes classify method."""
|
||||
from app.openapi.classifier import HeuristicClassifier
|
||||
|
||||
clf = HeuristicClassifier()
|
||||
assert hasattr(clf, "classify")
|
||||
assert callable(clf.classify)
|
||||
|
||||
def test_llm_has_classify_method(self) -> None:
|
||||
"""LLMClassifier exposes classify method."""
|
||||
from app.openapi.classifier import LLMClassifier
|
||||
|
||||
mock_llm = MagicMock()
|
||||
clf = LLMClassifier(llm=mock_llm)
|
||||
assert hasattr(clf, "classify")
|
||||
assert callable(clf.classify)
|
||||
120
backend/tests/unit/openapi/test_fetcher.py
Normal file
120
backend/tests/unit/openapi/test_fetcher.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Tests for OpenAPI spec fetcher module.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.openapi.ssrf import SSRFError
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_SAMPLE_JSON = '{"openapi": "3.0.0", "info": {"title": "Test", "version": "1.0"}, "paths": {}}'
|
||||
_SAMPLE_YAML = "openapi: '3.0.0'\ninfo:\n title: Test\n version: '1.0'\npaths: {}\n"
|
||||
_PUBLIC_IP = "93.184.216.34"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_public_dns():
|
||||
with patch("app.openapi.ssrf.resolve_hostname", return_value=[_PUBLIC_IP]):
|
||||
yield
|
||||
|
||||
|
||||
class TestFetchSpec:
|
||||
"""Tests for fetch_spec function."""
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_fetch_json_spec_succeeds(self, httpx_mock) -> None:
|
||||
"""Fetch a JSON spec and return parsed dict."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/spec.json",
|
||||
text=_SAMPLE_JSON,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
result = await fetch_spec("https://example.com/spec.json")
|
||||
assert isinstance(result, dict)
|
||||
assert result["openapi"] == "3.0.0"
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_fetch_yaml_spec_succeeds(self, httpx_mock) -> None:
|
||||
"""Fetch a YAML spec and return parsed dict."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/spec.yaml",
|
||||
text=_SAMPLE_YAML,
|
||||
headers={"content-type": "application/x-yaml"},
|
||||
)
|
||||
result = await fetch_spec("https://example.com/spec.yaml")
|
||||
assert isinstance(result, dict)
|
||||
assert result["openapi"] == "3.0.0"
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_fetch_yaml_by_url_extension(self, httpx_mock) -> None:
|
||||
"""Auto-detect YAML format from .yaml URL extension."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/api.yaml",
|
||||
text=_SAMPLE_YAML,
|
||||
headers={"content-type": "text/plain"},
|
||||
)
|
||||
result = await fetch_spec("https://example.com/api.yaml")
|
||||
assert isinstance(result, dict)
|
||||
assert result["openapi"] == "3.0.0"
|
||||
|
||||
async def test_ssrf_blocked_url_raises(self) -> None:
|
||||
"""SSRF-blocked URL raises SSRFError."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
with (
|
||||
patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]),
|
||||
pytest.raises(SSRFError),
|
||||
):
|
||||
await fetch_spec("http://internal.corp/spec.json")
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_oversized_response_raises(self, httpx_mock) -> None:
|
||||
"""Response exceeding 10MB raises ValueError."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
big_content = "x" * (10 * 1024 * 1024 + 1)
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/huge.json",
|
||||
text=big_content,
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
with pytest.raises(ValueError, match="too large"):
|
||||
await fetch_spec("https://example.com/huge.json")
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_invalid_json_raises(self, httpx_mock) -> None:
|
||||
"""Non-parseable JSON raises ValueError."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/bad.json",
|
||||
text="not valid json {{{",
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
with pytest.raises(ValueError, match="[Pp]arse|[Ii]nvalid|[Dd]ecode"):
|
||||
await fetch_spec("https://example.com/bad.json")
|
||||
|
||||
@pytest.mark.usefixtures("_mock_public_dns")
|
||||
async def test_invalid_yaml_raises(self, httpx_mock) -> None:
|
||||
"""Non-parseable YAML raises ValueError."""
|
||||
from app.openapi.fetcher import fetch_spec
|
||||
|
||||
httpx_mock.add_response(
|
||||
url="https://example.com/bad.yaml",
|
||||
text=": invalid: yaml: {\n",
|
||||
headers={"content-type": "application/x-yaml"},
|
||||
)
|
||||
with pytest.raises(ValueError, match="[Pp]arse|[Ii]nvalid|[Yy]AML"):
|
||||
await fetch_spec("https://example.com/bad.yaml")
|
||||
258
backend/tests/unit/openapi/test_generator.py
Normal file
258
backend/tests/unit/openapi/test_generator.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Tests for OpenAPI tool generator module.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.openapi.models import ClassificationResult, EndpointInfo, ParameterInfo
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_BASE_URL = "https://api.example.com"
|
||||
|
||||
|
||||
def _make_endpoint(
|
||||
path: str = "/items",
|
||||
method: str = "GET",
|
||||
operation_id: str = "list_items",
|
||||
summary: str = "List items",
|
||||
description: str = "Returns all items",
|
||||
parameters: tuple[ParameterInfo, ...] = (),
|
||||
request_body_schema: dict | None = None,
|
||||
) -> EndpointInfo:
|
||||
return EndpointInfo(
|
||||
path=path,
|
||||
method=method,
|
||||
operation_id=operation_id,
|
||||
summary=summary,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
request_body_schema=request_body_schema,
|
||||
)
|
||||
|
||||
|
||||
def _make_classification(
|
||||
endpoint: EndpointInfo,
|
||||
access_type: str = "read",
|
||||
needs_interrupt: bool = False,
|
||||
agent_group: str = "read_agent",
|
||||
) -> ClassificationResult:
|
||||
return ClassificationResult(
|
||||
endpoint=endpoint,
|
||||
access_type=access_type,
|
||||
customer_params=(),
|
||||
agent_group=agent_group,
|
||||
confidence=0.9,
|
||||
needs_interrupt=needs_interrupt,
|
||||
)
|
||||
|
||||
|
||||
_PATH_PARAM = ParameterInfo(
|
||||
name="item_id", location="path", required=True, schema_type="string"
|
||||
)
|
||||
_QUERY_PARAM = ParameterInfo(
|
||||
name="filter", location="query", required=False, schema_type="string"
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateToolCode:
|
||||
"""Tests for generate_tool_code function."""
|
||||
|
||||
def test_generate_tool_for_get_endpoint(self) -> None:
|
||||
"""Generated tool for GET endpoint is a GeneratedTool with non-empty code."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(method="GET")
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert tool.function_name == "list_items"
|
||||
assert tool.code != ""
|
||||
assert "@tool" in tool.code
|
||||
|
||||
def test_generate_tool_contains_function_name(self) -> None:
|
||||
"""Generated code contains the function name."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(operation_id="get_order", method="GET")
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert "get_order" in tool.code
|
||||
|
||||
def test_generate_tool_contains_base_url(self) -> None:
|
||||
"""Generated code contains the base URL."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert _BASE_URL in tool.code
|
||||
|
||||
def test_generate_tool_contains_http_method(self) -> None:
|
||||
"""Generated code uses the correct HTTP method."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(method="POST")
|
||||
clf = _make_classification(ep, access_type="write")
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert "post" in tool.code.lower()
|
||||
|
||||
def test_generate_tool_for_post_with_body(self) -> None:
|
||||
"""Generated tool for POST includes body parameter."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(
|
||||
method="POST",
|
||||
request_body_schema={"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
)
|
||||
clf = _make_classification(ep, access_type="write")
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert tool.code != ""
|
||||
assert "POST" in tool.code or "post" in tool.code
|
||||
|
||||
def test_generate_tool_with_path_params(self) -> None:
|
||||
"""Generated tool includes path parameter in function signature."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(
|
||||
path="/items/{item_id}",
|
||||
operation_id="get_item",
|
||||
parameters=(_PATH_PARAM,),
|
||||
)
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert "item_id" in tool.code
|
||||
|
||||
def test_write_tool_includes_interrupt_marker(self) -> None:
|
||||
"""Write tools that need interrupt include a marker comment."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(method="DELETE", operation_id="delete_item")
|
||||
clf = _make_classification(ep, access_type="write", needs_interrupt=True)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert "interrupt" in tool.code.lower() or "approval" in tool.code.lower()
|
||||
|
||||
def test_generated_code_is_executable(self) -> None:
|
||||
"""Generated code can be exec'd without syntax errors."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint(
|
||||
path="/items/{item_id}",
|
||||
operation_id="fetch_item",
|
||||
parameters=(_PATH_PARAM,),
|
||||
)
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
# Must be valid Python syntax
|
||||
compile(tool.code, "<generated>", "exec")
|
||||
|
||||
def test_generated_tool_code_exec_imports(self) -> None:
|
||||
"""Generated code exec'd with required imports does not raise."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
namespace: dict = {}
|
||||
try:
|
||||
import httpx
|
||||
from langchain_core.tools import tool as lc_tool
|
||||
|
||||
namespace = {"httpx": httpx, "tool": lc_tool}
|
||||
exec(tool.code, namespace) # noqa: S102
|
||||
except ImportError:
|
||||
pytest.skip("langchain_core not available for exec test")
|
||||
|
||||
def test_returns_generated_tool_instance(self) -> None:
|
||||
"""generate_tool_code returns a GeneratedTool instance."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
from app.openapi.models import GeneratedTool
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
assert isinstance(tool, GeneratedTool)
|
||||
|
||||
def test_generated_tool_is_frozen(self) -> None:
|
||||
"""GeneratedTool instance is immutable."""
|
||||
from app.openapi.generator import generate_tool_code
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
tool = generate_tool_code(clf, _BASE_URL)
|
||||
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
tool.code = "new code" # type: ignore[misc]
|
||||
|
||||
|
||||
class TestGenerateAgentYaml:
|
||||
"""Tests for generate_agent_yaml function."""
|
||||
|
||||
def test_generate_yaml_is_valid_string(self) -> None:
|
||||
"""generate_agent_yaml returns a non-empty string."""
|
||||
from app.openapi.generator import generate_agent_yaml
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
def test_generated_yaml_is_parseable(self) -> None:
|
||||
"""Output can be parsed as YAML."""
|
||||
import yaml
|
||||
|
||||
from app.openapi.generator import generate_agent_yaml
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||
|
||||
parsed = yaml.safe_load(result)
|
||||
assert isinstance(parsed, dict)
|
||||
|
||||
def test_generated_yaml_contains_agents_key(self) -> None:
|
||||
"""Generated YAML has an 'agents' key matching AgentConfig format."""
|
||||
import yaml
|
||||
|
||||
from app.openapi.generator import generate_agent_yaml
|
||||
|
||||
ep = _make_endpoint()
|
||||
clf = _make_classification(ep)
|
||||
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||
|
||||
parsed = yaml.safe_load(result)
|
||||
assert "agents" in parsed
|
||||
|
||||
def test_generated_yaml_contains_tool_name(self) -> None:
|
||||
"""Generated YAML references the tool function name."""
|
||||
from app.openapi.generator import generate_agent_yaml
|
||||
|
||||
ep = _make_endpoint(operation_id="list_orders")
|
||||
clf = _make_classification(ep)
|
||||
result = generate_agent_yaml((clf,), _BASE_URL)
|
||||
|
||||
assert "list_orders" in result
|
||||
|
||||
def test_empty_classifications_returns_empty_agents(self) -> None:
|
||||
"""No classifications yields YAML with empty agents list."""
|
||||
import yaml
|
||||
|
||||
from app.openapi.generator import generate_agent_yaml
|
||||
|
||||
result = generate_agent_yaml((), _BASE_URL)
|
||||
parsed = yaml.safe_load(result)
|
||||
assert parsed.get("agents") == [] or parsed.get("agents") is None
|
||||
290
backend/tests/unit/openapi/test_parser.py
Normal file
290
backend/tests/unit/openapi/test_parser.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Tests for OpenAPI endpoint parser module.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_MINIMAL_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test API", "version": "1.0.0"},
|
||||
"paths": {},
|
||||
}
|
||||
|
||||
_GET_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Orders API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/orders/{order_id}": {
|
||||
"get": {
|
||||
"operationId": "get_order",
|
||||
"summary": "Get an order",
|
||||
"description": "Retrieves a single order by ID",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "order_id",
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "string"},
|
||||
"description": "The order identifier",
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Order found",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
_POST_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Orders API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/orders": {
|
||||
"post": {
|
||||
"operationId": "create_order",
|
||||
"summary": "Create an order",
|
||||
"description": "Creates a new order",
|
||||
"requestBody": {
|
||||
"required": True,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"item": {"type": "string"},
|
||||
"quantity": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
"responses": {"201": {"description": "Created"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
_MULTI_PARAM_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Items API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/items/{item_id}": {
|
||||
"get": {
|
||||
"operationId": "get_item",
|
||||
"summary": "Get item",
|
||||
"description": "",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "item_id",
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "integer"},
|
||||
},
|
||||
{
|
||||
"name": "include_details",
|
||||
"in": "query",
|
||||
"required": False,
|
||||
"schema": {"type": "boolean"},
|
||||
},
|
||||
],
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
_REF_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Ref API", "version": "1.0.0"},
|
||||
"components": {
|
||||
"schemas": {
|
||||
"Item": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
"paths": {
|
||||
"/items": {
|
||||
"get": {
|
||||
"operationId": "list_items",
|
||||
"summary": "List items",
|
||||
"description": "",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {"$ref": "#/components/schemas/Item"}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
_MULTI_ENDPOINT_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Multi API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/users": {
|
||||
"get": {
|
||||
"operationId": "list_users",
|
||||
"summary": "List users",
|
||||
"description": "",
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
},
|
||||
"post": {
|
||||
"operationId": "create_user",
|
||||
"summary": "Create user",
|
||||
"description": "",
|
||||
"responses": {"201": {"description": "Created"}},
|
||||
},
|
||||
},
|
||||
"/users/{id}": {
|
||||
"delete": {
|
||||
"operationId": "delete_user",
|
||||
"summary": "Delete user",
|
||||
"description": "",
|
||||
"parameters": [
|
||||
{"name": "id", "in": "path", "required": True, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {"204": {"description": "Deleted"}},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestParseEndpoints:
|
||||
"""Tests for parse_endpoints function."""
|
||||
|
||||
def test_empty_paths_returns_empty_tuple(self) -> None:
|
||||
"""Spec with no paths yields no endpoints."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_MINIMAL_SPEC)
|
||||
assert result == ()
|
||||
|
||||
def test_parse_get_endpoint(self) -> None:
|
||||
"""Parse a GET endpoint with path parameter."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_GET_SPEC)
|
||||
assert len(result) == 1
|
||||
ep = result[0]
|
||||
assert ep.path == "/orders/{order_id}"
|
||||
assert ep.method == "GET"
|
||||
assert ep.operation_id == "get_order"
|
||||
assert ep.summary == "Get an order"
|
||||
|
||||
def test_parse_get_endpoint_parameters(self) -> None:
|
||||
"""Path parameters are extracted correctly."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_GET_SPEC)
|
||||
ep = result[0]
|
||||
assert len(ep.parameters) == 1
|
||||
param = ep.parameters[0]
|
||||
assert param.name == "order_id"
|
||||
assert param.location == "path"
|
||||
assert param.required is True
|
||||
assert param.schema_type == "string"
|
||||
|
||||
def test_parse_post_with_request_body(self) -> None:
|
||||
"""POST endpoint with request body is extracted."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_POST_SPEC)
|
||||
assert len(result) == 1
|
||||
ep = result[0]
|
||||
assert ep.method == "POST"
|
||||
assert ep.request_body_schema is not None
|
||||
assert ep.request_body_schema["type"] == "object"
|
||||
|
||||
def test_parse_path_and_query_params(self) -> None:
|
||||
"""Both path and query parameters are extracted."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_MULTI_PARAM_SPEC)
|
||||
ep = result[0]
|
||||
locations = {p.location for p in ep.parameters}
|
||||
assert "path" in locations
|
||||
assert "query" in locations
|
||||
|
||||
def test_autogenerate_operation_id_when_missing(self) -> None:
|
||||
"""Auto-generate operation_id when not provided in spec."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
spec = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test", "version": "1.0"},
|
||||
"paths": {
|
||||
"/things/{id}": {
|
||||
"get": {
|
||||
"summary": "Get thing",
|
||||
"description": "",
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
result = parse_endpoints(spec)
|
||||
ep = result[0]
|
||||
assert ep.operation_id != ""
|
||||
assert len(ep.operation_id) > 0
|
||||
|
||||
def test_multiple_endpoints_extracted(self) -> None:
|
||||
"""Multiple path+method combinations are all extracted."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_MULTI_ENDPOINT_SPEC)
|
||||
assert len(result) == 3
|
||||
methods = {ep.method for ep in result}
|
||||
assert "GET" in methods
|
||||
assert "POST" in methods
|
||||
assert "DELETE" in methods
|
||||
|
||||
def test_ref_in_response_schema_resolved(self) -> None:
|
||||
"""$ref in response schema is resolved to the target schema."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_REF_SPEC)
|
||||
ep = result[0]
|
||||
assert ep.response_schema is not None
|
||||
# Resolved ref should contain the properties
|
||||
assert "properties" in ep.response_schema or "$ref" not in ep.response_schema
|
||||
|
||||
def test_result_is_tuple(self) -> None:
|
||||
"""parse_endpoints returns a tuple (immutable)."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_GET_SPEC)
|
||||
assert isinstance(result, tuple)
|
||||
|
||||
def test_endpoint_info_is_frozen(self) -> None:
|
||||
"""EndpointInfo instances are frozen/immutable."""
|
||||
from app.openapi.parser import parse_endpoints
|
||||
|
||||
result = parse_endpoints(_GET_SPEC)
|
||||
ep = result[0]
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
ep.method = "POST" # type: ignore[misc]
|
||||
219
backend/tests/unit/openapi/test_review_api.py
Normal file
219
backend/tests/unit/openapi/test_review_api.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""Tests for OpenAPI review API endpoints.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_SAMPLE_URL = "https://example.com/api/spec.json"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create TestClient for the review API app."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.openapi.review_api import router
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def job_id(client):
|
||||
"""Create a job and return its ID."""
|
||||
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||
assert response.status_code == 202
|
||||
return response.json()["job_id"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def job_with_classifications(client, job_id):
|
||||
"""Return job_id for a job that has mock classifications injected."""
|
||||
from app.openapi.models import ClassificationResult, EndpointInfo
|
||||
from app.openapi.review_api import _job_store
|
||||
|
||||
ep = EndpointInfo(
|
||||
path="/orders",
|
||||
method="GET",
|
||||
operation_id="list_orders",
|
||||
summary="List orders",
|
||||
description="",
|
||||
)
|
||||
clf = ClassificationResult(
|
||||
endpoint=ep,
|
||||
access_type="read",
|
||||
customer_params=(),
|
||||
agent_group="read_agent",
|
||||
confidence=0.9,
|
||||
needs_interrupt=False,
|
||||
)
|
||||
# Inject classifications directly into the store
|
||||
job = _job_store[job_id]
|
||||
_job_store[job_id] = {**job, "classifications": [clf]}
|
||||
return job_id
|
||||
|
||||
|
||||
class TestImportEndpoint:
|
||||
"""Tests for POST /api/v1/openapi/import."""
|
||||
|
||||
def test_post_import_returns_job_id(self, client) -> None:
|
||||
"""POST /import returns 202 with a job_id."""
|
||||
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
assert len(data["job_id"]) > 0
|
||||
|
||||
def test_post_import_empty_url_returns_422(self, client) -> None:
|
||||
"""POST /import with empty URL returns 422 validation error."""
|
||||
response = client.post("/api/v1/openapi/import", json={"url": ""})
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_post_import_missing_url_returns_422(self, client) -> None:
|
||||
"""POST /import with missing URL field returns 422."""
|
||||
response = client.post("/api/v1/openapi/import", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_post_import_invalid_scheme_returns_422(self, client) -> None:
|
||||
"""POST /import with non-http URL returns 422."""
|
||||
response = client.post("/api/v1/openapi/import", json={"url": "ftp://evil.com/spec"})
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_post_import_returns_pending_status(self, client) -> None:
|
||||
"""Newly created job has pending status."""
|
||||
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||
data = response.json()
|
||||
assert data["status"] == "pending"
|
||||
|
||||
def test_post_import_returns_spec_url(self, client) -> None:
|
||||
"""Response includes the original spec URL."""
|
||||
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
|
||||
data = response.json()
|
||||
assert data["spec_url"] == _SAMPLE_URL
|
||||
|
||||
|
||||
class TestGetJobEndpoint:
|
||||
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
|
||||
|
||||
def test_get_job_returns_status(self, client, job_id) -> None:
|
||||
"""GET /jobs/{id} returns job status."""
|
||||
response = client.get(f"/api/v1/openapi/jobs/{job_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert "job_id" in data
|
||||
|
||||
def test_get_unknown_job_returns_404(self, client) -> None:
|
||||
"""GET /jobs/nonexistent returns 404."""
|
||||
response = client.get("/api/v1/openapi/jobs/nonexistent-id")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_job_includes_spec_url(self, client, job_id) -> None:
|
||||
"""Job response includes the spec URL."""
|
||||
response = client.get(f"/api/v1/openapi/jobs/{job_id}")
|
||||
data = response.json()
|
||||
assert data["spec_url"] == _SAMPLE_URL
|
||||
|
||||
|
||||
class TestGetClassificationsEndpoint:
|
||||
"""Tests for GET /api/v1/openapi/jobs/{job_id}/classifications."""
|
||||
|
||||
def test_get_classifications_returns_list(self, client, job_with_classifications) -> None:
|
||||
"""GET /classifications returns a list."""
|
||||
response = client.get(
|
||||
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
|
||||
def test_get_classifications_unknown_job_returns_404(self, client) -> None:
|
||||
"""GET /classifications for unknown job returns 404."""
|
||||
response = client.get("/api/v1/openapi/jobs/unknown/classifications")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_classification_has_expected_fields(self, client, job_with_classifications) -> None:
|
||||
"""Each classification item has access_type and endpoint fields."""
|
||||
response = client.get(
|
||||
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
|
||||
)
|
||||
item = response.json()[0]
|
||||
assert "access_type" in item
|
||||
assert "endpoint" in item
|
||||
assert "needs_interrupt" in item
|
||||
|
||||
|
||||
class TestUpdateClassificationEndpoint:
|
||||
"""Tests for PUT /api/v1/openapi/jobs/{job_id}/classifications/{idx}."""
|
||||
|
||||
def test_update_classification_succeeds(self, client, job_with_classifications) -> None:
|
||||
"""PUT /classifications/0 updates the classification."""
|
||||
response = client.put(
|
||||
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_update_unknown_job_returns_404(self, client) -> None:
|
||||
"""PUT /classifications/0 for unknown job returns 404."""
|
||||
response = client.put(
|
||||
"/api/v1/openapi/jobs/unknown/classifications/0",
|
||||
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
|
||||
"""PUT /classifications/0 with invalid access_type returns 422."""
|
||||
response = client.put(
|
||||
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
|
||||
"""PUT /classifications/0 with invalid agent_group returns 422."""
|
||||
response = client.put(
|
||||
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
|
||||
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
|
||||
"""PUT /classifications/999 returns 404 for out-of-range index."""
|
||||
response = client.put(
|
||||
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/999",
|
||||
json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestApproveEndpoint:
|
||||
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
|
||||
|
||||
def test_approve_job_succeeds(self, client, job_with_classifications) -> None:
|
||||
"""POST /approve transitions job to approved status."""
|
||||
response = client.post(
|
||||
f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_approve_unknown_job_returns_404(self, client) -> None:
|
||||
"""POST /approve for unknown job returns 404."""
|
||||
response = client.post("/api/v1/openapi/jobs/unknown/approve")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_approve_returns_job_status(self, client, job_with_classifications) -> None:
|
||||
"""POST /approve returns updated job status."""
|
||||
response = client.post(
|
||||
f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
|
||||
)
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
93
backend/tests/unit/openapi/test_validator.py
Normal file
93
backend/tests/unit/openapi/test_validator.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Tests for OpenAPI spec validator module.
|
||||
|
||||
RED phase: written before implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
_VALID_SPEC = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Test API", "version": "1.0.0"},
|
||||
"paths": {
|
||||
"/items": {
|
||||
"get": {
|
||||
"summary": "List items",
|
||||
"responses": {"200": {"description": "OK"}},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestValidateSpec:
|
||||
"""Tests for validate_spec function."""
|
||||
|
||||
def test_valid_minimal_spec_passes(self) -> None:
|
||||
"""A valid minimal spec returns empty error list."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
errors = validate_spec(_VALID_SPEC)
|
||||
assert errors == []
|
||||
|
||||
def test_missing_openapi_key_returns_error(self) -> None:
|
||||
"""Missing 'openapi' field returns an error."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
spec = {k: v for k, v in _VALID_SPEC.items() if k != "openapi"}
|
||||
errors = validate_spec(spec)
|
||||
assert len(errors) > 0
|
||||
assert any("openapi" in e.lower() for e in errors)
|
||||
|
||||
def test_missing_info_returns_error(self) -> None:
|
||||
"""Missing 'info' field returns an error."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
spec = {k: v for k, v in _VALID_SPEC.items() if k != "info"}
|
||||
errors = validate_spec(spec)
|
||||
assert len(errors) > 0
|
||||
assert any("info" in e.lower() for e in errors)
|
||||
|
||||
def test_missing_paths_returns_error(self) -> None:
|
||||
"""Missing 'paths' field returns an error."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
spec = {k: v for k, v in _VALID_SPEC.items() if k != "paths"}
|
||||
errors = validate_spec(spec)
|
||||
assert len(errors) > 0
|
||||
assert any("paths" in e.lower() for e in errors)
|
||||
|
||||
def test_non_dict_input_returns_error(self) -> None:
|
||||
"""Non-dict input returns an error without raising."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
errors = validate_spec("not a dict") # type: ignore[arg-type]
|
||||
assert len(errors) > 0
|
||||
|
||||
def test_empty_dict_returns_multiple_errors(self) -> None:
|
||||
"""Empty dict returns errors for all required fields."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
errors = validate_spec({})
|
||||
# Should have at least one error for each required field
|
||||
assert len(errors) >= 3
|
||||
|
||||
def test_invalid_openapi_version_returns_error(self) -> None:
|
||||
"""Unsupported openapi version string returns an error."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
spec = {**_VALID_SPEC, "openapi": "1.0.0"}
|
||||
errors = validate_spec(spec)
|
||||
assert len(errors) > 0
|
||||
|
||||
def test_errors_are_descriptive_strings(self) -> None:
|
||||
"""All returned errors are non-empty strings."""
|
||||
from app.openapi.validator import validate_spec
|
||||
|
||||
errors = validate_spec({})
|
||||
for e in errors:
|
||||
assert isinstance(e, str)
|
||||
assert len(e) > 0
|
||||
1
backend/tests/unit/replay/__init__.py
Normal file
1
backend/tests/unit/replay/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit tests for app.replay module."""
|
||||
217
backend/tests/unit/replay/test_api.py
Normal file
217
backend/tests/unit/replay/test_api.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Unit tests for app.replay.api."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api_utils import envelope
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
from app.replay.api import router
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def _http_exc(request, exc): # type: ignore[no-untyped-def]
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=envelope(None, success=False, error=exc.detail),
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _make_mock_pool(
|
||||
fetchall_result: list[dict],
|
||||
*,
|
||||
count: int | None = None,
|
||||
) -> MagicMock:
|
||||
"""Build a mock pool that returns the given rows from fetchall.
|
||||
|
||||
When *count* is provided, the first execute() call returns a cursor
|
||||
whose fetchone() yields ``(count,)`` (for the COUNT query) and the
|
||||
second call returns the rows via fetchall(). When *count* is None
|
||||
(the default), a single cursor backed by *fetchall_result* is used
|
||||
for all calls.
|
||||
"""
|
||||
if count is not None:
|
||||
count_cursor = AsyncMock()
|
||||
count_cursor.fetchone = AsyncMock(return_value=(count,))
|
||||
|
||||
rows_cursor = AsyncMock()
|
||||
rows_cursor.fetchall = AsyncMock(return_value=fetchall_result)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock(side_effect=[count_cursor, rows_cursor])
|
||||
else:
|
||||
mock_cursor = AsyncMock()
|
||||
mock_cursor.fetchall = AsyncMock(return_value=fetchall_result)
|
||||
mock_cursor.fetchone = AsyncMock(return_value=None)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute = AsyncMock(return_value=mock_cursor)
|
||||
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.connection.return_value = mock_ctx
|
||||
return mock_pool
|
||||
|
||||
|
||||
class TestListConversations:
|
||||
def test_returns_200_with_empty_list(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool([], count=0)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/conversations")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
data = body["data"]
|
||||
assert isinstance(data["conversations"], list)
|
||||
assert data["total"] == 0
|
||||
assert data["page"] == 1
|
||||
assert body["error"] is None
|
||||
|
||||
def test_returns_conversations_list(self) -> None:
|
||||
app = _build_app()
|
||||
mock_rows = [
|
||||
{
|
||||
"thread_id": "t1",
|
||||
"created_at": "2026-01-01T00:00:00",
|
||||
"last_activity": "2026-01-01T00:01:00",
|
||||
"status": "active",
|
||||
"total_tokens": 100,
|
||||
"total_cost_usd": 0.01,
|
||||
}
|
||||
]
|
||||
app.state.pool = _make_mock_pool(mock_rows, count=1)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/conversations")
|
||||
body = resp.json()
|
||||
assert resp.status_code == 200
|
||||
data = body["data"]
|
||||
assert len(data["conversations"]) == 1
|
||||
assert data["conversations"][0]["thread_id"] == "t1"
|
||||
assert data["total"] == 1
|
||||
|
||||
def test_pagination_defaults(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool([], count=0)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/conversations")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_pagination_custom_params(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool([], count=0)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/conversations?page=2&per_page=10")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_per_page_max_capped_at_100(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool([], count=0)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/conversations?per_page=200")
|
||||
# FastAPI Query(le=100) rejects values > 100
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
class TestGetReplay:
|
||||
def test_thread_not_found_returns_404(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool([])
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/replay/nonexistent-thread")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_returns_replay_page_for_existing_thread(self) -> None:
|
||||
app = _build_app()
|
||||
mock_rows = [
|
||||
{
|
||||
"thread_id": "thread-123",
|
||||
"checkpoint_id": "cp-001",
|
||||
"checkpoint": {
|
||||
"channel_values": {
|
||||
"messages": [{"type": "human", "content": "Hello"}]
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
}
|
||||
]
|
||||
app.state.pool = _make_mock_pool(mock_rows)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/replay/thread-123")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
assert body["data"]["thread_id"] == "thread-123"
|
||||
assert "steps" in body["data"]
|
||||
assert "total_steps" in body["data"]
|
||||
assert "page" in body["data"]
|
||||
assert "per_page" in body["data"]
|
||||
|
||||
def test_replay_pagination_params(self) -> None:
|
||||
app = _build_app()
|
||||
mock_rows = [
|
||||
{
|
||||
"thread_id": "t1",
|
||||
"checkpoint_id": "cp-001",
|
||||
"checkpoint": {
|
||||
"channel_values": {"messages": [{"type": "human", "content": "Hi"}]}
|
||||
},
|
||||
"metadata": {},
|
||||
}
|
||||
]
|
||||
app.state.pool = _make_mock_pool(mock_rows)
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/replay/t1?page=1&per_page=5")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_error_response_has_envelope(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool([])
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/replay/missing")
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert body["data"] is None
|
||||
assert body["error"] is not None
|
||||
|
||||
def test_invalid_thread_id_returns_400(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool([])
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/replay/id%20with%20spaces")
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_thread_id_special_chars_returns_400(self) -> None:
|
||||
app = _build_app()
|
||||
app.state.pool = _make_mock_pool([])
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/replay/id;DROP TABLE")
|
||||
assert resp.status_code == 400
|
||||
134
backend/tests/unit/replay/test_models.py
Normal file
134
backend/tests/unit/replay/test_models.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Unit tests for app.replay.models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestStepType:
|
||||
def test_all_step_types_exist(self) -> None:
|
||||
from app.replay.models import StepType
|
||||
|
||||
assert StepType.user_message
|
||||
assert StepType.supervisor_routing
|
||||
assert StepType.tool_call
|
||||
assert StepType.tool_result
|
||||
assert StepType.agent_response
|
||||
assert StepType.interrupt
|
||||
|
||||
def test_step_type_values(self) -> None:
|
||||
from app.replay.models import StepType
|
||||
|
||||
assert StepType.user_message.value == "user_message"
|
||||
assert StepType.tool_call.value == "tool_call"
|
||||
assert StepType.agent_response.value == "agent_response"
|
||||
|
||||
|
||||
class TestReplayStep:
|
||||
def test_minimal_replay_step(self) -> None:
|
||||
from app.replay.models import ReplayStep, StepType
|
||||
|
||||
step = ReplayStep(step=1, type=StepType.user_message, timestamp="2026-01-01T00:00:00Z")
|
||||
assert step.step == 1
|
||||
assert step.type == StepType.user_message
|
||||
assert step.timestamp == "2026-01-01T00:00:00Z"
|
||||
assert step.content == ""
|
||||
assert step.agent is None
|
||||
assert step.tool is None
|
||||
assert step.params is None
|
||||
assert step.result is None
|
||||
assert step.reasoning is None
|
||||
assert step.tokens is None
|
||||
assert step.duration_ms is None
|
||||
|
||||
def test_full_replay_step(self) -> None:
|
||||
from app.replay.models import ReplayStep, StepType
|
||||
|
||||
step = ReplayStep(
|
||||
step=2,
|
||||
type=StepType.tool_call,
|
||||
timestamp="2026-01-01T00:00:01Z",
|
||||
content="calling get_order",
|
||||
agent="order_agent",
|
||||
tool="get_order_status",
|
||||
params={"order_id": "ORD-123"},
|
||||
result={"status": "shipped"},
|
||||
reasoning="user asked about order",
|
||||
tokens=50,
|
||||
duration_ms=200,
|
||||
)
|
||||
assert step.step == 2
|
||||
assert step.agent == "order_agent"
|
||||
assert step.tool == "get_order_status"
|
||||
assert step.params == {"order_id": "ORD-123"}
|
||||
assert step.tokens == 50
|
||||
|
||||
def test_replay_step_is_frozen(self) -> None:
|
||||
from app.replay.models import ReplayStep, StepType
|
||||
|
||||
step = ReplayStep(step=1, type=StepType.user_message, timestamp="2026-01-01T00:00:00Z")
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
step.step = 99 # type: ignore[misc]
|
||||
|
||||
def test_replay_step_params_is_immutable_copy(self) -> None:
|
||||
from app.replay.models import ReplayStep, StepType
|
||||
|
||||
params = {"key": "value"}
|
||||
step = ReplayStep(
|
||||
step=1,
|
||||
type=StepType.tool_call,
|
||||
timestamp="2026-01-01T00:00:00Z",
|
||||
params=params,
|
||||
)
|
||||
# Modifying original dict should not affect step
|
||||
params["new_key"] = "new_value"
|
||||
assert "new_key" not in (step.params or {})
|
||||
|
||||
|
||||
class TestReplayPage:
|
||||
def test_replay_page_construction(self) -> None:
|
||||
from app.replay.models import ReplayPage, ReplayStep, StepType
|
||||
|
||||
steps = (
|
||||
ReplayStep(step=1, type=StepType.user_message, timestamp="2026-01-01T00:00:00Z"),
|
||||
ReplayStep(step=2, type=StepType.agent_response, timestamp="2026-01-01T00:00:01Z"),
|
||||
)
|
||||
page = ReplayPage(
|
||||
thread_id="thread-123",
|
||||
total_steps=2,
|
||||
page=1,
|
||||
per_page=20,
|
||||
steps=steps,
|
||||
)
|
||||
assert page.thread_id == "thread-123"
|
||||
assert page.total_steps == 2
|
||||
assert page.page == 1
|
||||
assert page.per_page == 20
|
||||
assert len(page.steps) == 2
|
||||
|
||||
def test_replay_page_is_frozen(self) -> None:
|
||||
from app.replay.models import ReplayPage
|
||||
|
||||
page = ReplayPage(
|
||||
thread_id="t1",
|
||||
total_steps=0,
|
||||
page=1,
|
||||
per_page=20,
|
||||
steps=(),
|
||||
)
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
page.page = 2 # type: ignore[misc]
|
||||
|
||||
def test_replay_page_empty_steps(self) -> None:
|
||||
from app.replay.models import ReplayPage
|
||||
|
||||
page = ReplayPage(
|
||||
thread_id="t1",
|
||||
total_steps=0,
|
||||
page=1,
|
||||
per_page=20,
|
||||
steps=(),
|
||||
)
|
||||
assert page.steps == ()
|
||||
257
backend/tests/unit/replay/test_transformer.py
Normal file
257
backend/tests/unit/replay/test_transformer.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Unit tests for app.replay.transformer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_row(messages: list[dict], metadata: dict | None = None) -> dict:
|
||||
"""Helper to build a checkpoint row with the given messages."""
|
||||
return {
|
||||
"thread_id": "thread-abc",
|
||||
"checkpoint_id": "cp-001",
|
||||
"checkpoint": {"channel_values": {"messages": messages}},
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
||||
|
||||
class TestTransformCheckpoints:
|
||||
def test_empty_rows_returns_empty_list(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
result = transform_checkpoints([])
|
||||
assert result == []
|
||||
|
||||
def test_human_message_produces_user_message_step(self) -> None:
|
||||
from app.replay.models import StepType
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [_make_row([{"type": "human", "content": "Hello, I need help"}])]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert len(steps) == 1
|
||||
assert steps[0].type == StepType.user_message
|
||||
assert steps[0].content == "Hello, I need help"
|
||||
assert steps[0].step == 1
|
||||
|
||||
def test_ai_message_with_content_produces_agent_response(self) -> None:
|
||||
from app.replay.models import StepType
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
_make_row(
|
||||
[{"type": "ai", "content": "I can help you with that.", "tool_calls": []}],
|
||||
metadata={"writes": {"some_agent": "response"}},
|
||||
)
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert len(steps) == 1
|
||||
assert steps[0].type == StepType.agent_response
|
||||
assert steps[0].content == "I can help you with that."
|
||||
|
||||
def test_ai_message_with_tool_calls_produces_tool_call_step(self) -> None:
|
||||
from app.replay.models import StepType
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
_make_row(
|
||||
[
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "get_order_status",
|
||||
"args": {"order_id": "ORD-123"},
|
||||
"id": "call_abc",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert len(steps) == 1
|
||||
assert steps[0].type == StepType.tool_call
|
||||
assert steps[0].tool == "get_order_status"
|
||||
assert steps[0].params == {"order_id": "ORD-123"}
|
||||
|
||||
def test_tool_message_produces_tool_result_step(self) -> None:
|
||||
from app.replay.models import StepType
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
_make_row(
|
||||
[
|
||||
{
|
||||
"type": "tool",
|
||||
"content": '{"status": "shipped"}',
|
||||
"name": "get_order_status",
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert len(steps) == 1
|
||||
assert steps[0].type == StepType.tool_result
|
||||
assert steps[0].tool == "get_order_status"
|
||||
|
||||
def test_multiple_messages_sequential_steps(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
_make_row(
|
||||
[
|
||||
{"type": "human", "content": "Help"},
|
||||
{"type": "ai", "content": "Sure!", "tool_calls": []},
|
||||
]
|
||||
)
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert len(steps) == 2
|
||||
assert steps[0].step == 1
|
||||
assert steps[1].step == 2
|
||||
|
||||
def test_unknown_message_type_skipped(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [_make_row([{"type": "unknown_type", "content": "test"}])]
|
||||
steps = transform_checkpoints(rows)
|
||||
# Should not crash; unknown types may be skipped
|
||||
assert isinstance(steps, list)
|
||||
|
||||
def test_row_missing_checkpoint_skipped(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [{"thread_id": "t1", "checkpoint_id": "cp1", "checkpoint": None, "metadata": {}}]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert isinstance(steps, list)
|
||||
|
||||
def test_row_missing_messages_key_skipped(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [{"thread_id": "t1", "checkpoint_id": "cp1", "checkpoint": {}, "metadata": {}}]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert isinstance(steps, list)
|
||||
|
||||
def test_multiple_rows_steps_are_continuous(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
_make_row([{"type": "human", "content": "Q1"}]),
|
||||
_make_row([{"type": "ai", "content": "A1", "tool_calls": []}]),
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert len(steps) == 2
|
||||
assert steps[0].step == 1
|
||||
assert steps[1].step == 2
|
||||
|
||||
def test_timestamps_are_strings(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [_make_row([{"type": "human", "content": "Hi"}])]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert isinstance(steps[0].timestamp, str)
|
||||
|
||||
def test_list_content_joined_to_string(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
_make_row(
|
||||
[
|
||||
{
|
||||
"type": "human",
|
||||
"content": [
|
||||
{"text": "Hello"},
|
||||
{"text": " world"},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert len(steps) == 1
|
||||
assert steps[0].content == "Hello world"
|
||||
|
||||
def test_checkpoint_as_string_skipped(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
{
|
||||
"thread_id": "t1",
|
||||
"checkpoint_id": "cp1",
|
||||
"checkpoint": "not-a-dict",
|
||||
"metadata": {},
|
||||
}
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert steps == []
|
||||
|
||||
def test_channel_values_not_dict_skipped(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
{
|
||||
"thread_id": "t1",
|
||||
"checkpoint_id": "cp1",
|
||||
"checkpoint": {"channel_values": "bad"},
|
||||
"metadata": {},
|
||||
}
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert steps == []
|
||||
|
||||
def test_tool_result_valid_json_parsed(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
_make_row(
|
||||
[
|
||||
{
|
||||
"type": "tool",
|
||||
"content": '{"order_id": "123", "status": "shipped"}',
|
||||
"name": "get_order_status",
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert len(steps) == 1
|
||||
assert steps[0].result == {"order_id": "123", "status": "shipped"}
|
||||
|
||||
def test_tool_result_invalid_json_wrapped(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
_make_row(
|
||||
[
|
||||
{
|
||||
"type": "tool",
|
||||
"content": "not valid json",
|
||||
"name": "some_tool",
|
||||
}
|
||||
]
|
||||
)
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
assert len(steps) == 1
|
||||
assert steps[0].result == {"raw": "not valid json"}
|
||||
|
||||
def test_malformed_message_skipped_gracefully(self) -> None:
|
||||
from app.replay.transformer import transform_checkpoints
|
||||
|
||||
rows = [
|
||||
_make_row(
|
||||
[
|
||||
{"type": "human", "content": "Good message"},
|
||||
42, # not a dict -- will raise in _step_from_message
|
||||
{"type": "ai", "content": "Response", "tool_calls": []},
|
||||
]
|
||||
)
|
||||
]
|
||||
steps = transform_checkpoints(rows)
|
||||
# The malformed message is skipped; the other two produce steps.
|
||||
assert len(steps) == 2
|
||||
assert steps[0].step == 1
|
||||
assert steps[1].step == 2
|
||||
@@ -7,10 +7,41 @@ import pytest
|
||||
from app.config import Settings
|
||||
|
||||
|
||||
def _isolated_settings(**kwargs: object) -> Settings:
|
||||
"""Create a Settings instance that ignores .env files and process env vars.
|
||||
|
||||
pydantic-settings reads from env_file and environment by default, which
|
||||
causes test results to depend on the machine they run on. We override
|
||||
model_config at the class level temporarily so that every test gets
|
||||
deterministic results.
|
||||
"""
|
||||
# Build a throwaway subclass that disables env-file and env-var loading.
|
||||
class _IsolatedSettings(Settings):
|
||||
model_config = Settings.model_config.copy()
|
||||
model_config["env_file"] = None # type: ignore[assignment]
|
||||
model_config["env_ignore_empty"] = True
|
||||
|
||||
# _env_parse_none_str makes pydantic-settings treat missing env vars as
|
||||
# absent rather than empty-string, so required fields will raise.
|
||||
import os
|
||||
|
||||
env_backup = os.environ.copy()
|
||||
# Strip all env vars that Settings knows about so they can't leak in.
|
||||
settings_fields = set(Settings.model_fields)
|
||||
for key in list(os.environ):
|
||||
if key.lower() in settings_fields:
|
||||
del os.environ[key]
|
||||
try:
|
||||
return _IsolatedSettings(**kwargs) # type: ignore[return-value]
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(env_backup)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSettings:
|
||||
def test_default_values(self) -> None:
|
||||
settings = Settings(
|
||||
settings = _isolated_settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
anthropic_api_key="key",
|
||||
)
|
||||
@@ -20,7 +51,7 @@ class TestSettings:
|
||||
assert settings.interrupt_ttl_minutes == 30
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
settings = Settings(
|
||||
settings = _isolated_settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="openai",
|
||||
llm_model="gpt-4o",
|
||||
@@ -33,18 +64,18 @@ class TestSettings:
|
||||
|
||||
def test_invalid_provider_rejected(self) -> None:
|
||||
with pytest.raises(Exception):
|
||||
Settings(
|
||||
_isolated_settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="invalid",
|
||||
)
|
||||
|
||||
def test_missing_database_url_rejected(self) -> None:
|
||||
with pytest.raises(Exception):
|
||||
Settings(anthropic_api_key="key")
|
||||
_isolated_settings(anthropic_api_key="key")
|
||||
|
||||
def test_empty_api_key_for_provider_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="API key"):
|
||||
Settings(
|
||||
_isolated_settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="anthropic",
|
||||
anthropic_api_key="",
|
||||
@@ -52,9 +83,27 @@ class TestSettings:
|
||||
|
||||
def test_wrong_provider_key_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="API key"):
|
||||
Settings(
|
||||
_isolated_settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="openai",
|
||||
anthropic_api_key="key",
|
||||
openai_api_key="",
|
||||
)
|
||||
|
||||
def test_azure_openai_missing_endpoint_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="AZURE_OPENAI_ENDPOINT"):
|
||||
_isolated_settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="azure_openai",
|
||||
azure_openai_api_key="key",
|
||||
azure_openai_deployment="my-deploy",
|
||||
)
|
||||
|
||||
def test_azure_openai_missing_deployment_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="AZURE_OPENAI_DEPLOYMENT"):
|
||||
_isolated_settings(
|
||||
database_url="postgresql://x:x@localhost/db",
|
||||
llm_provider="azure_openai",
|
||||
azure_openai_api_key="key",
|
||||
azure_openai_endpoint="https://example.openai.azure.com",
|
||||
)
|
||||
|
||||
156
backend/tests/unit/test_conversation_tracker.py
Normal file
156
backend/tests/unit/test_conversation_tracker.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Tests for app.conversation_tracker module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.conversation_tracker import (
|
||||
ConversationTrackerProtocol,
|
||||
NoOpConversationTracker,
|
||||
PostgresConversationTracker,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_pool() -> AsyncMock:
|
||||
"""Create a mock async connection pool."""
|
||||
pool = AsyncMock()
|
||||
conn = AsyncMock()
|
||||
conn.execute = AsyncMock()
|
||||
pool.connection = MagicMock(return_value=_AsyncContextManager(conn))
|
||||
return pool, conn
|
||||
|
||||
|
||||
class _AsyncContextManager:
|
||||
"""Async context manager helper."""
|
||||
|
||||
def __init__(self, value: object) -> None:
|
||||
self._value = value
|
||||
|
||||
async def __aenter__(self) -> object:
|
||||
return self._value
|
||||
|
||||
async def __aexit__(self, *args: object) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class TestConversationTrackerProtocol:
|
||||
def test_noop_satisfies_protocol(self) -> None:
|
||||
tracker = NoOpConversationTracker()
|
||||
assert isinstance(tracker, ConversationTrackerProtocol)
|
||||
|
||||
def test_postgres_satisfies_protocol(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
assert isinstance(tracker, ConversationTrackerProtocol)
|
||||
|
||||
|
||||
class TestNoOpConversationTracker:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_conversation_does_nothing(self) -> None:
|
||||
tracker = NoOpConversationTracker()
|
||||
pool = AsyncMock()
|
||||
# Should not raise
|
||||
await tracker.ensure_conversation(pool, "thread-1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_does_nothing(self) -> None:
|
||||
tracker = NoOpConversationTracker()
|
||||
pool = AsyncMock()
|
||||
await tracker.record_turn(pool, "thread-1", "agent_a", 100, 0.05)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_does_nothing(self) -> None:
|
||||
tracker = NoOpConversationTracker()
|
||||
pool = AsyncMock()
|
||||
await tracker.resolve(pool, "thread-1", "resolved")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accepts_none_agent_name(self) -> None:
|
||||
tracker = NoOpConversationTracker()
|
||||
pool = AsyncMock()
|
||||
await tracker.record_turn(pool, "thread-1", None, 0, 0.0)
|
||||
|
||||
|
||||
class TestPostgresConversationTracker:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_conversation_executes_insert(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.ensure_conversation(pool, "thread-abc")
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
sql, params = conn.execute.call_args[0]
|
||||
assert "INSERT" in sql
|
||||
assert "ON CONFLICT" in sql
|
||||
assert params["thread_id"] == "thread-abc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_executes_update(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.record_turn(pool, "thread-abc", "order_agent", 250, 0.12)
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
sql, params = conn.execute.call_args[0]
|
||||
assert "UPDATE" in sql
|
||||
assert params["thread_id"] == "thread-abc"
|
||||
assert params["agent_name"] == "order_agent"
|
||||
assert params["tokens"] == 250
|
||||
assert params["cost"] == 0.12
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_accepts_none_agent_name(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.record_turn(pool, "thread-abc", None, 0, 0.0)
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
sql, params = conn.execute.call_args[0]
|
||||
assert params["agent_name"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_executes_update(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.resolve(pool, "thread-abc", "resolved")
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
sql, params = conn.execute.call_args[0]
|
||||
assert "UPDATE" in sql
|
||||
assert params["thread_id"] == "thread-abc"
|
||||
assert params["resolution_type"] == "resolved"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_sets_ended_at(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.resolve(pool, "thread-abc", "escalated")
|
||||
|
||||
sql, params = conn.execute.call_args[0]
|
||||
assert "ended_at" in sql.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_conversation_with_special_thread_id(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.ensure_conversation(pool, "thread-123-abc-XYZ")
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_turn_with_zero_cost(self) -> None:
|
||||
tracker = PostgresConversationTracker()
|
||||
pool, conn = _make_pool()
|
||||
|
||||
await tracker.record_turn(pool, "t1", "agent", 0, 0.0)
|
||||
|
||||
conn.execute.assert_awaited_once()
|
||||
@@ -55,7 +55,7 @@ class TestDbModule:
|
||||
from app.db import setup_app_tables
|
||||
|
||||
await setup_app_tables(mock_pool)
|
||||
assert mock_conn.execute.await_count == 2
|
||||
assert mock_conn.execute.await_count == 5
|
||||
|
||||
def test_ddl_statements_valid(self) -> None:
|
||||
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL
|
||||
|
||||
55
backend/tests/unit/test_db_phase4.py
Normal file
55
backend/tests/unit/test_db_phase4.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Phase 4 DB migration tests -- analytics_events table and conversation columns."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestAnalyticsEventsDDL:
|
||||
def test_analytics_events_ddl_exists(self) -> None:
|
||||
from app.db import _ANALYTICS_EVENTS_DDL
|
||||
|
||||
assert "CREATE TABLE IF NOT EXISTS analytics_events" in _ANALYTICS_EVENTS_DDL
|
||||
|
||||
def test_analytics_events_ddl_has_required_columns(self) -> None:
|
||||
from app.db import _ANALYTICS_EVENTS_DDL
|
||||
|
||||
assert "thread_id" in _ANALYTICS_EVENTS_DDL
|
||||
assert "event_type" in _ANALYTICS_EVENTS_DDL
|
||||
assert "agent_name" in _ANALYTICS_EVENTS_DDL
|
||||
assert "tool_name" in _ANALYTICS_EVENTS_DDL
|
||||
assert "tokens_used" in _ANALYTICS_EVENTS_DDL
|
||||
assert "cost_usd" in _ANALYTICS_EVENTS_DDL
|
||||
assert "duration_ms" in _ANALYTICS_EVENTS_DDL
|
||||
assert "success" in _ANALYTICS_EVENTS_DDL
|
||||
assert "error_message" in _ANALYTICS_EVENTS_DDL
|
||||
assert "metadata" in _ANALYTICS_EVENTS_DDL
|
||||
|
||||
def test_conversations_migration_ddl_exists(self) -> None:
|
||||
from app.db import _CONVERSATIONS_MIGRATION_DDL
|
||||
|
||||
assert "ALTER TABLE" in _CONVERSATIONS_MIGRATION_DDL
|
||||
assert "resolution_type" in _CONVERSATIONS_MIGRATION_DDL
|
||||
assert "agents_used" in _CONVERSATIONS_MIGRATION_DDL
|
||||
assert "turn_count" in _CONVERSATIONS_MIGRATION_DDL
|
||||
assert "ended_at" in _CONVERSATIONS_MIGRATION_DDL
|
||||
assert "IF NOT EXISTS" in _CONVERSATIONS_MIGRATION_DDL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_app_tables_executes_analytics_ddl(self) -> None:
|
||||
mock_conn = AsyncMock()
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.connection.return_value = mock_ctx
|
||||
|
||||
from app.db import setup_app_tables
|
||||
|
||||
await setup_app_tables(mock_pool)
|
||||
# Now expects 5 statements: conversations, interrupts, sessions, analytics_events, migrations
|
||||
assert mock_conn.execute.await_count == 5
|
||||
79
backend/tests/unit/test_discount.py
Normal file
79
backend/tests/unit/test_discount.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for app.agents.discount module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.discount import apply_discount, generate_coupon
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestApplyDiscount:
|
||||
def test_invalid_discount_zero(self) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 0})
|
||||
assert result["status"] == "error"
|
||||
assert "Invalid" in result["message"]
|
||||
|
||||
def test_invalid_discount_over_100(self) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 101})
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_invalid_discount_negative(self) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": -5})
|
||||
assert result["status"] == "error"
|
||||
|
||||
@patch("app.agents.discount.interrupt", return_value=True)
|
||||
def test_approved_discount(self, mock_interrupt) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 10})
|
||||
assert result["status"] == "applied"
|
||||
assert result["discount_percent"] == 10
|
||||
assert "1042" in result["message"]
|
||||
|
||||
@patch("app.agents.discount.interrupt", return_value=False)
|
||||
def test_rejected_discount(self, mock_interrupt) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 10})
|
||||
assert result["status"] == "declined"
|
||||
|
||||
@patch("app.agents.discount.interrupt", return_value={"approved": True})
|
||||
def test_approved_via_dict(self, mock_interrupt) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 50})
|
||||
assert result["status"] == "applied"
|
||||
|
||||
@patch("app.agents.discount.interrupt", return_value={"approved": False})
|
||||
def test_rejected_via_dict(self, mock_interrupt) -> None:
|
||||
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 50})
|
||||
assert result["status"] == "declined"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGenerateCoupon:
|
||||
def test_valid_coupon(self) -> None:
|
||||
result = generate_coupon.invoke({"discount_percent": 15, "expiry_days": 7})
|
||||
assert result["status"] == "generated"
|
||||
assert result["discount_percent"] == 15
|
||||
assert result["expiry_days"] == 7
|
||||
assert result["coupon_code"].startswith("SAVE15-")
|
||||
|
||||
def test_default_expiry(self) -> None:
|
||||
result = generate_coupon.invoke({"discount_percent": 20})
|
||||
assert result["status"] == "generated"
|
||||
assert result["expiry_days"] == 30
|
||||
|
||||
def test_invalid_discount_zero(self) -> None:
|
||||
result = generate_coupon.invoke({"discount_percent": 0})
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_invalid_discount_over_100(self) -> None:
|
||||
result = generate_coupon.invoke({"discount_percent": 101})
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_invalid_expiry(self) -> None:
|
||||
result = generate_coupon.invoke({"discount_percent": 10, "expiry_days": 0})
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_coupon_codes_unique(self) -> None:
|
||||
r1 = generate_coupon.invoke({"discount_percent": 10})
|
||||
r2 = generate_coupon.invoke({"discount_percent": 10})
|
||||
assert r1["coupon_code"] != r2["coupon_code"]
|
||||
210
backend/tests/unit/test_edge_cases.py
Normal file
210
backend/tests/unit/test_edge_cases.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Edge case tests for ws_handler input validation and rate limiting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.callbacks import TokenUsageCallbackHandler
|
||||
from app.graph_context import GraphContext
|
||||
from app.session_manager import SessionManager
|
||||
from app.ws_context import WebSocketContext
|
||||
from app.ws_handler import dispatch_message
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _make_ws() -> AsyncMock:
|
||||
ws = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
def _make_graph() -> MagicMock:
|
||||
graph = AsyncMock()
|
||||
|
||||
class AsyncIterHelper:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
raise StopAsyncIteration
|
||||
|
||||
graph.astream = MagicMock(return_value=AsyncIterHelper())
|
||||
state = MagicMock()
|
||||
state.tasks = ()
|
||||
graph.aget_state = AsyncMock(return_value=state)
|
||||
return graph
|
||||
|
||||
|
||||
def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext:
|
||||
graph = _make_graph()
|
||||
registry = MagicMock()
|
||||
registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None)
|
||||
return WebSocketContext(
|
||||
graph_ctx=graph_ctx,
|
||||
session_manager=sm or SessionManager(),
|
||||
callback_handler=TokenUsageCallbackHandler(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEmptyMessageHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_message_content_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
sm = SessionManager()
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
msg_lower = call_data["message"].lower()
|
||||
assert "content" in msg_lower or "missing" in msg_lower
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_message_treated_as_empty(self) -> None:
|
||||
ws = _make_ws()
|
||||
sm = SessionManager()
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
|
||||
sm.touch("t1")
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestOversizedMessageHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_over_10000_chars_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
sm = SessionManager()
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
|
||||
sm.touch("t1")
|
||||
content = "x" * 10001
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too long" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
|
||||
ws = _make_ws()
|
||||
sm = SessionManager()
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
|
||||
sm.touch("t1")
|
||||
content = "x" * 10000
|
||||
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
|
||||
await dispatch_message(ws, ws_ctx, msg)
|
||||
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
# Should be processed, not an error about length
|
||||
msg_text = last_call.get("message", "").lower()
|
||||
assert last_call["type"] != "error" or "too long" not in msg_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_message_over_32kb_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
large_msg = "x" * 40_000
|
||||
await dispatch_message(ws, ws_ctx, large_msg)
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "too large" in call_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestInvalidJsonHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
await dispatch_message(ws, ws_ctx, "not valid json {{")
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
assert "invalid json" in call_data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_returns_json_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
await dispatch_message(ws, ws_ctx, "")
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_array_not_object_returns_error(self) -> None:
|
||||
ws = _make_ws()
|
||||
ws_ctx = _make_ws_ctx()
|
||||
|
||||
await dispatch_message(ws, ws_ctx, '["not", "an", "object"]')
|
||||
|
||||
call_data = ws.send_json.call_args[0][0]
|
||||
assert call_data["type"] == "error"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRateLimiting:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_fire_messages_rate_limited(self) -> None:
|
||||
ws = _make_ws()
|
||||
sm = SessionManager()
|
||||
|
||||
sm.touch("t1")
|
||||
|
||||
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
|
||||
rate_limit_triggered = False
|
||||
for i in range(11):
|
||||
ws_ctx = _make_ws_ctx(sm=sm)
|
||||
await dispatch_message(ws, ws_ctx, json.dumps({
|
||||
"type": "message",
|
||||
"thread_id": "t1",
|
||||
"content": f"message {i}",
|
||||
}))
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
if last_call["type"] == "error" and "rate" in last_call.get("message", "").lower():
|
||||
rate_limit_triggered = True
|
||||
break
|
||||
|
||||
assert rate_limit_triggered, "Rate limiting should trigger after 10 rapid messages"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_threads_have_separate_rate_limits(self) -> None:
|
||||
ws = _make_ws()
|
||||
sm = SessionManager()
|
||||
|
||||
sm.touch("t1")
|
||||
sm.touch("t2")
|
||||
|
||||
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
|
||||
for i in range(5):
|
||||
ws_ctx1 = _make_ws_ctx(sm=sm)
|
||||
ws_ctx2 = _make_ws_ctx(sm=sm)
|
||||
await dispatch_message(ws, ws_ctx1, json.dumps({
|
||||
"type": "message", "thread_id": "t1", "content": f"msg {i}",
|
||||
}))
|
||||
await dispatch_message(ws, ws_ctx2, json.dumps({
|
||||
"type": "message", "thread_id": "t2", "content": f"msg {i}",
|
||||
}))
|
||||
|
||||
last_call = ws.send_json.call_args[0][0]
|
||||
assert "rate" not in last_call.get("message", "").lower()
|
||||
175
backend/tests/unit/test_error_handler.py
Normal file
175
backend/tests/unit/test_error_handler.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Tests for app.tools.error_handler module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.tools.error_handler import (
|
||||
ErrorCategory,
|
||||
classify_error,
|
||||
with_retry,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestErrorClassification:
|
||||
def test_timeout_exception_is_timeout(self) -> None:
|
||||
exc = httpx.TimeoutException("timed out")
|
||||
assert classify_error(exc) == ErrorCategory.TIMEOUT
|
||||
|
||||
def test_connect_error_is_network(self) -> None:
|
||||
exc = httpx.ConnectError("connection refused")
|
||||
assert classify_error(exc) == ErrorCategory.NETWORK
|
||||
|
||||
def test_401_is_auth_failure(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(401, request=request)
|
||||
exc = httpx.HTTPStatusError("401", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.AUTH_FAILURE
|
||||
|
||||
def test_403_is_auth_failure(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(403, request=request)
|
||||
exc = httpx.HTTPStatusError("403", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.AUTH_FAILURE
|
||||
|
||||
def test_429_is_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(429, request=request)
|
||||
exc = httpx.HTTPStatusError("429", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||
|
||||
def test_500_is_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(500, request=request)
|
||||
exc = httpx.HTTPStatusError("500", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||
|
||||
def test_502_is_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(502, request=request)
|
||||
exc = httpx.HTTPStatusError("502", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||
|
||||
def test_503_is_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(503, request=request)
|
||||
exc = httpx.HTTPStatusError("503", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.RETRYABLE
|
||||
|
||||
def test_404_is_non_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(404, request=request)
|
||||
exc = httpx.HTTPStatusError("404", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||
|
||||
def test_400_is_non_retryable(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(400, request=request)
|
||||
exc = httpx.HTTPStatusError("400", request=request, response=response)
|
||||
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||
|
||||
def test_generic_exception_is_non_retryable(self) -> None:
|
||||
exc = ValueError("bad value")
|
||||
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||
|
||||
def test_runtime_error_is_non_retryable(self) -> None:
|
||||
exc = RuntimeError("boom")
|
||||
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
|
||||
|
||||
|
||||
class TestWithRetry:
|
||||
@pytest.mark.asyncio
|
||||
async def test_succeeds_on_first_try(self) -> None:
|
||||
fn = AsyncMock(return_value="ok")
|
||||
result = await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
assert result == "ok"
|
||||
assert fn.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_retryable_error(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(503, request=request)
|
||||
retryable_exc = httpx.HTTPStatusError("503", request=request, response=response)
|
||||
|
||||
fn = AsyncMock(side_effect=[retryable_exc, retryable_exc, "success"])
|
||||
|
||||
with patch("app.tools.error_handler.asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
|
||||
assert result == "success"
|
||||
assert fn.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_retry_non_retryable_error(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(404, request=request)
|
||||
non_retryable_exc = httpx.HTTPStatusError("404", request=request, response=response)
|
||||
|
||||
fn = AsyncMock(side_effect=non_retryable_exc)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
|
||||
assert fn.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_retry_auth_failure(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(401, request=request)
|
||||
auth_exc = httpx.HTTPStatusError("401", request=request, response=response)
|
||||
|
||||
fn = AsyncMock(side_effect=auth_exc)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
|
||||
assert fn.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_after_max_retries_exhausted(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(500, request=request)
|
||||
retryable_exc = httpx.HTTPStatusError("500", request=request, response=response)
|
||||
|
||||
fn = AsyncMock(side_effect=retryable_exc)
|
||||
|
||||
with (
|
||||
patch("app.tools.error_handler.asyncio.sleep", new_callable=AsyncMock),
|
||||
pytest.raises(httpx.HTTPStatusError),
|
||||
):
|
||||
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
|
||||
assert fn.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_retry_timeout(self) -> None:
|
||||
"""TimeoutException is TIMEOUT category -- not retried by default."""
|
||||
fn = AsyncMock(side_effect=httpx.TimeoutException("timed out"))
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
await with_retry(fn, max_retries=3, base_delay=0.0)
|
||||
|
||||
assert fn.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exponential_backoff_increases_delay(self) -> None:
|
||||
request = httpx.Request("GET", "http://example.com")
|
||||
response = httpx.Response(503, request=request)
|
||||
retryable_exc = httpx.HTTPStatusError("503", request=request, response=response)
|
||||
|
||||
fn = AsyncMock(side_effect=[retryable_exc, retryable_exc, "done"])
|
||||
sleep_delays: list[float] = []
|
||||
|
||||
async def capture_sleep(delay: float) -> None:
|
||||
sleep_delays.append(delay)
|
||||
|
||||
with patch("app.tools.error_handler.asyncio.sleep", side_effect=capture_sleep):
|
||||
await with_retry(fn, max_retries=3, base_delay=1.0)
|
||||
|
||||
assert len(sleep_delays) == 2
|
||||
assert sleep_delays[1] > sleep_delays[0]
|
||||
142
backend/tests/unit/test_error_responses.py
Normal file
142
backend/tests/unit/test_error_responses.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Tests for standardized error response envelope format."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api_utils import envelope
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def _build_test_app() -> FastAPI:
|
||||
"""Build a minimal FastAPI app with the standard exception handlers."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=envelope(None, success=False, error=exc.detail),
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=envelope(None, success=False, error=str(exc)),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=envelope(None, success=False, error="Internal server error"),
|
||||
)
|
||||
|
||||
class ItemRequest(BaseModel):
|
||||
name: str = Field(..., min_length=1)
|
||||
count: int = Field(..., gt=0)
|
||||
|
||||
@app.get("/items/{item_id}")
|
||||
def get_item(item_id: int) -> dict:
|
||||
if item_id == 0:
|
||||
raise HTTPException(status_code=400, detail="Invalid item ID")
|
||||
if item_id == 999:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
if item_id == 401:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return envelope({"id": item_id, "name": "test"})
|
||||
|
||||
@app.post("/items")
|
||||
def create_item(item: ItemRequest) -> dict:
|
||||
return envelope({"id": 1, "name": item.name})
|
||||
|
||||
@app.get("/crash")
|
||||
def crash() -> dict:
|
||||
msg = "unexpected failure"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class TestHttpExceptionEnvelope:
|
||||
"""HTTPException responses use the standard envelope format."""
|
||||
|
||||
def test_400_returns_envelope(self) -> None:
|
||||
app = _build_test_app()
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.get("/items/0")
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert body["data"] is None
|
||||
assert body["error"] == "Invalid item ID"
|
||||
|
||||
def test_404_returns_envelope(self) -> None:
|
||||
app = _build_test_app()
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.get("/items/999")
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert body["data"] is None
|
||||
assert body["error"] == "Item not found"
|
||||
|
||||
def test_401_returns_envelope(self) -> None:
|
||||
app = _build_test_app()
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.get("/items/401")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert body["data"] is None
|
||||
assert body["error"] == "Not authenticated"
|
||||
|
||||
|
||||
class TestValidationErrorEnvelope:
|
||||
"""Validation errors return 422 with envelope format."""
|
||||
|
||||
def test_validation_error_returns_envelope(self) -> None:
|
||||
app = _build_test_app()
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.post("/items", json={"name": "", "count": -1})
|
||||
assert resp.status_code == 422
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert body["data"] is None
|
||||
assert isinstance(body["error"], str)
|
||||
assert len(body["error"]) > 0
|
||||
|
||||
|
||||
class TestGeneralExceptionEnvelope:
|
||||
"""Unhandled exceptions return 500 with safe envelope."""
|
||||
|
||||
def test_unhandled_exception_returns_500_envelope(self) -> None:
|
||||
app = _build_test_app()
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.get("/crash")
|
||||
assert resp.status_code == 500
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert body["data"] is None
|
||||
assert body["error"] == "Internal server error"
|
||||
|
||||
|
||||
class TestSuccessResponseUnchanged:
|
||||
"""Success responses still work normally."""
|
||||
|
||||
def test_success_returns_envelope(self) -> None:
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/items/42")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["success"] is True
|
||||
assert body["data"]["id"] == 42
|
||||
assert body["error"] is None
|
||||
169
backend/tests/unit/test_escalation.py
Normal file
169
backend/tests/unit/test_escalation.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Tests for app.escalation module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.escalation import (
|
||||
EscalationPayload,
|
||||
EscalationResult,
|
||||
NoOpEscalator,
|
||||
WebhookEscalator,
|
||||
)
|
||||
|
||||
|
||||
def _make_payload(**kwargs) -> EscalationPayload:
|
||||
defaults = {
|
||||
"thread_id": "t1",
|
||||
"reason": "Agent cannot resolve",
|
||||
"conversation_summary": "User asked about refund policy",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return EscalationPayload(**defaults)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEscalationPayload:
|
||||
def test_frozen(self) -> None:
|
||||
payload = _make_payload()
|
||||
with pytest.raises(Exception):
|
||||
payload.thread_id = "t2" # type: ignore[misc]
|
||||
|
||||
def test_default_metadata(self) -> None:
|
||||
payload = _make_payload()
|
||||
assert payload.metadata == {}
|
||||
|
||||
def test_model_dump(self) -> None:
|
||||
payload = _make_payload(metadata={"key": "val"})
|
||||
data = payload.model_dump()
|
||||
assert data["thread_id"] == "t1"
|
||||
assert data["metadata"] == {"key": "val"}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEscalationResult:
|
||||
def test_frozen(self) -> None:
|
||||
result = EscalationResult(success=True, status_code=200, attempts=1, error=None)
|
||||
assert result.success
|
||||
assert result.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestWebhookEscalator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_url_returns_failure(self) -> None:
|
||||
escalator = WebhookEscalator(url="", max_retries=3)
|
||||
result = await escalator.escalate(_make_payload())
|
||||
assert not result.success
|
||||
assert result.attempts == 0
|
||||
assert "not configured" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_post(self) -> None:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("app.escalation.httpx.AsyncClient", return_value=mock_client):
|
||||
escalator = WebhookEscalator(url="https://example.com/hook")
|
||||
result = await escalator.escalate(_make_payload())
|
||||
|
||||
assert result.success
|
||||
assert result.status_code == 200
|
||||
assert result.attempts == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_on_server_error(self) -> None:
|
||||
fail_response = AsyncMock()
|
||||
fail_response.status_code = 500
|
||||
success_response = AsyncMock()
|
||||
success_response.status_code = 200
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(side_effect=[fail_response, fail_response, success_response])
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=3)
|
||||
result = await escalator.escalate(_make_payload())
|
||||
|
||||
assert result.success
|
||||
assert result.attempts == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_retries_exhausted(self) -> None:
|
||||
fail_response = AsyncMock()
|
||||
fail_response.status_code = 500
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=fail_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=3)
|
||||
result = await escalator.escalate(_make_payload())
|
||||
|
||||
assert not result.success
|
||||
assert result.attempts == 3
|
||||
assert "500" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_error(self) -> None:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("timeout"))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=2)
|
||||
result = await escalator.escalate(_make_payload())
|
||||
|
||||
assert not result.success
|
||||
assert "timed out" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_error(self) -> None:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(
|
||||
side_effect=httpx.RequestError("connection refused")
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
|
||||
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=1)
|
||||
result = await escalator.escalate(_make_payload())
|
||||
|
||||
assert not result.success
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestNoOpEscalator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_disabled(self) -> None:
|
||||
escalator = NoOpEscalator()
|
||||
result = await escalator.escalate(_make_payload())
|
||||
assert not result.success
|
||||
assert result.attempts == 0
|
||||
assert "disabled" in result.error.lower()
|
||||
@@ -6,8 +6,11 @@ from typing import TYPE_CHECKING
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.graph import SUPERVISOR_PROMPT, build_agent_nodes, build_graph
|
||||
from app.graph import build_agent_nodes, build_graph
|
||||
from app.graph_context import GraphContext
|
||||
from app.intent import ClassificationResult, IntentTarget
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.registry import AgentRegistry
|
||||
@@ -33,12 +36,59 @@ class TestBuildGraph:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||
mock_checkpointer = AsyncMock()
|
||||
checkpointer = InMemorySaver()
|
||||
|
||||
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
|
||||
assert graph is not None
|
||||
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
|
||||
assert graph_ctx is not None
|
||||
assert graph_ctx.graph is not None
|
||||
|
||||
def test_supervisor_prompt_contains_routing_info(self) -> None:
|
||||
assert "order_lookup" in SUPERVISOR_PROMPT
|
||||
assert "order_actions" in SUPERVISOR_PROMPT
|
||||
assert "fallback" in SUPERVISOR_PROMPT
|
||||
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||
checkpointer = InMemorySaver()
|
||||
mock_classifier = MagicMock()
|
||||
|
||||
graph_ctx = build_graph(
|
||||
sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier
|
||||
)
|
||||
assert graph_ctx.intent_classifier is mock_classifier
|
||||
assert graph_ctx.registry is sample_registry
|
||||
|
||||
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
|
||||
checkpointer = InMemorySaver()
|
||||
|
||||
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
|
||||
assert graph_ctx.intent_classifier is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestClassifyIntent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_without_classifier(self) -> None:
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(graph=MagicMock(), registry=mock_registry, intent_classifier=None)
|
||||
result = await graph_ctx.classify_intent("hello")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_classifier(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.9, reasoning="test"),),
|
||||
)
|
||||
mock_classifier = AsyncMock()
|
||||
mock_classifier.classify = AsyncMock(return_value=expected)
|
||||
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.list_agents = MagicMock(return_value=())
|
||||
graph_ctx = GraphContext(
|
||||
graph=MagicMock(), registry=mock_registry, intent_classifier=mock_classifier,
|
||||
)
|
||||
|
||||
result = await graph_ctx.classify_intent("check order")
|
||||
assert result is not None
|
||||
assert result.intents[0].agent_name == "order_lookup"
|
||||
|
||||
177
backend/tests/unit/test_intent.py
Normal file
177
backend/tests/unit/test_intent.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Tests for app.intent module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.intent import (
|
||||
AMBIGUITY_THRESHOLD,
|
||||
ClassificationResult,
|
||||
IntentTarget,
|
||||
LLMIntentClassifier,
|
||||
_build_agent_list,
|
||||
)
|
||||
from app.registry import AgentConfig
|
||||
|
||||
|
||||
def _make_agent(name: str, desc: str = "test", perm: str = "read") -> AgentConfig:
|
||||
return AgentConfig(
|
||||
name=name, description=desc, permission=perm, tools=["fallback_respond"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestIntentModels:
|
||||
def test_intent_target_frozen(self) -> None:
|
||||
target = IntentTarget(agent_name="order_lookup", confidence=0.9, reasoning="order query")
|
||||
with pytest.raises(Exception):
|
||||
target.agent_name = "other" # type: ignore[misc]
|
||||
|
||||
def test_classification_result_frozen(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="a", confidence=0.9, reasoning="r"),),
|
||||
)
|
||||
assert not result.is_ambiguous
|
||||
assert result.clarification_question is None
|
||||
|
||||
def test_classification_result_ambiguous(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(),
|
||||
is_ambiguous=True,
|
||||
clarification_question="What do you mean?",
|
||||
)
|
||||
assert result.is_ambiguous
|
||||
|
||||
def test_multi_intent(self) -> None:
|
||||
result = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_actions", confidence=0.85, reasoning="cancel"),
|
||||
IntentTarget(agent_name="discount", confidence=0.8, reasoning="discount"),
|
||||
),
|
||||
)
|
||||
assert len(result.intents) == 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildAgentList:
|
||||
def test_formats_agents(self) -> None:
|
||||
agents = (
|
||||
_make_agent("order_lookup", "Looks up orders", "read"),
|
||||
_make_agent("order_actions", "Modifies orders", "write"),
|
||||
)
|
||||
text = _build_agent_list(agents)
|
||||
assert "order_lookup" in text
|
||||
assert "order_actions" in text
|
||||
assert "read" in text
|
||||
assert "write" in text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLLMIntentClassifier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_intent_classification(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("order_lookup"), _make_agent("fallback"))
|
||||
|
||||
result = await classifier.classify("What is order 1042 status?", agents)
|
||||
assert len(result.intents) == 1
|
||||
assert result.intents[0].agent_name == "order_lookup"
|
||||
assert not result.is_ambiguous
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_intent_classification(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
|
||||
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
|
||||
),
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("order_actions"), _make_agent("discount"), _make_agent("fallback"))
|
||||
|
||||
result = await classifier.classify("Cancel order 1042 and give me 10% off", agents)
|
||||
assert len(result.intents) == 2
|
||||
assert not result.is_ambiguous
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_classification(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),),
|
||||
is_ambiguous=False,
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("order_lookup"), _make_agent("fallback"))
|
||||
|
||||
result = await classifier.classify("hmm", agents)
|
||||
# Low confidence triggers ambiguity
|
||||
assert result.is_ambiguous
|
||||
assert result.clarification_question is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_error_returns_ambiguous(self) -> None:
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(side_effect=RuntimeError("LLM error"))
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("fallback"),)
|
||||
|
||||
result = await classifier.classify("test", agents)
|
||||
assert result.is_ambiguous
|
||||
assert result.clarification_question is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_result_type_returns_ambiguous(self) -> None:
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value="not a ClassificationResult")
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("fallback"),)
|
||||
|
||||
result = await classifier.classify("test", agents)
|
||||
assert result.is_ambiguous
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_confidence_not_ambiguous(self) -> None:
|
||||
expected = ClassificationResult(
|
||||
intents=(
|
||||
IntentTarget(
|
||||
agent_name="order_lookup",
|
||||
confidence=AMBIGUITY_THRESHOLD + 0.1,
|
||||
reasoning="clear",
|
||||
),
|
||||
),
|
||||
)
|
||||
mock_structured = MagicMock()
|
||||
mock_structured.ainvoke = AsyncMock(return_value=expected)
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
|
||||
|
||||
classifier = LLMIntentClassifier(mock_llm)
|
||||
agents = (_make_agent("order_lookup"),)
|
||||
|
||||
result = await classifier.classify("order status 1042", agents)
|
||||
assert not result.is_ambiguous
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user