Compare commits

..

22 Commits

Author SHA1 Message Date
Yaojia Wang
f0699436c5 refactor: engineering improvements -- API versioning, structured logging, Alembic, error standardization, test coverage
- API versioning: all REST endpoints prefixed with /api/v1/
- Structured logging: replaced stdlib logging with structlog (console/JSON modes)
- Alembic migrations: versioned DB schema with initial migration
- Error standardization: global exception handlers for consistent envelope format
- Interrupt cleanup: asyncio background task for expired interrupt removal
- Integration tests: +30 tests (analytics, replay, openapi, error, session APIs)
- Frontend tests: +57 tests (all components, pages, useWebSocket hook)
- Backend: 557 tests, 89.75% coverage | Frontend: 80 tests, 16 test files
2026-04-06 23:19:29 +02:00
Yaojia Wang
af53111928 refactor: fix architectural issues across frontend and backend
Address all architecture review findings:

P0 fixes:
- Add API key authentication for admin endpoints (analytics, replay, openapi)
  and WebSocket connections via ADMIN_API_KEY env var
- Add PostgreSQL-backed PgSessionManager and PgInterruptManager for
  multi-worker production deployments (in-memory defaults preserved)

P1 fixes:
- Implement actual tool generation in OpenAPI approve_job endpoint
  using generate_tool_code() and generate_agent_yaml()
- Add missing clarification, interrupt_expired, and tool_result message
  handlers in frontend ChatPage

P2 fixes:
- Replace monkey-patching on CompiledStateGraph with typed GraphContext
- Replace 9-param dispatch_message with WebSocketContext dataclass
- Extract duplicate _envelope() into shared app/api_utils.py
- Replace mutable module-level counter with crypto.randomUUID()
- Remove hardcoded mock data from ReviewPage, use api.ts wrappers
- Remove `as any` type escape from ReplayPage

All 516 tests passing, 0 TypeScript errors.
2026-04-06 15:59:14 +02:00
Yaojia Wang
b8654aa31f feat: upgrade LangGraph to 1.x and migrate deprecated APIs
- Bump langgraph from 0.4 to 1.0+, langgraph-supervisor from 0.0.12 to 0.0.30+
- Bump langchain-core, langchain-anthropic, langchain-openai to 1.x
- Add langchain>=1.0 dependency for new create_agent location
- Migrate create_react_agent -> create_agent (prompt -> system_prompt)
- Fix create_supervisor positional arg to named agents= parameter
- Replace AsyncMock checkpointer with InMemorySaver in tests (v1 type validation)
- Update version references in README, ARCHITECTURE, eng-review-plan
2026-04-06 14:51:51 +02:00
Yaojia Wang
be5c84bcff docs: reconcile README and docs with actual codebase
README:
- Remove duplicated agent config, safety, security sections (covered by docs)
- Add ux_design_system.md and safety.py to project structure and doc links
- Convert doc links to descriptive table

agent-config-guide.md:
- Replace fictional agents/tools with real ones from agents.yaml
- Remove nonexistent 'admin' permission level (only read/write)
- Fix template names (e-commerce, saas, fintech)
- List all available built-in tools

openapi-import-guide.md:
- Fix /result -> /classifications endpoint
- Fix POST /approve to show no request body
- Remove nonexistent 'admin' access type
- Update response examples to match actual API

demo-script.md:
- Fix agent names (order_agent -> order_lookup)
- Replace fictional refund scenario with real lookup+cancel flow

ARCHITECTURE.md:
- Fix langgraph-supervisor version (v1.1 -> 0.0.12+)

docker-compose.yml:
- Expose postgres on port 5433 for local dev
2026-04-06 13:55:45 +02:00
Yaojia Wang
19fc9f3289 test: close coverage gaps and add frontend test infrastructure
Backend (516 tests, 94% coverage):
- Add azure_openai endpoint/deployment validation tests (config.py -> 100%)
- Add _total_conversations and _avg_turns direct tests (queries.py -> 100%)
- Add transformer edge cases: list content, string checkpoint, invalid JSON,
  malformed message graceful skip (transformer.py -> 93%)
- Add safety combined status_code+error_message interaction tests
- Fix ambiguous 200/422 assertion to strict 422
- Add E2E pagination shape assertions (total, page, per_page, row count)
- Fix ReplayPool mock to respect LIMIT/OFFSET params

Frontend (23 tests, vitest + happy-dom + @testing-library/react):
- Add vitest infrastructure with happy-dom environment
- Add api.ts tests: success, HTTP error, success=false, URL encoding
- Add DashboardPage tests: loading, data, error, empty states
- Add ReplayListPage tests: loading, empty, data, error, status badge classes
- Add ReplayPage tests: loading, steps, empty, error states
2026-04-06 13:32:10 +02:00
Yaojia Wang
036e12349d refactor: formalize safety rules, extract shared styles, reconcile docs (P2)
- Add backend/app/safety.py with explicit confirmation policy, multi-intent
  semantics, and MCP error taxonomy with retry classification
- Add 26 unit tests for safety module (confirmation rules, error taxonomy)
- Extract repeated inline styles into shared CSS classes in index.css
  (section-card, stat-label, status-badge, data-table, empty/error-state,
  pagination-bar)
- Refactor DashboardPage, ReplayListPage, ReplayPage to use shared classes
- Update README: add missing API endpoints, document safety/confirmation rules
- Use proper HTML entities for arrow/dash characters to fix encoding glitches
2026-04-05 23:10:50 +02:00
Yaojia Wang
e0931daece feat: wire frontend pages to live APIs and standardize response contracts (P1)
- Backend: Add COUNT query and paginated response shape to conversations endpoint
  Returns { conversations: [...], total, page, per_page } instead of flat array
- Frontend: Replace mock data in DashboardPage with fetchAnalytics() API calls
- Frontend: Replace mock data in ReplayListPage with fetchConversations() API calls
- Frontend: Replace mock data in ReplayPage with fetchReplay() API calls
- Add proper loading, empty, and error states to all three pages
- Align ConversationSummary type with actual DB columns (created_at, status)
- Update unit and E2E tests for new paginated conversation response shape
- Add fetchone() to FakeCursor for COUNT query support in E2E tests
2026-04-05 23:06:00 +02:00
Yaojia Wang
e55ec42ae5 fix: restore green builds and align frontend-backend contracts (P0)
- Isolate Settings tests from .env and process env leakage
- Fix analytics metadata test to unwrap psycopg Json wrapper
- Remove unused state variables causing frontend build failures
- Fix ReviewPage to use /classifications endpoint instead of nonexistent /result
- Normalize ReviewPage status enums (failed not error) and access_type values
- Align api.ts types with backend response shapes (ReplayPage, AnalyticsData, AgentUsage)
2026-04-05 23:00:39 +02:00
Yaojia Wang
189a0fad34 feat(ui): implement premium beige design system and ux refinements 2026-04-05 22:35:48 +02:00
Yaojia Wang
d2b4610df9 fix: address code and security review findings for Phase 5
- Add nginx security headers (X-Frame-Options, X-Content-Type-Options, etc.)
- Fix postgres networking: add to app_network, comment out host port exposure
- Fix rate limit memory leak: add bounded eviction for stale thread entries
- Use immutable update pattern in rate limit check (no .append mutation)
- Extract _VERSION constant to avoid duplicate hardcoded version string
2026-03-31 21:35:13 +02:00
Yaojia Wang
0e78e5b06b feat: complete phase 5 -- error hardening, frontend, Docker, demo, docs
Backend:
- ConversationTracker: Protocol + PostgresConversationTracker for lifecycle tracking
- Error handler: ErrorCategory enum, classify_error(), with_retry() exponential backoff
- Wire PostgresAnalyticsRecorder + ConversationTracker into ws_handler
- Rate limiting (10 msg/10s per thread), edge case hardening
- Health endpoint GET /api/health, version 0.5.0
- Demo seed data script + sample OpenAPI spec

Frontend (all new):
- React Router with NavBar (Chat / Replay / Dashboard / Review)
- ReplayListPage + ReplayPage with ReplayTimeline component
- DashboardPage with MetricCard, range selector, zero-state
- ReviewPage for OpenAPI classification review
- ErrorBanner for WebSocket disconnect handling
- API client (api.ts) with typed fetch wrappers

Infrastructure:
- Frontend Dockerfile (multi-stage node -> nginx)
- nginx.conf with SPA routing + API/WS proxy
- docker-compose.yml with frontend service + healthchecks
- .env.example files (root + backend)

Documentation:
- README.md with quick start and architecture
- Agent configuration guide
- OpenAPI import guide
- Deployment guide
- Demo script

48 new tests, 449 total passing, 92.87% coverage
2026-03-31 21:20:06 +02:00
Yaojia Wang
38644594d2 test: add thread_id validation tests for replay API
- Test invalid thread_id with spaces returns 400
- Test thread_id with special chars returns 400
- Tighten existing 404 test assertion
2026-03-31 13:44:04 +02:00
Yaojia Wang
ef6e5ac2be fix: address security findings in Phase 4 analytics and replay
- Fix CRITICAL: use parameterized INTERVAL arithmetic (%(days)s * INTERVAL '1 day')
  instead of string interpolation inside SQL literal
- Use asyncio.gather() for parallel query execution in get_analytics()
- Add range upper bound (max 365 days) to prevent DoS via full-table scans
- Add thread_id validation (alphanumeric, max 128 chars) in replay API
- Sanitize error messages to not reflect user input
2026-03-31 13:38:09 +02:00
Yaojia Wang
33db5aeb10 feat: complete phase 4 -- conversation replay API + analytics dashboard
- Replay models: StepType enum, ReplayStep, ReplayPage frozen dataclasses
- Checkpoint transformer: PostgresSaver JSONB -> structured timeline steps
- Replay API: GET /api/conversations (paginated), GET /api/replay/{thread_id}
- Analytics models: AgentUsage, InterruptStats, AnalyticsResult
- Analytics event recorder: Protocol + PostgresAnalyticsRecorder + NoOp
- Analytics queries: resolution_rate, agent_usage, escalation_rate, cost, interrupts
- Analytics API: GET /api/analytics?range=Xd with envelope response
- DB migration: analytics_events table + conversations column additions
- 74 new tests, 399 total passing, 92.87% coverage
2026-03-31 13:35:45 +02:00
Yaojia Wang
a2f750269d fix: address critical security and code review findings in Phase 3
- Wire ImportOrchestrator into review_api start_import via BackgroundTasks
- Sanitize docstrings in generated tool code to prevent code injection
- Add Literal["read", "write"] validation for access_type
- Add regex validation for agent_group
- Validate URL scheme (http/https only) in ImportRequest
- Validate LLM output fields (clamp confidence, validate access_type)
- Use dataclasses.replace instead of manual reconstruction in importer
- Expand SSRF blocked networks (Carrier-Grade NAT, IPv4-mapped IPv6, etc.)
- Make _BLOCKED_NETWORKS immutable tuple
- Use yaml.safe_dump instead of yaml.dump
- Fix _to_snake_case for empty strings and Python keywords
2026-03-31 00:28:28 +02:00
Yaojia Wang
a54eb224e0 feat: complete phase 3 -- OpenAPI auto-discovery, SSRF protection, tool generation
- SSRF protection: private IP blocking, DNS rebinding defense, redirect validation
- OpenAPI fetcher with SSRF guard, JSON/YAML auto-detection, 10MB limit
- Structural spec validator (3.0.x/3.1.x)
- Endpoint parser with $ref resolution, auto-generated operation IDs
- Heuristic + LLM endpoint classifier with Protocol interface
- Review API at /api/openapi (import, job status, classification CRUD, approve)
- @tool code generator + Agent YAML generator
- Import orchestrator (fetch -> validate -> parse -> classify pipeline)
- 125 new tests, 322 total passing, 93.23% coverage
2026-03-31 00:10:44 +02:00
Yaojia Wang
006b4ee5d7 fix: resolve ruff lint errors in Phase 2 code
- Move intent imports to TYPE_CHECKING block in graph.py (TC001)
- Rename test classes to CapWords convention (N801)
- Fix line length violations across test files (E501)
- Auto-fix import sorting (I001)
2026-03-30 21:44:47 +02:00
Yaojia Wang
b861ff055f test: add routing integration tests for Phase 2 test requirements
9 tests covering the complete multi-agent routing flow:
- Single-intent routing to each agent (order_lookup, order_actions, discount, fallback)
- Multi-intent routing hint injection for sequential execution
- Ambiguity detection skips graph and returns clarification
- Low confidence threshold triggers ambiguity
- No-classifier fallback to supervisor prompt routing

Fills Phase 2 test requirement for integration-level routing coverage.
Total: 197 tests, 92.60% coverage.
2026-03-30 21:41:01 +02:00
Yaojia Wang
512f988dd0 test: add Phase 2 checkpoint acceptance tests
18 integration tests validating all 7 Phase 2 checkpoint criteria:
1. Order query routes to order_lookup agent
2. Multi-intent classification with routing hint injection
3. Ambiguous message triggers clarification prompt
4. 30-min interrupt TTL auto-cancel with retry prompt
5. Webhook POST escalation with retry on failure
6. E-commerce template loads 4 correctly configured agents
7. Coverage at 92.60% (188 tests total)
2026-03-30 21:38:25 +02:00
Yaojia Wang
6e7b824b64 test: add integration tests for WebSocket message flow
17 integration tests covering:
- Happy path: token streaming, tool calls, multi-message sessions
- Interrupt flow: approve and reject paths with manager tracking
- Session TTL: expiration, sliding window reset, interrupt extension
- Validation: invalid JSON, missing fields, size limits
- Interrupt TTL: expired interrupt sends retry prompt

Fills Phase 1 test gap for integration-level WebSocket coverage.
Total: 170 tests, 92.15% coverage.
2026-03-30 21:24:31 +02:00
Yaojia Wang
1050df780d feat: complete phase 2 -- multi-agent routing, interrupt TTL, escalation, templates
- Intent classification with LLM structured output (single/multi/ambiguous)
- Discount agent with apply_discount and generate_coupon tools
- Interrupt manager with 30-min TTL auto-expiration and retry prompts
- Webhook escalation module with exponential backoff retry (max 3)
- Three vertical industry templates (e-commerce, SaaS, fintech)
- Template loading in AgentRegistry
- Enhanced supervisor prompt with dynamic agent descriptions
- 153 tests passing, 90.18% coverage
2026-03-30 21:04:39 +02:00
Yaojia Wang
7c3571b47d chore: update local Claude Code permission settings
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-30 15:14:16 +02:00
165 changed files with 19578 additions and 605 deletions

View File

@@ -2,7 +2,22 @@
"permissions": {
"allow": [
"Bash(find:*)",
"Bash(ruff:*)",
"Bash(pytest:*)",
"Bash(git status:*)",
"Bash(git diff:*)",
"Bash(git log:*)",
"Bash(git branch:*)",
"Bash(git add:*)",
"Bash(git commit:*)",
"Bash(git checkout:*)",
"Bash(git merge:*)",
"Bash(git tag:*)",
"Bash(git show:*)",
"Bash(docker:*)",
"Bash(docker-compose:*)",
"WebSearch"
]
],
"defaultMode": "bypassPermissions"
}
}

35
.env.example Normal file
View File

@@ -0,0 +1,35 @@
# Smart Support -- Docker Compose environment variables
# Copy to .env and fill in your values
# PostgreSQL password (used by both postgres and backend services)
POSTGRES_PASSWORD=dev_password
# LLM provider: anthropic | openai | azure_openai | google
LLM_PROVIDER=anthropic
LLM_MODEL=claude-sonnet-4-6
# API keys (provide the one matching LLM_PROVIDER)
ANTHROPIC_API_KEY=
OPENAI_API_KEY=
GOOGLE_API_KEY=
# Azure OpenAI (required when LLM_PROVIDER=azure_openai)
AZURE_OPENAI_API_KEY=
AZURE_OPENAI_ENDPOINT=
AZURE_OPENAI_DEPLOYMENT=
AZURE_OPENAI_API_VERSION=2024-12-01-preview
# Optional: webhook URL for escalation notifications
WEBHOOK_URL=
# Session and interrupt TTL in minutes
SESSION_TTL_MINUTES=30
INTERRUPT_TTL_MINUTES=30
# Optional: API key for admin endpoints (analytics, replay, openapi, websocket)
# Leave empty to disable authentication (dev mode)
ADMIN_API_KEY=
# Optional: load a named agent template instead of agents.yaml
# Available templates: ecommerce, saas, generic
TEMPLATE_NAME=

View File

@@ -30,7 +30,7 @@ pytest --cov=app --cov-report=term-missing
# - If any test fails, fix it before starting the new phase
# 3. Create checkpoint to snapshot the starting state
/everything-claude-code:checkpoint create [phase name]
/ecc:checkpoint create "phase-name"
# 4. Create the phase branch
git checkout main
@@ -50,25 +50,32 @@ git checkout -b phase-{N}/{short-description}
3. Identify all tasks, acceptance criteria, and dependencies for this phase
4. Create a phase dev log **skeleton** at `docs/phases/phase-{N}-dev-log.md` (date, branch name, plan link only -- content filled in Step 5)
### Step 2: Develop Using Orchestrate Skill
### Step 2: Develop Using ECC Skills
Route to the correct orchestration mode based on work type:
Route to the correct skill based on work type:
| Work Type | Skill Command |
|-----------|---------------|
| New feature | `/everything-claude-code:orchestrate feature` |
| Bug fix | `/everything-claude-code:orchestrate bugfix` |
| Refactor | `/everything-claude-code:orchestrate refactor` |
| Work Type | Skill Command | What It Does |
|-----------|---------------|--------------|
| New feature | `/ecc:feature-dev <desc>` | Discovery -> Exploration -> Architecture -> TDD -> Review -> Summary |
| Bug fix | `/ecc:tdd` then `/ecc:code-review` | RED -> GREEN -> REFACTOR cycle, then review |
| Refactor | `/ecc:plan` then `/ecc:tdd` then `/ecc:code-review` | Plan refactor scope, TDD, review |
| Security-sensitive | Add `/ecc:security-review` after code-review | Auth, payments, user input, external APIs |
| Final verification | `/ecc:verify` | Build + tests + lint + coverage + security scan |
ALWAYS use the appropriate orchestrate skill. Never develop without it.
A single phase may contain mixed work types (e.g., Phase 5 has feature + bugfix + refactor). Call the orchestrate skill **per sub-task** with the matching mode. Example:
A single phase may contain mixed work types. Call the appropriate skill **per sub-task**:
```
# Within Phase 5:
/everything-claude-code:orchestrate feature # for demo script
/everything-claude-code:orchestrate bugfix # for error handling fixes
/everything-claude-code:orchestrate refactor # for code cleanup
# Within a phase:
/ecc:feature-dev "demo script" # for new features
/ecc:tdd # for bug fixes (write failing test, then fix)
/ecc:plan "consolidate error handling" # for refactors (plan first, then TDD)
```
For full multi-phase autonomous execution, use GSD:
```
/gsd:autonomous # execute all remaining phases
/gsd:execute-phase 6 # execute a specific phase
```
### Step 3: Module Independence (CRITICAL)
@@ -171,10 +178,10 @@ After all development and testing, run verification in this exact order:
```
# 1. Run the verification skill -- must pass
/everything-claude-code:verify
/ecc:verify
# 2. Verify the checkpoint -- validates all phase deliverables
/everything-claude-code:checkpoint verify [phase name]
/ecc:checkpoint verify "phase-name"
```
The checkpoint verify validates:
@@ -222,11 +229,11 @@ git push origin main --tags
All four markers must be consistent. If any is missed, the next phase's Step 0 regression gate will catch the discrepancy.
A checkpoint includes:
- `/everything-claude-code:checkpoint create` at phase start
- `/everything-claude-code:checkpoint verify` at phase end
- `/ecc:checkpoint create` at phase start
- `/ecc:checkpoint verify` at phase end
- All tests passing (80%+ coverage)
- Phase dev log written and linked
- `/everything-claude-code:verify` passed
- `/ecc:verify` passed
- Git tag `checkpoint/phase-{N}` created
- Phase marked COMPLETED in four locations
- Branch merged to main
@@ -238,10 +245,10 @@ A checkpoint includes:
| Phase | Branch | Focus | Status |
|-------|--------|-------|--------|
| 1 | `phase-1/core-framework` | FastAPI + LangGraph + React chat loop + PostgresSaver | COMPLETED (2026-03-30) |
| 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | NOT STARTED |
| 3 | `phase-3/openapi-discovery` | OpenAPI parsing + MCP generation + SSRF protection | NOT STARTED |
| 4 | `phase-4/analytics-replay` | Replay API + analytics dashboard | NOT STARTED |
| 5 | `phase-5/polish-demo` | Error hardening + demo prep + Docker deploy | NOT STARTED |
| 2 | `phase-2/multi-agent-safety` | Supervisor routing + interrupts + templates | COMPLETED (2026-03-30) |
| 3 | `phase-3/openapi-discovery` | OpenAPI parsing + MCP generation + SSRF protection | COMPLETED (2026-03-30) |
| 4 | `phase-4/analytics-replay` | Replay API + analytics dashboard | COMPLETED (2026-03-31) |
| 5 | `phase-5/polish-demo` | Error hardening + demo prep + Docker deploy | COMPLETED (2026-03-31) |
Status values: `NOT STARTED` -> `IN PROGRESS` -> `COMPLETED (YYYY-MM-DD)`
@@ -264,7 +271,7 @@ This project inherits from `~/.claude/rules/`. CLAUDE.md only contains project-s
### Hooks (ECC Plugin -- No Custom Hooks)
All hooks come from the ECC plugin (`everything-claude-code`). No project-level hooks in `.claude/settings.local.json`.
All hooks come from the ECC plugin (`ecc`). No project-level hooks in `.claude/settings.local.json`.
| ECC Hook | Type | What It Does |
|----------|------|-------------|
@@ -290,7 +297,7 @@ Controlled by `ECC_HOOK_PROFILE` env var in `~/.claude/settings.json` (currently
- Architecture doc: `docs/ARCHITECTURE.md`
- Phase dev logs: `docs/phases/phase-{N}-dev-log.md`
- Test command: `pytest --cov=app --cov-report=term-missing`
- **Phase start:** `/everything-claude-code:checkpoint create [phase name]`
- **Phase end:** `/everything-claude-code:checkpoint verify [phase name]`
- Verify command: `/everything-claude-code:verify`
- Orchestrate: `/everything-claude-code:orchestrate {feature|bugfix|refactor}`
- **Phase start:** `/ecc:checkpoint create "phase-name"`
- **Phase end:** `/ecc:checkpoint verify "phase-name"`
- Verify command: `/ecc:verify`
- Orchestrate: `/ecc:orchestrate {feature|bugfix|refactor}`

267
README.md
View File

@@ -1,159 +1,174 @@
# Smart Support
AI 客服行动层框架。粘贴你的 API获得一个能执行真实操作的智能客服。
AI customer support action layer. Paste your API spec, get an AI agent that executes real actions.
## 问题
## The Problem
现有客服工具(ZendeskIntercomAda)擅长回答 FAQ但自动化率卡在 20-30%。剩下 70% 的工单需要人工登录内部系统,手动查订单、取消订单、发优惠券。
Existing support tools (Zendesk, Intercom, Ada) answer FAQs well but automation
rates stall at 20-30%. The remaining 70% of tickets require agents to manually
log into internal systems to look up orders, cancel orders, issue coupons.
Smart Support 是补全这个缺口的「行动层」。它不替代现有客服平台,而是让 AI 能直接调用内部系统完成操作。
Smart Support fills that gap as the "action layer" -- it does not replace your
existing support platform, it enables AI to directly call your internal systems.
## 工作原理
## How It Works
```
客户消息 → Chat UI FastAPI WebSocket LangGraph Supervisor → 专业 Agent MCP Tools → 你的内部系统
Agent 注册表 interrupt()
(YAML 配置) (人工确认)
User message -> Chat UI -> FastAPI WebSocket -> LangGraph Supervisor -> Specialist Agent -> MCP Tools -> Your systems
| |
Agent Registry interrupt()
(YAML config) (human approval)
|
PostgresSaver
(会话状态持久化)
(session persistence)
```
1. 客户在聊天界面发送消息
2. LangGraph Supervisor 分析意图,路由到对应的专业 Agent
3. Agent 通过 MCP 协议调用你的内部系统(查订单、取消订单、发折扣...
4. 涉及写操作时,自动触发人工确认流程
5. 所有操作全程记录,支持回放和分析
1. User sends a message in the chat UI.
2. LangGraph Supervisor classifies intent and routes to the right agent.
3. Agent calls your internal systems via MCP tools.
4. Write operations trigger a human-in-the-loop approval gate.
5. All operations are logged with full replay and analytics.
## 核心特性
## Key Features
- **多 Agent 协作** - 不同操作由不同 Agent 处理,各自拥有独立的权限边界和工具集
- **即插即用** - 粘贴 OpenAPI 规范 URL,自动生成 MCP 工具和 Agent 配置
- **人工确认** - 所有写操作(取消、退款、修改)需要人工审批,读操作直接执行
- **会话上下文** - 支持多轮对话Agent 能理解「取消那个订单」这样的指代
- **实时流式输出** - WebSocket 双向通信,逐 token 流式返回
- **对话回放** - 逐步查看 Agent 决策过程、工具调用和返回结果
- **数据分析** - 解决率、Agent 使用率、升级率、每次对话成本
- **YAML 驱动配置** - Agent 定义、人设、垂直模板全部通过 YAML 配置
- **Multi-agent routing** -- each operation goes to a specialist agent with its own tools and permissions
- **Zero-config import** -- paste an OpenAPI 3.0 URL, agents are generated automatically
- **Human-in-the-loop** -- all write operations (cancel, refund, modify) require approval; reads execute immediately
- **Session context** -- multi-turn conversation with persistent state across reconnects
- **Real-time streaming** -- WebSocket token streaming with live tool call visibility
- **Conversation replay** -- step-by-step audit trail of every agent decision
- **Analytics dashboard** -- resolution rate, agent usage, escalation rate, cost per conversation
- **YAML-driven config** -- agents, personas, and vertical templates in a single file
## 技术栈
## Tech Stack
| 组件 | 技术选型 |
|------|---------|
| 后端 | Python 3.11+, FastAPI |
| Agent 编排 | LangGraph v1.1, langgraph-supervisor |
| 工具集成 | langchain-mcp-adapters, @tool |
| 状态持久化 | PostgreSQL + langgraph-checkpoint-postgres |
| LLM | Claude Sonnet 4.6(可切换 OpenAI、Google 等) |
| 前端 | React |
| 部署 | Docker Compose |
| Component | Technology |
|-----------|-----------|
| Backend | Python 3.11+, FastAPI |
| Agent orchestration | LangGraph 1.x, langgraph-supervisor |
| Session state | PostgreSQL 16 + langgraph-checkpoint-postgres |
| LLM | Claude Sonnet 4.6 (configurable: OpenAI, Azure OpenAI, Google) |
| Frontend | React 19, TypeScript, Vite |
| Testing | pytest (backend), vitest + happy-dom (frontend) |
| Deployment | Docker Compose |
## 项目结构
## Quick Start
```bash
git clone <repo-url>
cd smart-support
# Configure your LLM API key
cp .env.example .env
# Edit .env: set LLM_PROVIDER and the corresponding API key
# anthropic -> ANTHROPIC_API_KEY
# openai -> OPENAI_API_KEY
# azure_openai -> AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT + AZURE_OPENAI_DEPLOYMENT
# google -> GOOGLE_API_KEY
# Start all services
docker compose up -d
# Open the app
open http://localhost
```
### Local Development
```bash
# Start only PostgreSQL via Docker (exposed on port 5433)
docker compose up postgres -d
# Backend (in one terminal)
cd backend
pip install -e ".[dev]"
uvicorn app.main:app --host 0.0.0.0 --port 8001 --reload
# Frontend (in another terminal)
cd frontend
npm install
npm run dev # http://localhost:5173 (proxies /api and /ws to :8001)
```
See [Deployment Guide](docs/deployment.md) for production setup, HTTPS, and scaling.
## Project Structure
```
smart-support/
├── backend/
│ ├── app/
│ │ ├── main.py # FastAPI + WebSocket 入口
│ │ ├── graph.py # LangGraph Supervisor 配置
│ │ ├── agents/ # Agent 定义 + 工具
│ │ ├── registry.py # YAML Agent 注册表加载器
│ │ ├── openapi/ # OpenAPI 解析 + MCP 服务器生成
│ │ ├── replay/ # 对话回放 API
│ │ ├── analytics/ # 数据分析查询 + API
│ │ ── callbacks.py # Token 用量统计
│ ├── agents.yaml # Agent 注册表配置
│ ├── templates/ # 垂直行业模板
└── tests/
├── frontend/ # React 聊天 UI + 回放 + 仪表盘
├── docker-compose.yml # PostgreSQL + 应用
└── pyproject.toml
│ │ ├── main.py # FastAPI + WebSocket entry point
│ │ ├── graph.py # LangGraph Supervisor construction
│ │ ├── graph_context.py # Typed wrapper for graph + classifier + registry
│ │ ├── ws_handler.py # WebSocket message dispatch + rate limiting
│ │ ├── ws_context.py # WebSocket dependency bundle
│ │ ├── auth.py # API key authentication middleware
│ │ ├── api_utils.py # Shared API response helpers
│ │ ── safety.py # Confirmation rules + MCP error taxonomy
│ ├── agents/ # Agent definitions and tools
│ ├── registry.py # YAML agent registry loader
│ ├── openapi/ # OpenAPI parser, classifier, and review API
│ │ ├── replay/ # Conversation replay API
│ │ └── analytics/ # Analytics queries and API
│ ├── agents.yaml # Agent registry configuration
│ ├── templates/ # Vertical industry templates
│ └── tests/ # Unit, integration, and E2E tests
├── frontend/
│ ├── src/
│ │ ├── pages/ # Chat, Replay, Dashboard, Review pages
│ │ ├── components/ # NavBar, Layout, MetricCard, ReplayTimeline
│ │ ├── hooks/ # useWebSocket with reconnect support
│ │ └── api.ts # Typed API client
│ └── Dockerfile # Multi-stage nginx build
├── docs/ # Architecture, deployment, guides
├── docker-compose.yml # Full-stack compose
└── .env.example # Environment variable template
```
## 快速开始
## API Endpoints
| Method | Path | Auth | Description |
|--------|------|------|-------------|
| WS | `/ws` | Token | Main WebSocket chat endpoint (`?token=<key>`) |
| GET | `/api/health` | No | Health check |
| GET | `/api/conversations` | API Key | List conversations (paginated) |
| GET | `/api/replay/{thread_id}` | API Key | Replay conversation steps (paginated) |
| GET | `/api/analytics` | API Key | Analytics summary (`?range=7d`) |
| POST | `/api/openapi/import` | API Key | Start OpenAPI import job |
| GET | `/api/openapi/jobs/{id}` | API Key | Check import job status |
| GET | `/api/openapi/jobs/{id}/classifications` | API Key | Get endpoint classifications |
| PUT | `/api/openapi/jobs/{id}/classifications/{idx}` | API Key | Update a classification |
| POST | `/api/openapi/jobs/{id}/approve` | API Key | Approve and generate tools |
Authentication is controlled by the `ADMIN_API_KEY` environment variable.
API Key endpoints require the `X-API-Key` header. When `ADMIN_API_KEY` is unset, auth is disabled.
## Running Tests
```bash
# 启动 PostgreSQL 和应用
docker compose up
# Backend (516 tests, 94% coverage)
cd backend
pytest --cov=app --cov-report=term-missing
# 访问聊天界面
open http://localhost:8000
# Frontend (23 tests, vitest + happy-dom)
cd frontend
npm test
```
## Agent 配置示例
Backend coverage is enforced at 80%+.
```yaml
# agents.yaml
agents:
- name: order_lookup
description: 查询订单状态、物流信息
permission: read
personality:
tone: professional
greeting: "您好,我来帮您查询订单信息。"
tools:
- get_order_status
- get_tracking_info
## Documentation
- name: order_actions
description: 取消订单、修改订单
permission: write # 触发人工确认
personality:
tone: careful
greeting: "我可以帮您处理订单变更,所有操作都会先经过您的确认。"
tools:
- cancel_order
- modify_order
- name: discount
description: 发放优惠券、折扣码
permission: write
tools:
- apply_discount
- generate_coupon
```
## OpenAPI 自动接入
不需要手动写 MCP 连接器。粘贴你的 API 规范 URL
1. 框架解析 OpenAPI 3.0 规范
2. LLM 自动分类每个端点(读/写、客户参数、Agent 分组)
3. 运维人员审核分类结果
4. 自动生成 MCP 服务器 + Agent YAML 配置
5. 新工具立即可用
## 安全设计
- **人工确认** - 所有写操作需要客户或运维人员批准
- **SSRF 防护** - OpenAPI URL 导入时屏蔽内网地址和 DNS 重绑定攻击
- **操作审计** - 每个操作记录 Agent、参数、结果、时间戳
- **权限隔离** - 每个 Agent 只能访问其配置的工具集
- **中断超时** - 30 分钟未确认的操作自动取消,防止过期审批
## 开发阶段
| 阶段 | 周期 | 内容 |
|------|------|------|
| Phase 1 | 第 1-3 周 | 核心框架Chat UI + Supervisor + Agent 注册表 + 中断流程 |
| Phase 2 | 第 3-4 周 | 多 Agent 路由 + Webhook 升级 + 垂直模板 |
| Phase 3 | 第 4-6 周 | OpenAPI 自动发现 + MCP 服务器生成 + SSRF 防护 |
| Phase 4 | 第 6-7 周 | 对话回放 + 数据分析仪表盘 |
## 目标用户
中型电商公司(日均 500-5000 订单5-20 名客服)的客户体验负责人。
他们的痛点:客服需要在 Zendesk 和 Shopify 后台之间反复切换手动执行查询和操作。Smart Support 让 AI 直接完成这些操作,人工只需审批关键步骤。
## 相关文档
- [设计文档](design-doc.md) - 问题定义、约束、方案选择
- [CEO 计划](ceo-plan.md) - 产品愿景、范围决策
- [工程评审计划](eng-review-plan.md) - 架构决策、测试策略、失败模式
- [测试计划](eng-review-test-plan.md) - 测试路径、边界情况、E2E 流程
- [待办事项](TODOS.md) - 延迟到后续阶段的工作
| Document | Description |
|----------|-------------|
| [Architecture](docs/ARCHITECTURE.md) | System design, component diagram, data flow, ADRs |
| [Development Plan](docs/DEVELOPMENT-PLAN.md) | Phase breakdown, task checklists, and status |
| [Agent Config Guide](docs/agent-config-guide.md) | agents.yaml format, fields, templates, routing logic |
| [OpenAPI Import Guide](docs/openapi-import-guide.md) | Auto-discovery workflow, REST API, SSRF protection |
| [Deployment Guide](docs/deployment.md) | Docker, local dev, production, HTTPS, backups, scaling |
| [Demo Script](docs/demo-script.md) | Step-by-step live demo walkthrough (5 scenes) |
| [UX Design System](docs/ux_design_system.md) | Color palette, typography, component patterns, CSS tokens |
## License

View File

@@ -1,19 +1,34 @@
# Database
# Smart Support Backend -- environment variables
# Copy to .env and fill in your values
# Required: PostgreSQL connection string
DATABASE_URL=postgresql://smart_support:dev_password@localhost:5432/smart_support
# LLM Provider: anthropic | openai | google
# Required: LLM provider configuration
# provider: anthropic | openai | google
LLM_PROVIDER=anthropic
LLM_MODEL=claude-sonnet-4-6
# API Keys (set the one matching your LLM_PROVIDER)
# API keys -- provide the one matching LLM_PROVIDER
ANTHROPIC_API_KEY=
OPENAI_API_KEY=
GOOGLE_API_KEY=
# Session
# Optional: webhook endpoint for escalation notifications
# The backend will POST a JSON payload when a conversation is escalated.
WEBHOOK_URL=
WEBHOOK_TIMEOUT_SECONDS=10
WEBHOOK_MAX_RETRIES=3
# Session management
SESSION_TTL_MINUTES=30
INTERRUPT_TTL_MINUTES=30
# Server
# Optional: load a named agent template instead of agents.yaml
# Leave blank to use the default agents.yaml in the backend directory.
# Available templates: ecommerce, saas, generic
TEMPLATE_NAME=
# Server binding
WS_HOST=0.0.0.0
WS_PORT=8000

View File

@@ -20,6 +20,17 @@ agents:
tools:
- cancel_order
- name: discount
description: "Applies discounts and generates coupon codes. Use when the customer asks for discounts, promotions, or coupon codes."
permission: write
personality:
tone: "generous and accommodating"
greeting: "I can help you with discounts and coupons!"
escalation_message: "Let me connect you with our promotions team."
tools:
- apply_discount
- generate_coupon
- name: fallback
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
permission: read

149
backend/alembic.ini Normal file
View File

@@ -0,0 +1,149 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the tzdata library which can be installed by adding
# `alembic[tz]` to the pip requirements.
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url =
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
backend/alembic/README Normal file
View File

@@ -0,0 +1 @@
Generic single-database configuration.

67
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,67 @@
"""Alembic environment configuration for smart-support."""
from __future__ import annotations
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# No SQLAlchemy ORM models -- we use raw DDL migrations
target_metadata = None
def _get_url() -> str:
"""Read DATABASE_URL from environment, falling back to alembic.ini."""
return os.environ.get("DATABASE_URL", "") or config.get_main_option(
"sqlalchemy.url", ""
)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
Configures the context with just a URL so that an Engine
is not required.
"""
url = _get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode with a live database connection."""
configuration = config.get_section(config.config_ini_section, {})
configuration["sqlalchemy.url"] = _get_url()
connectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,92 @@
"""Initial schema -- all application tables.
Revision ID: a1b2c3d4e5f6
Revises:
Create Date: 2026-04-06
"""
from __future__ import annotations
from alembic import op
revision: str = "a1b2c3d4e5f6"
down_revision: str | None = None
branch_labels: tuple[str, ...] | None = None
depends_on: tuple[str, ...] | None = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE IF NOT EXISTS conversations (
thread_id TEXT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
total_tokens INTEGER NOT NULL DEFAULT 0,
total_cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
status TEXT NOT NULL DEFAULT 'active'
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS active_interrupts (
interrupt_id TEXT PRIMARY KEY,
thread_id TEXT NOT NULL REFERENCES conversations(thread_id),
action TEXT NOT NULL,
params JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
resolved_at TIMESTAMPTZ,
resolution TEXT
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS sessions (
thread_id TEXT PRIMARY KEY,
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"""
)
op.execute(
"""
CREATE TABLE IF NOT EXISTS analytics_events (
id BIGSERIAL PRIMARY KEY,
thread_id TEXT NOT NULL,
event_type TEXT NOT NULL,
agent_name TEXT,
tool_name TEXT,
tokens_used INTEGER NOT NULL DEFAULT 0,
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
duration_ms INTEGER,
success BOOLEAN,
error_message TEXT,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"""
)
# Migration columns added in Phase 4
op.execute(
"""
ALTER TABLE conversations
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ
"""
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS analytics_events")
op.execute("DROP TABLE IF EXISTS sessions")
op.execute("DROP TABLE IF EXISTS active_interrupts")
op.execute("DROP TABLE IF EXISTS conversations")

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langchain_core.tools import BaseTool
from app.agents.discount import apply_discount, generate_coupon
from app.agents.fallback import fallback_respond
from app.agents.order_actions import cancel_order
from app.agents.order_lookup import get_order_status, get_tracking_info
@@ -16,6 +17,8 @@ _TOOL_MAP: dict[str, BaseTool] = {
"get_tracking_info": get_tracking_info,
"cancel_order": cancel_order,
"fallback_respond": fallback_respond,
"apply_discount": apply_discount,
"generate_coupon": generate_coupon,
}

View File

@@ -0,0 +1,79 @@
"""Discount agent tools -- apply discounts and generate coupons."""
from __future__ import annotations
import uuid
from langchain_core.tools import tool
from langgraph.types import interrupt
@tool
def apply_discount(order_id: str, discount_percent: int) -> dict:
"""Apply a discount to an order. Requires human approval before execution."""
if discount_percent < 1 or discount_percent > 100:
return {
"status": "error",
"order_id": order_id,
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
}
response = interrupt(
{
"action": "apply_discount",
"order_id": order_id,
"discount_percent": discount_percent,
"message": (
f"Please confirm: apply {discount_percent}% discount to order {order_id}?"
),
}
)
if isinstance(response, bool):
approved = response
elif isinstance(response, dict):
approved = response.get("approved", False)
else:
approved = bool(response)
if approved:
return {
"status": "applied",
"order_id": order_id,
"discount_percent": discount_percent,
"message": (
f"{discount_percent}% discount applied to order {order_id}."
),
}
return {
"status": "declined",
"order_id": order_id,
"message": f"Discount for order {order_id} was declined.",
}
@tool
def generate_coupon(discount_percent: int, expiry_days: int = 30) -> dict:
"""Generate a coupon code with the specified discount percentage."""
if discount_percent < 1 or discount_percent > 100:
return {
"status": "error",
"message": f"Invalid discount: {discount_percent}%. Must be between 1 and 100.",
}
if expiry_days < 1:
return {
"status": "error",
"message": f"Invalid expiry: {expiry_days} days. Must be at least 1.",
}
coupon_code = f"SAVE{discount_percent}-{uuid.uuid4().hex[:8].upper()}"
return {
"status": "generated",
"coupon_code": coupon_code,
"discount_percent": discount_percent,
"expiry_days": expiry_days,
"message": (
f"Coupon {coupon_code} generated: {discount_percent}% off, "
f"valid for {expiry_days} days."
),
}

View File

@@ -1,4 +1,4 @@
"""Fallback agent tools -- handles unmatched intents."""
"""Fallback agent tools -- handles unmatched intents and clarification requests."""
from __future__ import annotations
@@ -13,6 +13,7 @@ def fallback_respond(query: str) -> str:
"Here's what I can do:\n"
"- Check order status (e.g., 'What is the status of order 1042?')\n"
"- Get tracking information (e.g., 'Track order 1042')\n"
"- Cancel an order (e.g., 'Cancel order 1042')\n\n"
"- Cancel an order (e.g., 'Cancel order 1042')\n"
"- Apply discounts or generate coupons\n\n"
"Could you please rephrase your request?"
)

View File

@@ -0,0 +1,3 @@
"""Analytics module -- event recording and dashboard queries."""
from __future__ import annotations

View File

@@ -0,0 +1,60 @@
"""Analytics API router -- dashboard metrics endpoint."""
from __future__ import annotations
import re
from dataclasses import asdict
from typing import TYPE_CHECKING
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from app.analytics.queries import get_analytics
from app.api_utils import envelope
from app.auth import require_admin_api_key
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
router = APIRouter(
prefix="/api/v1/analytics",
tags=["analytics"],
dependencies=[Depends(require_admin_api_key)],
)
_RANGE_PATTERN = re.compile(r"^(\d+)d$")
_DEFAULT_RANGE = "7d"
_MAX_RANGE_DAYS = 365
async def _get_pool(request: Request) -> AsyncConnectionPool:
"""Dependency: extract the shared pool from app state."""
return request.app.state.pool
def _parse_range(range_str: str) -> int:
"""Parse 'Xd' range string to integer days. Raises 400 on invalid format."""
match = _RANGE_PATTERN.match(range_str)
if not match:
raise HTTPException(
status_code=400,
detail="Invalid range format. Expected: '<N>d' e.g. '7d', '30d'.",
)
days = int(match.group(1))
if days < 1 or days > _MAX_RANGE_DAYS:
raise HTTPException(
status_code=400,
detail=f"Range must be between 1 and {_MAX_RANGE_DAYS} days.",
)
return days
@router.get("")
async def analytics(
request: Request,
range: str = Query(default=_DEFAULT_RANGE, alias="range"), # noqa: A002
) -> dict:
"""Return aggregated analytics metrics for the given time range."""
range_days = _parse_range(range)
pool = await _get_pool(request)
result = await get_analytics(pool, range_days=range_days)
return envelope(asdict(result))

View File

@@ -0,0 +1,97 @@
"""Analytics event recorder -- Protocol and implementations."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from psycopg.types.json import Json
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
_INSERT_SQL = """
INSERT INTO analytics_events
(thread_id, event_type, agent_name, tool_name, tokens_used, cost_usd,
duration_ms, success, error_message, metadata)
VALUES
(%(thread_id)s, %(event_type)s, %(agent_name)s, %(tool_name)s,
%(tokens_used)s, %(cost_usd)s, %(duration_ms)s, %(success)s,
%(error_message)s, %(metadata)s)
"""
@runtime_checkable
class AnalyticsRecorder(Protocol):
"""Protocol for recording analytics events."""
async def record(
self,
*,
thread_id: str,
event_type: str,
agent_name: str | None = None,
tool_name: str | None = None,
tokens_used: int = 0,
cost_usd: float = 0.0,
duration_ms: int | None = None,
success: bool | None = None,
error_message: str | None = None,
metadata: dict | None = None,
) -> None: ...
class NoOpAnalyticsRecorder:
"""No-op implementation for testing or when the DB is unavailable."""
async def record(
self,
*,
thread_id: str,
event_type: str,
agent_name: str | None = None,
tool_name: str | None = None,
tokens_used: int = 0,
cost_usd: float = 0.0,
duration_ms: int | None = None,
success: bool | None = None,
error_message: str | None = None,
metadata: dict | None = None,
) -> None:
"""Do nothing."""
class PostgresAnalyticsRecorder:
"""Postgres-backed analytics recorder -- INSERTs into analytics_events."""
def __init__(self, pool: AsyncConnectionPool) -> None:
self._pool = pool
async def record(
self,
*,
thread_id: str,
event_type: str,
agent_name: str | None = None,
tool_name: str | None = None,
tokens_used: int = 0,
cost_usd: float = 0.0,
duration_ms: int | None = None,
success: bool | None = None,
error_message: str | None = None,
metadata: dict | None = None,
) -> None:
"""Insert one analytics event row."""
params: dict[str, Any] = {
"thread_id": thread_id,
"event_type": event_type,
"agent_name": agent_name,
"tool_name": tool_name,
"tokens_used": tokens_used,
"cost_usd": cost_usd,
"duration_ms": duration_ms,
"success": success,
"error_message": error_message,
"metadata": Json(metadata or {}),
}
async with self._pool.connection() as conn:
await conn.execute(_INSERT_SQL, params)

View File

@@ -0,0 +1,38 @@
"""Value objects for analytics dashboard."""
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class AgentUsage:
"""Agent usage statistics within a time range."""
agent: str
count: int
percentage: float
@dataclass(frozen=True)
class InterruptStats:
"""Interrupt approval/rejection statistics within a time range."""
total: int = 0
approved: int = 0
rejected: int = 0
expired: int = 0
@dataclass(frozen=True)
class AnalyticsResult:
"""Full analytics result for a given time range."""
range: str
total_conversations: int
resolution_rate: float
escalation_rate: float
avg_turns_per_conversation: float
avg_cost_per_conversation_usd: float
agent_usage: tuple[AgentUsage, ...]
interrupt_stats: InterruptStats

View File

@@ -0,0 +1,184 @@
"""Analytics query functions -- all async, take pool + range_days."""
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
_RESOLUTION_RATE_SQL = """
SELECT
CASE WHEN COUNT(*) = 0 THEN 0.0
ELSE COUNT(*) FILTER (WHERE resolution_type = 'resolved')::float / COUNT(*)
END AS rate
FROM conversations
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
"""
_ESCALATION_RATE_SQL = """
SELECT
CASE WHEN COUNT(*) = 0 THEN 0.0
ELSE COUNT(*) FILTER (WHERE resolution_type = 'escalated')::float / COUNT(*)
END AS rate
FROM conversations
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
"""
_TOTAL_CONVERSATIONS_SQL = """
SELECT COUNT(*) AS total
FROM conversations
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
"""
_AVG_TURNS_SQL = """
SELECT COALESCE(AVG(turn_count), 0.0) AS avg_turns
FROM conversations
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
"""
_COST_PER_CONVERSATION_SQL = """
SELECT COALESCE(AVG(total_cost_usd), 0.0) AS avg_cost
FROM conversations
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
"""
_AGENT_USAGE_SQL = """
SELECT
agent,
COUNT(*) AS count,
ROUND(COUNT(*) * 100.0 / NULLIF(SUM(COUNT(*)) OVER (), 0), 2) AS percentage
FROM (
SELECT UNNEST(agents_used) AS agent
FROM conversations
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
AND agents_used IS NOT NULL
) sub
GROUP BY agent
ORDER BY count DESC
"""
_INTERRUPT_STATS_SQL = """
SELECT
COUNT(*) FILTER (WHERE event_type = 'interrupt') AS total,
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND success = TRUE) AS approved,
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND success = FALSE
AND error_message IS NULL) AS rejected,
COUNT(*) FILTER (WHERE event_type = 'interrupt' AND error_message = 'expired') AS expired
FROM analytics_events
WHERE created_at >= NOW() - (%(days)s * INTERVAL '1 day')
"""
async def resolution_rate(pool: AsyncConnectionPool, range_days: int) -> float:
"""Return the fraction of resolved conversations in the given range."""
async with pool.connection() as conn:
cursor = await conn.execute(_RESOLUTION_RATE_SQL, {"days": range_days})
row = await cursor.fetchone()
if not row:
return 0.0
return float(row.get("rate") or 0.0)
async def escalation_rate(pool: AsyncConnectionPool, range_days: int) -> float:
"""Return the fraction of escalated conversations in the given range."""
async with pool.connection() as conn:
cursor = await conn.execute(_ESCALATION_RATE_SQL, {"days": range_days})
row = await cursor.fetchone()
if not row:
return 0.0
return float(row.get("rate") or 0.0)
async def _total_conversations(pool: AsyncConnectionPool, range_days: int) -> int:
"""Return the total number of conversations in the given range."""
async with pool.connection() as conn:
cursor = await conn.execute(_TOTAL_CONVERSATIONS_SQL, {"days": range_days})
row = await cursor.fetchone()
if not row:
return 0
return int(row.get("total") or 0)
async def _avg_turns(pool: AsyncConnectionPool, range_days: int) -> float:
"""Return the average turn count per conversation in the given range."""
async with pool.connection() as conn:
cursor = await conn.execute(_AVG_TURNS_SQL, {"days": range_days})
row = await cursor.fetchone()
if not row:
return 0.0
return float(row.get("avg_turns") or 0.0)
async def cost_per_conversation(pool: AsyncConnectionPool, range_days: int) -> float:
"""Return the average cost per conversation in the given range."""
async with pool.connection() as conn:
cursor = await conn.execute(_COST_PER_CONVERSATION_SQL, {"days": range_days})
row = await cursor.fetchone()
if not row:
return 0.0
return float(row.get("avg_cost") or 0.0)
async def agent_usage(
pool: AsyncConnectionPool, range_days: int
) -> tuple[AgentUsage, ...]:
"""Return per-agent usage statistics for the given range."""
async with pool.connection() as conn:
cursor = await conn.execute(_AGENT_USAGE_SQL, {"days": range_days})
rows = await cursor.fetchall()
if not rows:
return ()
return tuple(
AgentUsage(
agent=row["agent"],
count=int(row["count"]),
percentage=float(row["percentage"]),
)
for row in rows
)
async def interrupt_stats(
pool: AsyncConnectionPool, range_days: int
) -> InterruptStats:
"""Return interrupt approval/rejection statistics for the given range."""
async with pool.connection() as conn:
cursor = await conn.execute(_INTERRUPT_STATS_SQL, {"days": range_days})
row = await cursor.fetchone()
if not row:
return InterruptStats()
return InterruptStats(
total=int(row.get("total") or 0),
approved=int(row.get("approved") or 0),
rejected=int(row.get("rejected") or 0),
expired=int(row.get("expired") or 0),
)
async def get_analytics(
pool: AsyncConnectionPool, range_days: int
) -> AnalyticsResult:
"""Aggregate all analytics metrics into a single AnalyticsResult."""
res_rate, esc_rate, cost, usage, i_stats, total, avg_t = await asyncio.gather(
resolution_rate(pool, range_days),
escalation_rate(pool, range_days),
cost_per_conversation(pool, range_days),
agent_usage(pool, range_days),
interrupt_stats(pool, range_days),
_total_conversations(pool, range_days),
_avg_turns(pool, range_days),
)
return AnalyticsResult(
range=f"{range_days}d",
total_conversations=total,
resolution_rate=res_rate,
escalation_rate=esc_rate,
avg_turns_per_conversation=avg_t,
avg_cost_per_conversation_usd=cost,
agent_usage=usage,
interrupt_stats=i_stats,
)

10
backend/app/api_utils.py Normal file
View File

@@ -0,0 +1,10 @@
"""Shared API response helpers."""
from __future__ import annotations
from typing import Any
def envelope(data: Any, *, success: bool = True, error: str | None = None) -> dict:
"""Wrap API response data in a standard envelope format."""
return {"success": success, "data": data, "error": error}

72
backend/app/auth.py Normal file
View File

@@ -0,0 +1,72 @@
"""API key authentication for admin endpoints and WebSocket connections."""
from __future__ import annotations
import secrets
from typing import Annotated
import structlog
from fastapi import Depends, HTTPException, Query, Request, WebSocket, status
from fastapi.security import APIKeyHeader
logger = structlog.get_logger()
_API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
def _get_admin_api_key(request: Request) -> str:
"""Retrieve the configured admin API key from app settings.
Returns empty string if settings are not configured (test/dev mode).
"""
settings = getattr(request.app.state, "settings", None)
if settings is None:
return ""
key = getattr(settings, "admin_api_key", "")
return key if isinstance(key, str) else ""
async def require_admin_api_key(
request: Request,
api_key: Annotated[str | None, Depends(_API_KEY_HEADER)] = None,
) -> None:
"""Dependency that enforces API key authentication on admin endpoints.
Skips validation when no admin_api_key is configured (dev mode).
"""
expected = _get_admin_api_key(request)
if not expected:
return
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing X-API-Key header",
)
if not secrets.compare_digest(api_key, expected):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API key",
)
async def verify_ws_token(
ws: WebSocket,
token: str | None = Query(default=None),
) -> None:
"""Verify WebSocket connection token from query parameter.
Skips validation when no admin_api_key is configured (dev mode).
Usage: ws://host/ws?token=<api_key>
"""
settings = ws.app.state.settings
expected = settings.admin_api_key
if not expected:
return
if token is None or not secrets.compare_digest(token, expected):
await ws.close(code=4001, reason="Unauthorized")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing WebSocket token",
)

View File

@@ -17,7 +17,7 @@ class Settings(BaseSettings):
database_url: str
llm_provider: Literal["anthropic", "openai", "google"] = "anthropic"
llm_provider: Literal["anthropic", "openai", "azure_openai", "google"] = "anthropic"
llm_model: str = "claude-sonnet-4-6"
session_ttl_minutes: int = 30
@@ -26,8 +26,22 @@ class Settings(BaseSettings):
ws_host: str = "0.0.0.0"
ws_port: int = 8000
webhook_url: str = ""
webhook_timeout_seconds: int = 10
webhook_max_retries: int = 3
template_name: str = ""
log_format: str = "console" # "console" for dev, "json" for production
admin_api_key: str = ""
anthropic_api_key: str = ""
openai_api_key: str = ""
azure_openai_api_key: str = ""
azure_openai_endpoint: str = ""
azure_openai_api_version: str = "2024-12-01-preview"
azure_openai_deployment: str = ""
google_api_key: str = ""
@model_validator(mode="after")
@@ -35,6 +49,7 @@ class Settings(BaseSettings):
key_map = {
"anthropic": self.anthropic_api_key,
"openai": self.openai_api_key,
"azure_openai": self.azure_openai_api_key,
"google": self.google_api_key,
}
key = key_map.get(self.llm_provider, "")
@@ -43,4 +58,13 @@ class Settings(BaseSettings):
f"API key for provider '{self.llm_provider}' is required. "
f"Set the corresponding environment variable."
)
if self.llm_provider == "azure_openai":
if not self.azure_openai_endpoint:
raise ValueError(
"AZURE_OPENAI_ENDPOINT is required for azure_openai provider."
)
if not self.azure_openai_deployment:
raise ValueError(
"AZURE_OPENAI_DEPLOYMENT is required for azure_openai provider."
)
return self

View File

@@ -0,0 +1,135 @@
"""Conversation tracker -- Protocol and implementations for tracking conversation state."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol, runtime_checkable
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
_ENSURE_SQL = """
INSERT INTO conversations
(thread_id, created_at, last_activity)
VALUES
(%(thread_id)s, NOW(), NOW())
ON CONFLICT (thread_id) DO NOTHING
"""
_RECORD_TURN_SQL = """
UPDATE conversations
SET
turn_count = turn_count + 1,
agents_used = CASE
WHEN %(agent_name)s IS NOT NULL AND NOT (agents_used @> ARRAY[%(agent_name)s]::text[])
THEN agents_used || ARRAY[%(agent_name)s]::text[]
ELSE agents_used
END,
total_tokens = total_tokens + %(tokens)s,
total_cost_usd = total_cost_usd + %(cost)s,
last_activity = NOW()
WHERE thread_id = %(thread_id)s
"""
_RESOLVE_SQL = """
UPDATE conversations
SET
resolution_type = %(resolution_type)s,
ended_at = NOW()
WHERE thread_id = %(thread_id)s
"""
@runtime_checkable
class ConversationTrackerProtocol(Protocol):
"""Protocol for tracking conversation lifecycle and metrics."""
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
"""Create conversation row if it does not already exist."""
...
async def record_turn(
self,
pool: AsyncConnectionPool,
thread_id: str,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Increment turn count and update aggregated metrics."""
...
async def resolve(
self,
pool: AsyncConnectionPool,
thread_id: str,
resolution_type: str,
) -> None:
"""Mark conversation as resolved with a resolution type."""
...
class NoOpConversationTracker:
"""No-op implementation -- used in tests or when DB is unavailable."""
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
"""Do nothing."""
async def record_turn(
self,
pool: AsyncConnectionPool,
thread_id: str,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Do nothing."""
async def resolve(
self,
pool: AsyncConnectionPool,
thread_id: str,
resolution_type: str,
) -> None:
"""Do nothing."""
class PostgresConversationTracker:
"""Postgres-backed conversation tracker."""
async def ensure_conversation(self, pool: AsyncConnectionPool, thread_id: str) -> None:
"""Insert conversation row; do nothing if already exists (ON CONFLICT DO NOTHING)."""
params = {"thread_id": thread_id}
async with pool.connection() as conn:
await conn.execute(_ENSURE_SQL, params)
async def record_turn(
self,
pool: AsyncConnectionPool,
thread_id: str,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Increment turn count, append agent if new, update token/cost totals."""
params = {
"thread_id": thread_id,
"agent_name": agent_name,
"tokens": tokens,
"cost": cost,
}
async with pool.connection() as conn:
await conn.execute(_RECORD_TURN_SQL, params)
async def resolve(
self,
pool: AsyncConnectionPool,
thread_id: str,
resolution_type: str,
) -> None:
"""Set resolution_type and ended_at on the conversation row."""
params = {
"thread_id": thread_id,
"resolution_type": resolution_type,
}
async with pool.connection() as conn:
await conn.execute(_RESOLVE_SQL, params)

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
@@ -34,6 +35,40 @@ CREATE TABLE IF NOT EXISTS active_interrupts (
);
"""
_ANALYTICS_EVENTS_DDL = """
CREATE TABLE IF NOT EXISTS analytics_events (
id BIGSERIAL PRIMARY KEY,
thread_id TEXT NOT NULL,
event_type TEXT NOT NULL,
agent_name TEXT,
tool_name TEXT,
tokens_used INTEGER NOT NULL DEFAULT 0,
cost_usd DOUBLE PRECISION NOT NULL DEFAULT 0.0,
duration_ms INTEGER,
success BOOLEAN,
error_message TEXT,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_SESSIONS_DDL = """
CREATE TABLE IF NOT EXISTS sessions (
thread_id TEXT PRIMARY KEY,
last_activity TIMESTAMPTZ NOT NULL DEFAULT NOW(),
has_pending_interrupt BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_CONVERSATIONS_MIGRATION_DDL = """
ALTER TABLE conversations
ADD COLUMN IF NOT EXISTS resolution_type TEXT,
ADD COLUMN IF NOT EXISTS agents_used TEXT[],
ADD COLUMN IF NOT EXISTS turn_count INTEGER NOT NULL DEFAULT 0,
ADD COLUMN IF NOT EXISTS ended_at TIMESTAMPTZ;
"""
async def create_pool(settings: Settings) -> AsyncConnectionPool:
"""Create an async connection pool with the required psycopg settings."""
@@ -54,8 +89,22 @@ async def create_checkpointer(pool: AsyncConnectionPool) -> AsyncPostgresSaver:
return checkpointer
def run_alembic_migrations(database_url: str) -> None:
"""Run Alembic migrations to head."""
from alembic.config import Config
from alembic import command
alembic_cfg = Config(str(Path(__file__).parent.parent / "alembic.ini"))
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
command.upgrade(alembic_cfg, "head")
async def setup_app_tables(pool: AsyncConnectionPool) -> None:
"""Create application-specific tables (conversations, active_interrupts)."""
"""Create application-specific tables and apply migrations."""
async with pool.connection() as conn:
await conn.execute(_CONVERSATIONS_DDL)
await conn.execute(_INTERRUPTS_DDL)
await conn.execute(_SESSIONS_DDL)
await conn.execute(_ANALYTICS_EVENTS_DDL)
await conn.execute(_CONVERSATIONS_MIGRATION_DDL)

140
backend/app/escalation.py Normal file
View File

@@ -0,0 +1,140 @@
"""Webhook escalation module -- HTTP POST with exponential backoff retry."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Protocol
import httpx
import structlog
from pydantic import BaseModel
logger = structlog.get_logger()
class EscalationPayload(BaseModel, frozen=True):
"""Immutable payload sent to the escalation webhook."""
thread_id: str
reason: str
conversation_summary: str
metadata: dict = {}
@dataclass(frozen=True)
class EscalationResult:
"""Immutable result of an escalation attempt."""
success: bool
status_code: int | None
attempts: int
error: str | None
class EscalationService(Protocol):
"""Protocol for escalation implementations."""
async def escalate(self, payload: EscalationPayload) -> EscalationResult: ...
class WebhookEscalator:
"""Sends escalation requests via HTTP POST with exponential backoff retry."""
def __init__(
self,
url: str,
timeout_seconds: int = 10,
max_retries: int = 3,
) -> None:
self._url = url
self._timeout = timeout_seconds
self._max_retries = max_retries
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
"""POST the escalation payload to the configured webhook URL."""
if not self._url:
return EscalationResult(
success=False,
status_code=None,
attempts=0,
error="Webhook URL not configured",
)
last_error: str | None = None
async with httpx.AsyncClient(timeout=self._timeout) as client:
for attempt in range(1, self._max_retries + 1):
try:
response = await client.post(
self._url,
json=payload.model_dump(),
)
if 200 <= response.status_code < 300:
logger.info(
"Escalation succeeded for thread %s (attempt %d)",
payload.thread_id,
attempt,
)
return EscalationResult(
success=True,
status_code=response.status_code,
attempts=attempt,
error=None,
)
last_error = f"HTTP {response.status_code}"
logger.warning(
"Escalation attempt %d/%d failed: %s",
attempt,
self._max_retries,
last_error,
)
except httpx.TimeoutException:
last_error = "Request timed out"
logger.warning(
"Escalation attempt %d/%d timed out",
attempt,
self._max_retries,
)
except httpx.RequestError as exc:
last_error = str(exc)
logger.warning(
"Escalation attempt %d/%d error: %s",
attempt,
self._max_retries,
last_error,
)
# Exponential backoff: skip delay after last attempt
if attempt < self._max_retries:
delay = 2**attempt
await asyncio.sleep(delay)
logger.error(
"Escalation failed for thread %s after %d attempts: %s",
payload.thread_id,
self._max_retries,
last_error,
)
return EscalationResult(
success=False,
status_code=None,
attempts=self._max_retries,
error=last_error,
)
class NoOpEscalator:
"""Escalator that does nothing -- used when webhook URL is not configured."""
async def escalate(self, payload: EscalationPayload) -> EscalationResult:
logger.info("Escalation disabled (no webhook URL). Thread: %s", payload.thread_id)
return EscalationResult(
success=False,
status_code=None,
attempts=0,
error="Escalation disabled",
)

View File

@@ -4,27 +4,46 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from langgraph.prebuilt import create_react_agent
from langchain.agents import create_agent
from langgraph_supervisor import create_supervisor
from app.agents import get_tools_by_names
from app.graph_context import GraphContext
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.graph.state import CompiledStateGraph
from app.intent import IntentClassifier
from app.registry import AgentRegistry
import structlog
logger = structlog.get_logger()
SUPERVISOR_PROMPT = (
"You are a customer support supervisor. "
"Route customer requests to the appropriate agent based on their description. "
"For order status and tracking queries, use the order_lookup agent. "
"For order modifications like cancellations, use the order_actions agent. "
"For anything else, use the fallback agent."
"Route customer requests to the appropriate agent based on their description.\n\n"
"Available agents and their roles:\n"
"{agent_descriptions}\n\n"
"Routing rules:\n"
"- For order status and tracking queries, use the order_lookup agent.\n"
"- For order modifications like cancellations, use the order_actions agent.\n"
"- For discounts, promotions, or coupon codes, use the discount agent.\n"
"- For anything else or when uncertain, use the fallback agent.\n"
"- If the user's request involves multiple actions, execute them in order.\n"
"- If a previous intent classification is provided, follow it.\n"
)
def _format_agent_descriptions(registry: AgentRegistry) -> str:
"""Build agent description text for the supervisor prompt."""
lines = []
for agent in registry.list_agents():
lines.append(f"- {agent.name}: {agent.description}")
return "\n".join(lines)
def build_agent_nodes(
registry: AgentRegistry,
llm: BaseChatModel,
@@ -41,11 +60,11 @@ def build_agent_nodes(
f"Permission level: {agent_config.permission}."
)
agent_node = create_react_agent(
agent_node = create_agent(
model=llm,
tools=tools,
name=agent_config.name,
prompt=system_prompt,
system_prompt=system_prompt,
)
agent_nodes.append(agent_node)
@@ -56,15 +75,29 @@ def build_graph(
registry: AgentRegistry,
llm: BaseChatModel,
checkpointer: AsyncPostgresSaver,
) -> CompiledStateGraph:
"""Build and compile the LangGraph supervisor graph."""
intent_classifier: IntentClassifier | None = None,
) -> GraphContext:
"""Build and compile the LangGraph supervisor graph.
Returns a GraphContext that bundles the compiled graph with its
associated registry and intent classifier.
"""
agent_nodes = build_agent_nodes(registry, llm)
agent_descriptions = _format_agent_descriptions(registry)
prompt = SUPERVISOR_PROMPT.format(agent_descriptions=agent_descriptions)
workflow = create_supervisor(
agent_nodes,
agents=agent_nodes,
model=llm,
prompt=SUPERVISOR_PROMPT,
prompt=prompt,
output_mode="full_history",
)
return workflow.compile(checkpointer=checkpointer)
compiled = workflow.compile(checkpointer=checkpointer)
return GraphContext(
graph=compiled,
registry=registry,
intent_classifier=intent_classifier,
)

View File

@@ -0,0 +1,36 @@
"""GraphContext -- typed wrapper around the compiled graph and its dependencies."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
from app.intent import ClassificationResult, IntentClassifier
from app.registry import AgentRegistry
@dataclass(frozen=True)
class GraphContext:
"""Bundles the compiled LangGraph graph with its associated services.
Replaces the previous pattern of monkey-patching attributes onto the
third-party CompiledStateGraph instance.
"""
graph: CompiledStateGraph
registry: AgentRegistry
intent_classifier: IntentClassifier | None = None
async def classify_intent(self, message: str) -> ClassificationResult | None:
"""Classify user intent using the attached classifier.
Returns None if no classifier is configured.
"""
if self.intent_classifier is None:
return None
agents = self.registry.list_agents()
return await self.intent_classifier.classify(message, agents)

119
backend/app/intent.py Normal file
View File

@@ -0,0 +1,119 @@
"""Intent classification using LLM structured output."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from pydantic import BaseModel
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
from app.registry import AgentConfig
import structlog
logger = structlog.get_logger()
CLASSIFICATION_PROMPT = (
"You are an intent classifier for a customer support system.\n"
"Given a user message, determine which agent(s) should handle it.\n\n"
"Available agents:\n{agent_list}\n\n"
"Rules:\n"
"- If the message clearly maps to one agent, return a single intent.\n"
"- If the message contains multiple distinct requests, return multiple intents "
"in execution order.\n"
"- If the message is vague or doesn't match any agent, set is_ambiguous=True "
"and provide a clarification_question.\n"
"- Never route to the fallback agent unless truly ambiguous.\n"
"- confidence should be between 0.0 and 1.0.\n"
)
AMBIGUITY_THRESHOLD = 0.5
class IntentTarget(BaseModel, frozen=True):
"""A single classified intent targeting a specific agent."""
agent_name: str
confidence: float
reasoning: str
class ClassificationResult(BaseModel, frozen=True):
"""Result of intent classification -- may contain multiple intents."""
intents: tuple[IntentTarget, ...]
is_ambiguous: bool = False
clarification_question: str | None = None
class IntentClassifier(Protocol):
"""Protocol for intent classification implementations."""
async def classify(
self,
message: str,
available_agents: tuple[AgentConfig, ...],
) -> ClassificationResult: ...
def _build_agent_list(agents: tuple[AgentConfig, ...]) -> str:
"""Format agent descriptions for the classification prompt."""
lines = []
for agent in agents:
lines.append(f"- {agent.name}: {agent.description} (permission: {agent.permission})")
return "\n".join(lines)
class LLMIntentClassifier:
"""Classifies user intent using LLM structured output."""
def __init__(self, llm: BaseChatModel) -> None:
self._llm = llm
async def classify(
self,
message: str,
available_agents: tuple[AgentConfig, ...],
) -> ClassificationResult:
"""Classify user message into one or more agent intents."""
agent_list = _build_agent_list(available_agents)
system_prompt = CLASSIFICATION_PROMPT.format(agent_list=agent_list)
structured_llm = self._llm.with_structured_output(ClassificationResult)
try:
result = await structured_llm.ainvoke(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
]
)
except Exception:
logger.exception("Intent classification failed, returning ambiguous")
return ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="I'm not sure I understood. Could you please rephrase?",
)
if not isinstance(result, ClassificationResult):
return ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="I'm not sure I understood. Could you please rephrase?",
)
# Apply ambiguity threshold
if result.intents and all(i.confidence < AMBIGUITY_THRESHOLD for i in result.intents):
return ClassificationResult(
intents=result.intents,
is_ambiguous=True,
clarification_question=(
result.clarification_question
or "I'm not sure I understood. Could you please rephrase?"
),
)
return result

View File

@@ -0,0 +1,268 @@
"""Interrupt TTL management -- tracks pending interrupts with auto-expiration.
Provides both in-memory (InterruptManager) and PostgreSQL-backed
(PgInterruptManager) implementations behind a common Protocol.
"""
from __future__ import annotations
import time
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
@dataclass(frozen=True)
class InterruptRecord:
"""Immutable record of a pending interrupt."""
interrupt_id: str
thread_id: str
action: str
params: dict
created_at: float
ttl_seconds: int
@dataclass(frozen=True)
class InterruptStatus:
"""Current status of a tracked interrupt."""
is_expired: bool
remaining_seconds: float
record: InterruptRecord
class InterruptManagerProtocol(Protocol):
"""Protocol for interrupt TTL management."""
def register(self, thread_id: str, action: str, params: dict) -> InterruptRecord: ...
def check_status(self, thread_id: str) -> InterruptStatus | None: ...
def resolve(self, thread_id: str) -> None: ...
def has_pending(self, thread_id: str) -> bool: ...
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict: ...
def _build_retry_prompt(expired_record: InterruptRecord) -> dict:
"""Generate a WebSocket message prompting the user to retry an expired action."""
return {
"type": "interrupt_expired",
"thread_id": expired_record.thread_id,
"action": expired_record.action,
"message": (
f"The approval request for '{expired_record.action}' has expired "
f"after {expired_record.ttl_seconds // 60} minutes. "
f"Would you like to try again?"
),
}
class InterruptManager:
"""In-memory interrupt manager for single-worker development.
Complements SessionManager -- this tracks interrupt-specific TTL
while SessionManager handles session-level TTL.
"""
def __init__(self, ttl_seconds: int = 1800) -> None:
self._ttl_seconds = ttl_seconds
self._interrupts: dict[str, InterruptRecord] = {}
def register(
self,
thread_id: str,
action: str,
params: dict,
) -> InterruptRecord:
"""Register a new pending interrupt with TTL tracking."""
record = InterruptRecord(
interrupt_id=uuid.uuid4().hex,
thread_id=thread_id,
action=action,
params=dict(params),
created_at=time.time(),
ttl_seconds=self._ttl_seconds,
)
self._interrupts = {**self._interrupts, thread_id: record}
return record
def check_status(self, thread_id: str) -> InterruptStatus | None:
"""Check the TTL status of a pending interrupt."""
record = self._interrupts.get(thread_id)
if record is None:
return None
elapsed = time.time() - record.created_at
remaining = max(0.0, record.ttl_seconds - elapsed)
is_expired = elapsed > record.ttl_seconds
return InterruptStatus(
is_expired=is_expired,
remaining_seconds=remaining,
record=record,
)
def resolve(self, thread_id: str) -> None:
"""Remove a resolved interrupt from tracking."""
self._interrupts = {
k: v for k, v in self._interrupts.items() if k != thread_id
}
def cleanup_expired(self) -> tuple[InterruptRecord, ...]:
"""Find and remove all expired interrupts. Returns the expired records."""
now = time.time()
expired: list[InterruptRecord] = []
active: dict[str, InterruptRecord] = {}
for thread_id, record in self._interrupts.items():
if now - record.created_at > record.ttl_seconds:
expired.append(record)
else:
active[thread_id] = record
self._interrupts = active
return tuple(expired)
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
"""Generate a WebSocket message prompting the user to retry an expired action."""
return _build_retry_prompt(expired_record)
def has_pending(self, thread_id: str) -> bool:
"""Check if a thread has a pending (non-expired) interrupt."""
status = self.check_status(thread_id)
if status is None:
return False
return not status.is_expired
# Alias for explicit naming
InMemoryInterruptManager = InterruptManager
class PgInterruptManager:
"""PostgreSQL-backed interrupt manager for multi-worker production.
Uses the existing active_interrupts table defined in db.py.
"""
def __init__(
self,
pool: AsyncConnectionPool,
ttl_seconds: int = 1800,
) -> None:
self._pool = pool
self._ttl_seconds = ttl_seconds
def register(
self,
thread_id: str,
action: str,
params: dict,
) -> InterruptRecord:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._register(thread_id, action, params)
)
async def _register(
self, thread_id: str, action: str, params: dict
) -> InterruptRecord:
import json
record = InterruptRecord(
interrupt_id=uuid.uuid4().hex,
thread_id=thread_id,
action=action,
params=dict(params),
created_at=time.time(),
ttl_seconds=self._ttl_seconds,
)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO active_interrupts (interrupt_id, thread_id, action, params)
VALUES (%(iid)s, %(tid)s, %(action)s, %(params)s)
ON CONFLICT (thread_id) WHERE resolved_at IS NULL
DO UPDATE SET
interrupt_id = %(iid)s,
action = %(action)s,
params = %(params)s,
created_at = NOW(),
resolved_at = NULL
""",
{
"iid": record.interrupt_id,
"tid": thread_id,
"action": action,
"params": json.dumps(params),
},
)
return record
def check_status(self, thread_id: str) -> InterruptStatus | None:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._check_status(thread_id)
)
async def _check_status(self, thread_id: str) -> InterruptStatus | None:
async with self._pool.connection() as conn:
cursor = await conn.execute(
"""
SELECT interrupt_id, action, params, created_at
FROM active_interrupts
WHERE thread_id = %(tid)s AND resolved_at IS NULL
ORDER BY created_at DESC LIMIT 1
""",
{"tid": thread_id},
)
row = await cursor.fetchone()
if row is None:
return None
created_at = row["created_at"].timestamp()
elapsed = time.time() - created_at
remaining = max(0.0, self._ttl_seconds - elapsed)
is_expired = elapsed > self._ttl_seconds
record = InterruptRecord(
interrupt_id=row["interrupt_id"],
thread_id=thread_id,
action=row["action"],
params=row["params"] if isinstance(row["params"], dict) else {},
created_at=created_at,
ttl_seconds=self._ttl_seconds,
)
return InterruptStatus(
is_expired=is_expired,
remaining_seconds=remaining,
record=record,
)
def resolve(self, thread_id: str) -> None:
import asyncio
asyncio.get_event_loop().run_until_complete(self._resolve(thread_id))
async def _resolve(self, thread_id: str) -> None:
async with self._pool.connection() as conn:
await conn.execute(
"""
UPDATE active_interrupts
SET resolved_at = NOW(), resolution = 'resolved'
WHERE thread_id = %(tid)s AND resolved_at IS NULL
""",
{"tid": thread_id},
)
def generate_retry_prompt(self, expired_record: InterruptRecord) -> dict:
return _build_retry_prompt(expired_record)
def has_pending(self, thread_id: str) -> bool:
status = self.check_status(thread_id)
if status is None:
return False
return not status.is_expired

View File

@@ -31,6 +31,16 @@ def create_llm(settings: Settings) -> BaseChatModel:
api_key=settings.openai_api_key,
)
if provider == "azure_openai":
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(
azure_deployment=settings.azure_openai_deployment,
azure_endpoint=settings.azure_openai_endpoint,
api_key=settings.azure_openai_api_key,
api_version=settings.azure_openai_api_version,
)
if provider == "google":
from langchain_google_genai import ChatGoogleGenerativeAI
@@ -39,4 +49,7 @@ def create_llm(settings: Settings) -> BaseChatModel:
google_api_key=settings.google_api_key,
)
raise ValueError(f"Unknown LLM provider: '{provider}'. Use 'anthropic', 'openai', or 'google'.")
raise ValueError(
f"Unknown LLM provider: '{provider}'. "
"Use 'anthropic', 'openai', 'azure_openai', or 'google'."
)

View File

@@ -0,0 +1,57 @@
"""Structured logging configuration using structlog."""
from __future__ import annotations
import logging
import sys
import structlog
def configure_logging(log_format: str = "console") -> None:
"""Configure structlog with stdlib integration.
Args:
log_format: "console" for human-readable dev output,
"json" for machine-parseable production output.
"""
shared_processors: list[structlog.types.Processor] = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
]
if log_format == "json":
renderer: structlog.types.Processor = structlog.processors.JSONRenderer()
else:
renderer = structlog.dev.ConsoleRenderer()
structlog.configure(
processors=[
*shared_processors,
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
formatter = structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
renderer,
],
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)

View File

@@ -2,79 +2,211 @@
from __future__ import annotations
import logging
import asyncio
import contextlib
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from app.analytics.api import router as analytics_router
from app.analytics.event_recorder import PostgresAnalyticsRecorder
from app.api_utils import envelope
from app.callbacks import TokenUsageCallbackHandler
from app.config import Settings
from app.db import create_checkpointer, create_pool, setup_app_tables
from app.conversation_tracker import PostgresConversationTracker
from app.db import create_checkpointer, create_pool, run_alembic_migrations
from app.escalation import NoOpEscalator, WebhookEscalator
from app.graph import build_graph
from app.intent import LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.llm import create_llm
from app.logging_config import configure_logging
from app.openapi.review_api import router as openapi_router
from app.registry import AgentRegistry
from app.replay.api import router as replay_router
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
logger = logging.getLogger(__name__)
import structlog
logger = structlog.get_logger()
AGENTS_YAML = Path(__file__).parent.parent / "agents.yaml"
FRONTEND_DIST = Path(__file__).parent.parent.parent / "frontend" / "dist"
async def _interrupt_cleanup_loop(
interrupt_manager: InterruptManager,
interval: int = 60,
) -> None:
"""Periodically remove expired interrupts in the background.
Runs until cancelled. Catches all exceptions to prevent the task
from dying unexpectedly.
"""
while True:
await asyncio.sleep(interval)
try:
expired = interrupt_manager.cleanup_expired()
if expired:
logger.info(
"Cleaned up %d expired interrupt(s)",
len(expired),
)
except Exception:
logger.exception("Error during interrupt cleanup")
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
settings = Settings()
configure_logging(settings.log_format)
pool = await create_pool(settings)
checkpointer = await create_checkpointer(pool)
await setup_app_tables(pool)
run_alembic_migrations(settings.database_url)
# Load agents from template or default YAML
if settings.template_name:
registry = AgentRegistry.load_template(settings.template_name)
else:
registry = AgentRegistry.load(AGENTS_YAML)
llm = create_llm(settings)
graph = build_graph(registry, llm, checkpointer)
intent_classifier = LLMIntentClassifier(llm)
graph_ctx = build_graph(registry, llm, checkpointer, intent_classifier=intent_classifier)
session_manager = SessionManager(
session_ttl_seconds=settings.session_ttl_minutes * 60,
)
interrupt_manager = InterruptManager(
ttl_seconds=settings.interrupt_ttl_minutes * 60,
)
app.state.graph = graph
# Configure escalation
if settings.webhook_url:
escalator = WebhookEscalator(
url=settings.webhook_url,
timeout_seconds=settings.webhook_timeout_seconds,
max_retries=settings.webhook_max_retries,
)
else:
escalator = NoOpEscalator()
app.state.graph_ctx = graph_ctx
app.state.session_manager = session_manager
app.state.interrupt_manager = interrupt_manager
app.state.escalator = escalator
app.state.settings = settings
app.state.pool = pool
app.state.analytics_recorder = PostgresAnalyticsRecorder(pool=pool)
app.state.conversation_tracker = PostgresConversationTracker()
logger.info(
"Smart Support started: %d agents loaded, LLM=%s/%s",
"Smart Support started: %d agents loaded, LLM=%s/%s, template=%s",
len(registry),
settings.llm_provider,
settings.llm_model,
settings.template_name or "(default)",
)
cleanup_task = asyncio.create_task(
_interrupt_cleanup_loop(interrupt_manager),
)
yield
cleanup_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await cleanup_task
await pool.close()
app = FastAPI(title="Smart Support", version="0.1.0", lifespan=lifespan)
_VERSION = "0.6.0"
app = FastAPI(title="Smart Support", version=_VERSION, lifespan=lifespan)
app.include_router(openapi_router)
app.include_router(replay_router)
app.include_router(analytics_router)
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Wrap HTTPException in standard envelope format."""
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Wrap validation errors in standard envelope format."""
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
"""Catch-all handler -- never leak stack traces."""
logger.exception("Unhandled exception: %s", exc)
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
@app.get("/api/v1/health")
def health_check() -> dict:
"""Health check endpoint for load balancers and monitoring."""
return {"status": "ok", "version": _VERSION}
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket) -> None:
await ws.accept()
graph = app.state.graph
session_manager = app.state.session_manager
async def websocket_endpoint(
ws: WebSocket,
token: str | None = Query(default=None),
) -> None:
settings = app.state.settings
# Verify WebSocket token when admin_api_key is configured
if settings.admin_api_key:
import secrets as _secrets
if token is None or not _secrets.compare_digest(token, settings.admin_api_key):
await ws.close(code=4001, reason="Unauthorized")
return
await ws.accept()
callback_handler = TokenUsageCallbackHandler(model_name=settings.llm_model)
ws_ctx = WebSocketContext(
graph_ctx=app.state.graph_ctx,
session_manager=app.state.session_manager,
callback_handler=callback_handler,
interrupt_manager=app.state.interrupt_manager,
analytics_recorder=app.state.analytics_recorder,
conversation_tracker=app.state.conversation_tracker,
pool=app.state.pool,
)
try:
while True:
raw_data = await ws.receive_text()
await dispatch_message(ws, graph, session_manager, callback_handler, raw_data)
await dispatch_message(ws, ws_ctx, raw_data)
except WebSocketDisconnect:
logger.info("WebSocket client disconnected")

View File

@@ -0,0 +1,2 @@
# OpenAPI auto-discovery module
# Parses OpenAPI specs, classifies endpoints via LLM, generates tools

View File

@@ -0,0 +1,169 @@
"""OpenAPI endpoint classifier.
Classifies endpoints into read/write access types and identifies
customer-identifying parameters. Provides a rule-based heuristic
classifier and an LLM-backed classifier with heuristic fallback.
"""
from __future__ import annotations
import json
import re
from typing import Protocol
import structlog
from app.openapi.models import ClassificationResult, EndpointInfo
logger = structlog.get_logger()
_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
_INTERRUPT_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
# Parameter names that identify the customer/order context
_CUSTOMER_PARAM_PATTERNS = re.compile(
r"(customer_id|order_id|user_id|account_id|client_id|subscriber_id)",
re.IGNORECASE,
)
class ClassifierProtocol(Protocol):
"""Protocol for endpoint classifiers."""
async def classify(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]: ...
class HeuristicClassifier:
"""Rule-based endpoint classifier.
GET -> read, no interrupt.
POST/PUT/PATCH/DELETE -> write, needs_interrupt=True.
"""
async def classify(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Classify endpoints using HTTP method heuristics."""
if not endpoints:
return ()
return tuple(_classify_one(ep) for ep in endpoints)
def _classify_one(ep: EndpointInfo) -> ClassificationResult:
"""Classify a single endpoint using heuristics."""
access_type = "write" if ep.method in _WRITE_METHODS else "read"
needs_interrupt = ep.method in _INTERRUPT_METHODS
customer_params = _detect_customer_params(ep)
agent_group = "write_agent" if access_type == "write" else "read_agent"
return ClassificationResult(
endpoint=ep,
access_type=access_type,
customer_params=customer_params,
agent_group=agent_group,
confidence=0.7,
needs_interrupt=needs_interrupt,
)
def _detect_customer_params(ep: EndpointInfo) -> tuple[str, ...]:
"""Extract parameter names that identify the customer/order context."""
return tuple(
p.name
for p in ep.parameters
if _CUSTOMER_PARAM_PATTERNS.fullmatch(p.name)
)
class LLMClassifier:
"""LLM-backed endpoint classifier with heuristic fallback.
Uses an LLM to classify endpoints with higher accuracy.
Falls back to HeuristicClassifier on any LLM error.
"""
def __init__(self, llm: object) -> None:
self._llm = llm
self._fallback = HeuristicClassifier()
async def classify(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Classify endpoints using LLM with heuristic fallback."""
if not endpoints:
return ()
try:
return await self._classify_with_llm(endpoints)
except Exception:
logger.warning(
"LLM classification failed, falling back to heuristic",
exc_info=True,
)
return await self._fallback.classify(endpoints)
async def _classify_with_llm(
self, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Attempt LLM-based classification."""
prompt = _build_classification_prompt(endpoints)
response = await self._llm.ainvoke(prompt)
parsed = _parse_llm_response(response.content, endpoints)
return parsed
def _build_classification_prompt(endpoints: tuple[EndpointInfo, ...]) -> str:
"""Build a prompt for classifying endpoints."""
items = []
for i, ep in enumerate(endpoints):
items.append(
f"{i}: {ep.method} {ep.path} - {ep.summary or ep.operation_id}"
)
endpoint_list = "\n".join(items)
return (
"Classify each API endpoint as 'read' or 'write'. "
"For each, determine if it needs human interrupt approval, "
"identify customer-identifying parameters, and assign an agent_group.\n\n"
f"Endpoints:\n{endpoint_list}\n\n"
"Respond with a JSON array with one object per endpoint:\n"
'[{"access_type": "read|write", "agent_group": "...", '
'"needs_interrupt": true|false, "customer_params": [...], "confidence": 0.0-1.0}]'
)
def _parse_llm_response(
content: str, endpoints: tuple[EndpointInfo, ...]
) -> tuple[ClassificationResult, ...]:
"""Parse LLM JSON response into ClassificationResult instances.
Raises ValueError if the response cannot be parsed or is mismatched.
"""
# Extract JSON array from response
match = re.search(r"\[.*\]", content, re.DOTALL)
if not match:
raise ValueError(f"No JSON array found in LLM response: {content!r}")
items = json.loads(match.group())
if not isinstance(items, list) or len(items) != len(endpoints):
raise ValueError(
f"LLM returned {len(items)} items for {len(endpoints)} endpoints"
)
results: list[ClassificationResult] = []
for ep, item in zip(endpoints, items, strict=True):
raw_access = item.get("access_type", "read")
access_type = raw_access if raw_access in {"read", "write"} else "read"
confidence = max(0.0, min(1.0, float(item.get("confidence", 0.8))))
raw_group = str(item.get("agent_group", "support"))
agent_group = raw_group if raw_group.strip() else "support"
results.append(
ClassificationResult(
endpoint=ep,
access_type=access_type,
customer_params=tuple(item.get("customer_params", [])),
agent_group=agent_group,
confidence=confidence,
needs_interrupt=bool(item.get("needs_interrupt", False)),
)
)
return tuple(results)

View File

@@ -0,0 +1,93 @@
"""OpenAPI spec fetcher with SSRF protection.
Fetches OpenAPI spec documents from remote URLs, validates them against
SSRF policy, and parses JSON or YAML format automatically.
"""
from __future__ import annotations
import json
import yaml
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
_MAX_SIZE_BYTES = 10 * 1024 * 1024 # 10MB
async def fetch_spec(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> dict:
"""Fetch an OpenAPI spec from a URL and return as a dict.
Auto-detects JSON or YAML format from content-type header or URL extension.
Enforces a 10MB size limit.
Raises:
SSRFError: If the URL is blocked by SSRF policy.
ValueError: If the response is too large or cannot be parsed.
"""
from app.openapi.ssrf import safe_fetch
response = await safe_fetch(url, policy=policy)
response.raise_for_status()
content = response.text
if len(content.encode("utf-8")) > _MAX_SIZE_BYTES:
raise ValueError(
f"Response too large: {len(content.encode('utf-8'))} bytes "
f"(max {_MAX_SIZE_BYTES} bytes)"
)
content_type = response.headers.get("content-type", "")
return _parse_content(content, content_type, url)
def _parse_content(content: str, content_type: str, url: str) -> dict:
"""Parse content as JSON or YAML based on content-type or URL extension."""
if _is_yaml_format(content_type, url):
return _parse_yaml(content)
if _is_json_format(content_type, url):
return _parse_json(content)
# Fall back: try JSON first, then YAML
try:
return _parse_json(content)
except ValueError:
return _parse_yaml(content)
def _is_yaml_format(content_type: str, url: str) -> bool:
"""Check if the content is YAML format."""
yaml_types = {"application/x-yaml", "text/yaml", "application/yaml"}
if any(t in content_type for t in yaml_types):
return True
lower_url = url.lower().split("?")[0]
return lower_url.endswith(".yaml") or lower_url.endswith(".yml")
def _is_json_format(content_type: str, url: str) -> bool:
"""Check if the content is JSON format."""
if "application/json" in content_type:
return True
lower_url = url.lower().split("?")[0]
return lower_url.endswith(".json")
def _parse_json(content: str) -> dict:
"""Parse content as JSON, raising ValueError on failure."""
try:
result = json.loads(content)
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid JSON: {exc}") from exc
if not isinstance(result, dict):
raise ValueError(f"Expected a JSON object, got {type(result).__name__}")
return result
def _parse_yaml(content: str) -> dict:
"""Parse content as YAML, raising ValueError on failure."""
try:
result = yaml.safe_load(content)
except yaml.YAMLError as exc:
raise ValueError(f"Invalid YAML: {exc}") from exc
if not isinstance(result, dict):
raise ValueError(f"Expected a YAML mapping, got {type(result).__name__}")
return result

View File

@@ -0,0 +1,164 @@
"""Tool code generator for classified OpenAPI endpoints.
Generates Python source code for LangChain @tool functions and
YAML agent configurations from classification results.
"""
from __future__ import annotations
import keyword
import re
import yaml
from app.openapi.models import ClassificationResult, EndpointInfo, GeneratedTool, ParameterInfo
_INDENT = " "
def generate_tool_code(classification: ClassificationResult, base_url: str) -> GeneratedTool:
"""Generate a LangChain @tool function for a classified endpoint.
Returns a GeneratedTool with the function source code as a string.
"""
ep = classification.endpoint
func_name = _to_snake_case(ep.operation_id)
params = _collect_params(ep)
sig = _build_signature(params, ep.request_body_schema)
docstring = _sanitize_docstring(ep.summary or ep.description or ep.operation_id)
interrupt_comment = _interrupt_comment(classification)
http_call = _build_http_call(ep, base_url, params)
lines = [
"@tool",
f"async def {func_name}({sig}) -> str:",
f'{_INDENT}"""{docstring}"""',
]
if interrupt_comment:
lines.append(f"{_INDENT}{interrupt_comment}")
lines += [
f"{_INDENT}async with httpx.AsyncClient() as client:",
f"{_INDENT}{_INDENT}{http_call}",
f"{_INDENT}{_INDENT}return response.text",
]
code = "\n".join(lines)
return GeneratedTool(
function_name=func_name,
endpoint=ep,
classification=classification,
code=code,
)
def generate_agent_yaml(
classifications: tuple[ClassificationResult, ...],
base_url: str,
) -> str:
"""Generate an agents.yaml string from a set of classification results.
Groups tools by agent_group, creating one agent entry per group.
"""
if not classifications:
return yaml.safe_dump({"agents": []})
groups: dict[str, dict] = {}
for clf in classifications:
group = clf.agent_group
func_name = _to_snake_case(clf.endpoint.operation_id)
if group not in groups:
permission = "read" if clf.access_type == "read" else "write"
groups[group] = {
"name": group,
"description": f"Agent for {group} operations",
"permission": permission,
"tools": [],
}
groups[group]["tools"].append(func_name)
return yaml.safe_dump({"agents": list(groups.values())}, sort_keys=False)
# --- Private helpers ---
def _collect_params(ep: EndpointInfo) -> list[ParameterInfo]:
"""Return path params first, then query params."""
path_params = [p for p in ep.parameters if p.location == "path"]
other_params = [p for p in ep.parameters if p.location != "path"]
return path_params + other_params
def _build_signature(params: list[ParameterInfo], body_schema: dict | None) -> str:
"""Build a Python function signature string from parameters."""
parts: list[str] = []
for p in params:
py_type = _schema_type_to_python(p.schema_type)
if p.required:
parts.append(f"{p.name}: {py_type}")
else:
parts.append(f"{p.name}: {py_type} | None = None")
if body_schema:
parts.append("body: dict | None = None")
return ", ".join(parts)
def _build_http_call(
ep: EndpointInfo, base_url: str, params: list[ParameterInfo]
) -> str:
"""Build the httpx client call line."""
method = ep.method.lower()
path = ep.path
# Replace path parameters with f-string expressions
for p in params:
if p.location == "path":
path = path.replace(f"{{{p.name}}}", f"{{{p.name}}}")
url_expr = f'f"{base_url}{path}"'
query_params = [p for p in params if p.location == "query"]
extra_args = []
if query_params:
qp_dict = "{" + ", ".join(f'"{p.name}": {p.name}' for p in query_params) + "}"
extra_args.append(f"params={qp_dict}")
if ep.request_body_schema and method in ("post", "put", "patch"):
extra_args.append("json=body")
args_str = ", ".join([url_expr] + extra_args)
return f"response = await client.{method}({args_str})"
def _interrupt_comment(classification: ClassificationResult) -> str:
"""Return a comment line if the endpoint requires interrupt/approval."""
if classification.needs_interrupt:
return "# INTERRUPT: requires human approval before execution"
return ""
def _schema_type_to_python(schema_type: str) -> str:
"""Map OpenAPI schema type to Python type annotation."""
mapping = {
"string": "str",
"integer": "int",
"number": "float",
"boolean": "bool",
"array": "list",
"object": "dict",
}
return mapping.get(schema_type, "str")
def _sanitize_docstring(text: str) -> str:
"""Escape triple-quotes and newlines to prevent docstring injection."""
return text.replace("\\", "\\\\").replace('"""', r"\"\"\"").replace("\n", " ")
def _to_snake_case(name: str) -> str:
"""Convert operationId to a valid snake_case Python identifier."""
clean = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_")
result = clean.lower() or "unnamed_tool"
if keyword.iskeyword(result):
result = f"{result}_tool"
return result

View File

@@ -0,0 +1,111 @@
"""Import orchestrator for OpenAPI auto-discovery pipeline.
Orchestrates: fetch -> validate -> parse -> classify
Each stage updates the job status and calls the on_progress callback.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import replace
import structlog
from app.openapi.classifier import ClassifierProtocol, HeuristicClassifier
from app.openapi.fetcher import fetch_spec
from app.openapi.models import ImportJob
from app.openapi.parser import parse_endpoints
from app.openapi.ssrf import DEFAULT_POLICY, SSRFPolicy
from app.openapi.validator import validate_spec
logger = structlog.get_logger()
ProgressCallback = Callable[[str, ImportJob], None] | None
class ImportOrchestrator:
"""Orchestrates the full OpenAPI import pipeline.
Stages:
1. fetching -- download and parse spec from URL
2. validating -- check spec structure
3. parsing -- extract endpoint definitions
4. classifying -- classify endpoints for agent routing
5. done / failed
"""
def __init__(
self,
classifier: ClassifierProtocol | None = None,
policy: SSRFPolicy = DEFAULT_POLICY,
) -> None:
self._classifier = classifier or HeuristicClassifier()
self._policy = policy
async def start_import(
self,
url: str,
job_id: str,
on_progress: ProgressCallback,
) -> ImportJob:
"""Run the full import pipeline for a spec URL.
Returns an ImportJob reflecting final status (done or failed).
on_progress is called with (stage_name, current_job) at each stage.
Passing None for on_progress is safe.
"""
job = ImportJob(
job_id=job_id,
status="pending",
spec_url=url,
)
try:
# Stage 1: fetch
job = _update(job, status="fetching")
_notify(on_progress, "fetching", job)
spec_dict = await fetch_spec(url, self._policy)
# Stage 2: validate
job = _update(job, status="validating")
_notify(on_progress, "validating", job)
errors = validate_spec(spec_dict)
if errors:
raise ValueError(f"Invalid spec: {'; '.join(errors)}")
# Stage 3: parse
job = _update(job, status="parsing")
_notify(on_progress, "parsing", job)
endpoints = parse_endpoints(spec_dict)
# Stage 4: classify
job = _update(job, status="classifying", total_endpoints=len(endpoints))
_notify(on_progress, "classifying", job)
classifications = await self._classifier.classify(endpoints)
# Done
job = _update(
job,
status="done",
total_endpoints=len(endpoints),
classified_count=len(classifications),
)
_notify(on_progress, "done", job)
return job
except Exception as exc:
logger.exception("Import pipeline failed for job %s", job_id)
job = _update(job, status="failed", error_message=str(exc))
_notify(on_progress, "failed", job)
return job
def _update(job: ImportJob, **kwargs: object) -> ImportJob:
"""Return a new ImportJob with updated fields (immutable update)."""
return replace(job, **kwargs)
def _notify(callback: ProgressCallback, stage: str, job: ImportJob) -> None:
"""Call the progress callback if provided."""
if callback is not None:
callback(stage, job)

View File

@@ -0,0 +1,67 @@
"""Data models for OpenAPI auto-discovery module.
Frozen dataclasses for all value objects to ensure immutability.
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass(frozen=True)
class ParameterInfo:
"""Describes a single endpoint parameter."""
name: str
location: str # "path", "query", "header", "cookie"
required: bool
schema_type: str # "string", "integer", "boolean", etc.
description: str = ""
@dataclass(frozen=True)
class EndpointInfo:
"""Describes a single API endpoint."""
path: str
method: str # uppercase: GET, POST, PUT, DELETE, PATCH
operation_id: str
summary: str
description: str
parameters: tuple[ParameterInfo, ...] = field(default_factory=tuple)
request_body_schema: dict | None = None
response_schema: dict | None = None
@dataclass(frozen=True)
class ClassificationResult:
"""Result of classifying an endpoint for agent routing."""
endpoint: EndpointInfo
access_type: str # "read" or "write"
customer_params: tuple[str, ...] # param names that identify the customer
agent_group: str # which agent group handles this endpoint
confidence: float # 0.0 to 1.0
needs_interrupt: bool # requires human approval before execution
@dataclass(frozen=True)
class ImportJob:
"""Tracks the state of an OpenAPI import job."""
job_id: str
status: str # "pending", "fetching", "validating", "parsing", "classifying", "done", "failed"
spec_url: str
total_endpoints: int = 0
classified_count: int = 0
error_message: str | None = None
@dataclass(frozen=True)
class GeneratedTool:
"""A generated LangChain tool from a classified endpoint."""
function_name: str
endpoint: EndpointInfo
classification: ClassificationResult
code: str

View File

@@ -0,0 +1,152 @@
"""OpenAPI spec endpoint parser.
Extracts all endpoint definitions from a parsed OpenAPI spec dict,
resolving $ref references from components.
"""
from __future__ import annotations
import re
from app.openapi.models import EndpointInfo, ParameterInfo
_HTTP_METHODS = ("get", "post", "put", "patch", "delete", "head", "options")
def parse_endpoints(spec_dict: dict) -> tuple[EndpointInfo, ...]:
"""Parse all endpoints from a validated OpenAPI spec dict.
Returns an immutable tuple of EndpointInfo instances.
"""
paths = spec_dict.get("paths", {})
if not isinstance(paths, dict) or not paths:
return ()
endpoints: list[EndpointInfo] = []
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
for method in _HTTP_METHODS:
operation = path_item.get(method)
if operation is None:
continue
endpoint = _parse_operation(path, method.upper(), operation, spec_dict)
endpoints.append(endpoint)
return tuple(endpoints)
def _parse_operation(
path: str,
method: str,
operation: dict,
spec_dict: dict,
) -> EndpointInfo:
"""Parse a single operation dict into an EndpointInfo."""
operation_id = operation.get("operationId") or _generate_operation_id(path, method)
summary = operation.get("summary", "")
description = operation.get("description", "")
parameters = _parse_parameters(operation.get("parameters", []), spec_dict)
request_body_schema = _parse_request_body(operation.get("requestBody"), spec_dict)
response_schema = _parse_response_schema(operation.get("responses", {}), spec_dict)
return EndpointInfo(
path=path,
method=method,
operation_id=operation_id,
summary=summary,
description=description,
parameters=tuple(parameters),
request_body_schema=request_body_schema,
response_schema=response_schema,
)
def _parse_parameters(
params_list: list,
spec_dict: dict,
) -> list[ParameterInfo]:
"""Parse list of parameter dicts into ParameterInfo instances."""
result: list[ParameterInfo] = []
for param in params_list:
if not isinstance(param, dict):
continue
schema = param.get("schema", {})
schema_type = schema.get("type", "string") if isinstance(schema, dict) else "string"
result.append(
ParameterInfo(
name=param.get("name", ""),
location=param.get("in", "query"),
required=bool(param.get("required", False)),
schema_type=schema_type,
description=param.get("description", ""),
)
)
return result
def _parse_request_body(request_body: dict | None, spec_dict: dict) -> dict | None:
"""Extract schema from requestBody, resolving $ref if present."""
if not isinstance(request_body, dict):
return None
content = request_body.get("content", {})
if not isinstance(content, dict):
return None
# Prefer application/json
for media_type in ("application/json", *content.keys()):
media = content.get(media_type)
if not isinstance(media, dict):
continue
schema = media.get("schema")
if schema:
return _resolve_ref(schema, spec_dict)
return None
def _parse_response_schema(responses: dict, spec_dict: dict) -> dict | None:
"""Extract schema from the first 2xx response."""
if not isinstance(responses, dict):
return None
for status_code in ("200", "201", "202", "204"):
response = responses.get(status_code)
if not isinstance(response, dict):
continue
content = response.get("content", {})
if not isinstance(content, dict):
continue
for media_type in ("application/json", *content.keys()):
media = content.get(media_type)
if not isinstance(media, dict):
continue
schema = media.get("schema")
if schema:
return _resolve_ref(schema, spec_dict)
return None
def _resolve_ref(schema: object, spec_dict: dict) -> dict:
"""Resolve a $ref to its target schema, or return the schema as-is."""
if not isinstance(schema, dict):
return {}
ref = schema.get("$ref")
if not ref:
return schema
# Only handle local refs like #/components/schemas/Foo
if not isinstance(ref, str) or not ref.startswith("#/"):
return schema
parts = ref.lstrip("#/").split("/")
target: object = spec_dict
for part in parts:
if not isinstance(target, dict):
return schema
target = target.get(part)
return target if isinstance(target, dict) else schema
def _generate_operation_id(path: str, method: str) -> str:
"""Generate a snake_case operation_id from path and method."""
# Remove path parameters braces and replace / with _
clean = re.sub(r"\{[^}]+\}", "by_param", path)
clean = re.sub(r"[^a-zA-Z0-9]+", "_", clean).strip("_")
return f"{method.lower()}_{clean}" if clean else method.lower()

View File

@@ -0,0 +1,282 @@
"""FastAPI router for OpenAPI import review workflow.
Exposes endpoints for:
- Starting an import job (triggers background pipeline)
- Querying job status
- Reviewing and editing classifications
- Approving a job to trigger tool generation
"""
from __future__ import annotations
import asyncio
import re
import uuid
from typing import Literal
import structlog
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from pydantic import BaseModel, field_validator
from app.auth import require_admin_api_key
from app.openapi.generator import generate_agent_yaml, generate_tool_code
from app.openapi.importer import ImportOrchestrator
from app.openapi.models import ClassificationResult, ImportJob
logger = structlog.get_logger()
router = APIRouter(
prefix="/api/v1/openapi",
tags=["openapi"],
dependencies=[Depends(require_admin_api_key)],
)
# In-memory store: job_id -> job dict, guarded by async lock
_job_store: dict[str, dict] = {}
_store_lock = asyncio.Lock()
# Shared orchestrator instance
_orchestrator = ImportOrchestrator()
# --- Request / Response schemas ---
class ImportRequest(BaseModel):
url: str
@field_validator("url")
@classmethod
def url_must_be_valid(cls, value: str) -> str:
stripped = value.strip()
if not stripped:
raise ValueError("url must not be empty")
if not stripped.startswith(("http://", "https://")):
raise ValueError("url must start with http:// or https://")
return stripped
class JobResponse(BaseModel):
job_id: str
status: str
spec_url: str
total_endpoints: int = 0
classified_count: int = 0
error_message: str | None = None
class ClassificationResponse(BaseModel):
index: int
access_type: str
needs_interrupt: bool
agent_group: str
confidence: float
customer_params: list[str]
endpoint: dict
class UpdateClassificationRequest(BaseModel):
access_type: Literal["read", "write"]
needs_interrupt: bool
agent_group: str
@field_validator("agent_group")
@classmethod
def agent_group_must_be_safe(cls, value: str) -> str:
if not value.strip() or not re.fullmatch(r"[a-zA-Z0-9_\-]+", value):
raise ValueError(
"agent_group must be non-empty and contain only "
"alphanumeric characters, underscores, or hyphens"
)
return value
# --- Helpers ---
def _job_to_response(job: dict) -> dict:
return {
"job_id": job["job_id"],
"status": job["status"],
"spec_url": job["spec_url"],
"total_endpoints": job.get("total_endpoints", 0),
"classified_count": job.get("classified_count", 0),
"error_message": job.get("error_message"),
}
def _classification_to_response(idx: int, clf: ClassificationResult) -> dict:
ep = clf.endpoint
return {
"index": idx,
"access_type": clf.access_type,
"needs_interrupt": clf.needs_interrupt,
"agent_group": clf.agent_group,
"confidence": clf.confidence,
"customer_params": list(clf.customer_params),
"endpoint": {
"path": ep.path,
"method": ep.method,
"operation_id": ep.operation_id,
"summary": ep.summary,
"description": ep.description,
},
}
async def _run_import(job_id: str, url: str) -> None:
"""Run the import pipeline as a background task."""
def on_progress(stage: str, result_job: ImportJob) -> None:
if job_id in _job_store:
_job_store[job_id] = {
**_job_store[job_id],
"status": result_job.status,
"total_endpoints": result_job.total_endpoints,
"classified_count": result_job.classified_count,
"error_message": result_job.error_message,
}
try:
result = await _orchestrator.start_import(
url=url, job_id=job_id, on_progress=on_progress,
)
if job_id in _job_store:
_job_store[job_id] = {
**_job_store[job_id],
"status": result.status,
"total_endpoints": result.total_endpoints,
"classified_count": result.classified_count,
"error_message": result.error_message,
}
except Exception:
logger.exception("Background import failed for job %s", job_id)
if job_id in _job_store:
_job_store[job_id] = {
**_job_store[job_id],
"status": "failed",
"error_message": "Import failed. Please check the URL and try again.",
}
# --- Endpoints ---
@router.post("/import", status_code=202)
async def start_import(
request: ImportRequest,
background_tasks: BackgroundTasks,
) -> dict:
"""Start an OpenAPI import job for the given spec URL."""
job_id = str(uuid.uuid4())
job: dict = {
"job_id": job_id,
"status": "pending",
"spec_url": request.url,
"total_endpoints": 0,
"classified_count": 0,
"error_message": None,
"classifications": [],
}
_job_store[job_id] = job
background_tasks.add_task(_run_import, job_id, request.url)
return _job_to_response(job)
@router.get("/jobs/{job_id}")
async def get_job(job_id: str) -> dict:
"""Get the status of an import job."""
job = _job_store.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
return _job_to_response(job)
@router.get("/jobs/{job_id}/classifications")
async def get_classifications(job_id: str) -> list:
"""Get all classifications for an import job."""
job = _job_store.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
classifications: list[ClassificationResult] = job.get("classifications", [])
return [
_classification_to_response(i, clf)
for i, clf in enumerate(classifications)
]
@router.put("/jobs/{job_id}/classifications/{idx}")
async def update_classification(
job_id: str,
idx: int,
request: UpdateClassificationRequest,
) -> dict:
"""Update a specific classification by index."""
job = _job_store.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
classifications: list[ClassificationResult] = job.get("classifications", [])
if idx < 0 or idx >= len(classifications):
raise HTTPException(
status_code=404,
detail=f"Classification index {idx} out of range",
)
original = classifications[idx]
updated = ClassificationResult(
endpoint=original.endpoint,
access_type=request.access_type,
customer_params=original.customer_params,
agent_group=request.agent_group,
confidence=original.confidence,
needs_interrupt=request.needs_interrupt,
)
new_classifications = list(classifications)
new_classifications[idx] = updated
_job_store[job_id] = {**job, "classifications": new_classifications}
return _classification_to_response(idx, updated)
@router.post("/jobs/{job_id}/approve")
async def approve_job(job_id: str) -> dict:
"""Approve a job's classifications and trigger tool generation.
Generates Python tool code for each classified endpoint and
produces an agent YAML configuration snippet.
"""
job = _job_store.get(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found")
classifications: list[ClassificationResult] = job.get("classifications", [])
if not classifications:
raise HTTPException(
status_code=400,
detail="No classifications to approve. Import must complete first.",
)
base_url = job["spec_url"].rsplit("/", 1)[0]
generated_tools = []
for clf in classifications:
tool = generate_tool_code(clf, base_url)
generated_tools.append({
"function_name": tool.function_name,
"agent_group": clf.agent_group,
"code": tool.code,
})
agent_yaml = generate_agent_yaml(tuple(classifications), base_url)
updated_job = {
**job,
"status": "approved",
"generated_tools": generated_tools,
"agent_yaml": agent_yaml,
}
_job_store[job_id] = updated_job
response = _job_to_response(updated_job)
response["generated_tools_count"] = len(generated_tools)
return response

167
backend/app/openapi/ssrf.py Normal file
View File

@@ -0,0 +1,167 @@
"""SSRF protection module.
Validates URLs before making external HTTP requests.
Blocks private IPs, loopback addresses, link-local, and DNS rebinding attacks.
"""
from __future__ import annotations
import ipaddress
import socket
from dataclasses import dataclass
from urllib.parse import urlparse
import httpx
class SSRFError(Exception):
"""Raised when a URL fails SSRF validation."""
@dataclass(frozen=True)
class SSRFPolicy:
"""Configuration for SSRF protection."""
allowed_schemes: frozenset[str] = frozenset({"http", "https"})
allowed_hosts: frozenset[str] | None = None # None = all public hosts allowed
max_redirects: int = 5
timeout_seconds: float = 30.0
_BLOCKED_NETWORKS = (
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("0.0.0.0/32"),
ipaddress.ip_network("100.64.0.0/10"), # Carrier-Grade NAT
ipaddress.ip_network("198.18.0.0/15"), # Benchmarking
ipaddress.ip_network("240.0.0.0/4"), # Reserved
ipaddress.ip_network("255.255.255.255/32"), # Broadcast
# IPv6
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fe80::/10"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("::/128"),
ipaddress.ip_network("::ffff:0:0/96"), # IPv4-mapped IPv6
ipaddress.ip_network("2001:db8::/32"), # Documentation
)
DEFAULT_POLICY = SSRFPolicy()
def is_private_ip(ip_str: str) -> bool:
"""Check if an IP address is private/reserved."""
try:
addr = ipaddress.ip_address(ip_str)
except ValueError:
return True # Invalid IP treated as blocked
return any(addr in network for network in _BLOCKED_NETWORKS)
def validate_url(url: str, policy: SSRFPolicy = DEFAULT_POLICY) -> str:
"""Validate a URL against SSRF policy.
Returns the validated URL string.
Raises SSRFError if the URL is blocked.
"""
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in policy.allowed_schemes:
raise SSRFError(
f"URL scheme '{parsed.scheme}' is not allowed. "
f"Allowed: {', '.join(sorted(policy.allowed_schemes))}"
)
# Check hostname exists
hostname = parsed.hostname
if not hostname:
raise SSRFError("URL has no hostname")
# Check allowed hosts whitelist
if policy.allowed_hosts is not None and hostname not in policy.allowed_hosts:
raise SSRFError(f"Host '{hostname}' is not in the allowed hosts list")
# DNS resolution -- resolve before making any request
resolved_ips = resolve_hostname(hostname)
if not resolved_ips:
raise SSRFError(f"Could not resolve hostname '{hostname}'")
# Check all resolved IPs against blocked networks
for ip_str in resolved_ips:
if is_private_ip(ip_str):
raise SSRFError(
f"Host '{hostname}' resolves to private/reserved IP {ip_str}. "
"Request blocked for SSRF protection."
)
return url
def resolve_hostname(hostname: str) -> list[str]:
"""Resolve hostname to IP addresses via DNS."""
try:
results = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
return list({result[4][0] for result in results})
except socket.gaierror:
return []
async def safe_fetch(
url: str,
*,
policy: SSRFPolicy = DEFAULT_POLICY,
method: str = "GET",
headers: dict[str, str] | None = None,
) -> httpx.Response:
"""Fetch a URL with SSRF protection.
Validates the URL, resolves DNS, checks IPs, then makes the request.
After receiving the response, verifies the actual connected IP
to guard against DNS rebinding.
"""
validate_url(url, policy)
# Make the request with redirect following disabled so we can check each hop
async with httpx.AsyncClient(
follow_redirects=False,
timeout=httpx.Timeout(policy.timeout_seconds),
) as client:
current_url = url
for _redirect_count in range(policy.max_redirects + 1):
response = await client.request(
method,
current_url,
headers=headers,
)
if response.is_redirect:
redirect_url = str(response.next_request.url) if response.next_request else None
if not redirect_url:
raise SSRFError("Redirect with no target URL")
# Validate the redirect target
validate_url(redirect_url, policy)
current_url = redirect_url
continue
return response
raise SSRFError(
f"Too many redirects (max {policy.max_redirects}). "
"Possible redirect loop or evasion attempt."
)
async def safe_fetch_text(
url: str,
*,
policy: SSRFPolicy = DEFAULT_POLICY,
headers: dict[str, str] | None = None,
) -> str:
"""Fetch a URL and return text content with SSRF protection."""
response = await safe_fetch(url, policy=policy, headers=headers)
response.raise_for_status()
return response.text

View File

@@ -0,0 +1,51 @@
"""OpenAPI spec validator.
Validates an OpenAPI spec dict for required fields and basic structural
correctness. Returns a list of human-readable error strings.
"""
from __future__ import annotations
_SUPPORTED_VERSIONS = ("3.0.", "3.1.")
_REQUIRED_FIELDS = ("openapi", "info", "paths")
def validate_spec(spec_dict: object) -> list[str]:
"""Validate an OpenAPI spec dict.
Returns a list of error strings. An empty list means the spec is valid.
Does not raise; all errors are captured and returned.
"""
if not isinstance(spec_dict, dict):
return [f"Spec must be a dict, got {type(spec_dict).__name__}"]
errors: list[str] = []
# Check required top-level fields
for field in _REQUIRED_FIELDS:
if field not in spec_dict:
errors.append(f"Missing required field: '{field}'")
# Validate openapi version if present
if "openapi" in spec_dict:
version = spec_dict["openapi"]
if not isinstance(version, str):
errors.append(f"'openapi' must be a string, got {type(version).__name__}")
elif not any(version.startswith(prefix) for prefix in _SUPPORTED_VERSIONS):
errors.append(
f"Unsupported OpenAPI version '{version}'. "
f"Supported versions start with: {', '.join(_SUPPORTED_VERSIONS)}"
)
# Validate info object if present
if "info" in spec_dict and isinstance(spec_dict["info"], dict):
info = spec_dict["info"]
for sub_field in ("title", "version"):
if sub_field not in info:
errors.append(f"Missing required field in 'info': '{sub_field}'")
# Validate paths object if present
if "paths" in spec_dict and not isinstance(spec_dict["paths"], dict):
errors.append(f"'paths' must be a dict, got {type(spec_dict['paths']).__name__}")
return errors

View File

@@ -100,5 +100,41 @@ class AgentRegistry:
def get_agents_by_permission(self, permission: str) -> tuple[AgentConfig, ...]:
return tuple(a for a in self._agents.values() if a.permission == permission)
@classmethod
def load_template(
cls,
template_name: str,
templates_dir: str | Path | None = None,
) -> AgentRegistry:
"""Load agent configurations from a named template."""
if templates_dir is None:
templates_dir = Path(__file__).parent.parent / "templates"
templates_dir = Path(templates_dir)
yaml_path = templates_dir / f"{template_name}.yaml"
if not yaml_path.exists():
available = cls.list_templates(templates_dir)
raise FileNotFoundError(
f"Template '{template_name}' not found. "
f"Available: {', '.join(available) if available else 'none'}"
)
return cls.load(yaml_path)
@classmethod
def list_templates(
cls,
templates_dir: str | Path | None = None,
) -> tuple[str, ...]:
"""List available template names from the templates directory."""
if templates_dir is None:
templates_dir = Path(__file__).parent.parent / "templates"
templates_dir = Path(templates_dir)
if not templates_dir.is_dir():
return ()
return tuple(
sorted(p.stem for p in templates_dir.glob("*.yaml"))
)
def __len__(self) -> int:
return len(self._agents)

View File

@@ -0,0 +1,3 @@
"""Replay module -- conversation replay API and transformer."""
from __future__ import annotations

125
backend/app/replay/api.py Normal file
View File

@@ -0,0 +1,125 @@
"""Replay API router -- conversation listing and step-by-step replay."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from app.api_utils import envelope
from app.auth import require_admin_api_key
_THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
router = APIRouter(
prefix="/api/v1",
tags=["replay"],
dependencies=[Depends(require_admin_api_key)],
)
_COUNT_CONVERSATIONS_SQL = """
SELECT COUNT(*) FROM conversations
"""
_LIST_CONVERSATIONS_SQL = """
SELECT thread_id, created_at, last_activity, status, total_tokens, total_cost_usd
FROM conversations
ORDER BY last_activity DESC
LIMIT %(limit)s OFFSET %(offset)s
"""
_GET_CHECKPOINTS_SQL = """
SELECT thread_id, checkpoint_id, checkpoint, metadata
FROM checkpoints
WHERE thread_id = %(thread_id)s
ORDER BY checkpoint_id ASC
"""
async def get_pool(request: Request) -> AsyncConnectionPool:
"""Dependency: extract the shared pool from app state."""
return request.app.state.pool
@router.get("/conversations")
async def list_conversations(
request: Request,
page: Annotated[int, Query(ge=1)] = 1,
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
) -> dict:
"""List conversations with pagination."""
pool = await get_pool(request)
offset = (page - 1) * per_page
async with pool.connection() as conn:
count_cursor = await conn.execute(_COUNT_CONVERSATIONS_SQL)
count_row = await count_cursor.fetchone()
total = count_row[0] if count_row else 0
cursor = await conn.execute(
_LIST_CONVERSATIONS_SQL,
{"limit": per_page, "offset": offset},
)
rows = await cursor.fetchall()
return envelope({
"conversations": [dict(row) for row in rows],
"total": total,
"page": page,
"per_page": per_page,
})
@router.get("/replay/{thread_id}")
async def get_replay(
thread_id: str,
request: Request,
page: Annotated[int, Query(ge=1)] = 1,
per_page: Annotated[int, Query(ge=1, le=100)] = 20,
) -> dict:
"""Return paginated replay steps for a conversation thread."""
from app.replay.transformer import transform_checkpoints
if not _THREAD_ID_PATTERN.match(thread_id):
raise HTTPException(status_code=400, detail="Invalid thread_id format")
pool = await get_pool(request)
async with pool.connection() as conn:
cursor = await conn.execute(_GET_CHECKPOINTS_SQL, {"thread_id": thread_id})
rows = await cursor.fetchall()
if not rows:
raise HTTPException(status_code=404, detail="Thread not found")
all_steps = transform_checkpoints([dict(row) for row in rows])
total_steps = len(all_steps)
start = (page - 1) * per_page
end = start + per_page
page_steps = all_steps[start:end]
data = {
"thread_id": thread_id,
"total_steps": total_steps,
"page": page,
"per_page": per_page,
"steps": [
{
"step": s.step,
"type": s.type.value,
"timestamp": s.timestamp,
"content": s.content,
"agent": s.agent,
"tool": s.tool,
"params": s.params,
"result": s.result,
"reasoning": s.reasoning,
"tokens": s.tokens,
"duration_ms": s.duration_ms,
}
for s in page_steps
],
}
return envelope(data)

View File

@@ -0,0 +1,52 @@
"""Value objects for conversation replay."""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
class StepType(str, Enum):
"""Types of steps in a conversation replay."""
user_message = "user_message"
supervisor_routing = "supervisor_routing"
tool_call = "tool_call"
tool_result = "tool_result"
agent_response = "agent_response"
interrupt = "interrupt"
@dataclass(frozen=True)
class ReplayStep:
"""A single step in a conversation replay."""
step: int
type: StepType
timestamp: str
content: str = ""
agent: str | None = None
tool: str | None = None
params: dict | None = None
result: dict | None = None
reasoning: str | None = None
tokens: int | None = None
duration_ms: int | None = None
def __post_init__(self) -> None:
# Store params as a frozen copy to prevent mutation from the outside
if self.params is not None:
object.__setattr__(self, "params", dict(self.params))
if self.result is not None:
object.__setattr__(self, "result", dict(self.result))
@dataclass(frozen=True)
class ReplayPage:
"""A paginated page of replay steps for a conversation thread."""
thread_id: str
total_steps: int
page: int
per_page: int
steps: tuple[ReplayStep, ...]

View File

@@ -0,0 +1,116 @@
"""Transforms PostgresSaver checkpoint rows into ReplayStep list."""
from __future__ import annotations
import structlog
from app.replay.models import ReplayStep, StepType
logger = structlog.get_logger()
_EMPTY_TIMESTAMP = "1970-01-01T00:00:00Z"
def _extract_messages(row: dict) -> list[dict]:
"""Safely extract messages list from a checkpoint row."""
checkpoint = row.get("checkpoint")
if not checkpoint or not isinstance(checkpoint, dict):
return []
channel_values = checkpoint.get("channel_values")
if not channel_values or not isinstance(channel_values, dict):
return []
messages = channel_values.get("messages")
if not messages or not isinstance(messages, list):
return []
return messages
def _step_from_message(msg: dict, step_number: int) -> ReplayStep | None:
"""Convert a single message dict to a ReplayStep. Returns None for unknown types."""
msg_type = msg.get("type", "")
timestamp = msg.get("created_at") or _EMPTY_TIMESTAMP
content = msg.get("content") or ""
if isinstance(content, list):
# LangChain may encode content as a list of parts
content = " ".join(
part.get("text", "") if isinstance(part, dict) else str(part)
for part in content
)
if msg_type == "human":
return ReplayStep(
step=step_number,
type=StepType.user_message,
timestamp=timestamp,
content=content,
)
if msg_type == "ai":
tool_calls = msg.get("tool_calls") or []
if tool_calls:
first = tool_calls[0]
return ReplayStep(
step=step_number,
type=StepType.tool_call,
timestamp=timestamp,
content=content,
tool=first.get("name"),
params=dict(first.get("args") or {}),
)
return ReplayStep(
step=step_number,
type=StepType.agent_response,
timestamp=timestamp,
content=content,
agent=msg.get("name"),
)
if msg_type == "tool":
raw = content
result: dict | None = None
try:
import json
result = json.loads(raw)
except (ValueError, TypeError):
result = {"raw": raw}
return ReplayStep(
step=step_number,
type=StepType.tool_result,
timestamp=timestamp,
tool=msg.get("name"),
result=result,
)
logger.debug("Skipping unknown message type: %s", msg_type)
return None
def transform_checkpoints(rows: list[dict]) -> list[ReplayStep]:
"""Transform a list of checkpoint rows into an ordered list of ReplaySteps.
Steps are numbered sequentially starting from 1 across all rows.
Unknown or malformed messages are silently skipped.
"""
steps: list[ReplayStep] = []
step_number = 1
for row in rows:
try:
messages = _extract_messages(row)
except Exception: # noqa: BLE001
logger.exception("Error extracting messages from checkpoint row")
continue
for msg in messages:
try:
step = _step_from_message(msg, step_number)
except Exception: # noqa: BLE001
logger.exception("Error converting message to ReplayStep")
step = None
if step is not None:
steps.append(step)
step_number += 1
return steps

131
backend/app/safety.py Normal file
View File

@@ -0,0 +1,131 @@
"""Safety policy for destructive-action confirmation rules.
This module makes the confirmation rules explicit and auditable. Every tool
call passes through ``requires_confirmation`` before execution to decide
whether human-in-the-loop approval is needed.
Policy summary
--------------
- ``read`` actions: execute immediately, no confirmation required.
- ``write`` actions: require human approval via interrupt gate.
- OpenAPI-imported endpoints: use ``needs_interrupt`` from classification.
- If both the agent permission AND the endpoint classification agree
the action is read-only, it executes without confirmation.
Multi-intent semantics
----------------------
When a user message contains multiple intents (e.g. "cancel my order and
apply a refund"), the supervisor routes them sequentially. Each action is
evaluated independently:
- If a write action is blocked by an interrupt, subsequent actions in the
same message are paused until the interrupt is resolved.
- Read actions that follow a blocked write are also paused (sequential,
not best-effort) to preserve causal ordering.
- If an interrupt is rejected, the remaining actions are skipped and the
agent informs the user.
MCP error taxonomy
------------------
Tool execution errors are classified into categories for retry decisions:
- ``transient``: network timeouts, rate limits, 5xx -- retryable up to 3 times.
- ``validation``: bad parameters, 4xx -- not retryable, report to user.
- ``auth``: 401/403 -- not retryable, escalate.
- ``unknown``: unclassified -- not retryable, log and escalate.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
@dataclass(frozen=True)
class ConfirmationPolicy:
"""Result of evaluating whether an action needs confirmation."""
requires_confirmation: bool
reason: str
def requires_confirmation(
*,
agent_permission: Literal["read", "write"],
needs_interrupt: bool | None = None,
) -> ConfirmationPolicy:
"""Determine whether an action requires human confirmation.
Parameters
----------
agent_permission:
The permission level of the agent executing the action.
needs_interrupt:
Override from OpenAPI classification. When ``None``, the decision
is based solely on ``agent_permission``.
"""
if needs_interrupt is not None:
if needs_interrupt:
return ConfirmationPolicy(
requires_confirmation=True,
reason="Endpoint classified as requiring human approval",
)
return ConfirmationPolicy(
requires_confirmation=False,
reason="Endpoint classified as safe (no interrupt needed)",
)
if agent_permission == "write":
return ConfirmationPolicy(
requires_confirmation=True,
reason="Write-permission agent actions require confirmation",
)
return ConfirmationPolicy(
requires_confirmation=False,
reason="Read-only agent actions execute immediately",
)
# --- MCP Error Taxonomy ---
MCP_ERROR_CATEGORY = Literal["transient", "validation", "auth", "unknown"]
_TRANSIENT_STATUS_CODES = frozenset({408, 429, 500, 502, 503, 504})
_AUTH_STATUS_CODES = frozenset({401, 403})
_MAX_RETRIES = 3
def classify_mcp_error(
*,
status_code: int | None = None,
error_message: str = "",
) -> MCP_ERROR_CATEGORY:
"""Classify an MCP tool error for retry decisions."""
if status_code is not None:
if status_code in _TRANSIENT_STATUS_CODES:
return "transient"
if status_code in _AUTH_STATUS_CODES:
return "auth"
if 400 <= status_code < 500:
return "validation"
lower_msg = error_message.lower()
if any(kw in lower_msg for kw in ("timeout", "timed out", "rate limit")):
return "transient"
if any(kw in lower_msg for kw in ("unauthorized", "forbidden")):
return "auth"
if any(kw in lower_msg for kw in ("invalid", "missing", "bad request")):
return "validation"
return "unknown"
def is_retryable(category: MCP_ERROR_CATEGORY) -> bool:
"""Return whether a given error category is retryable."""
return category == "transient"
def max_retries() -> int:
"""Maximum retry attempts for transient errors."""
return _MAX_RETRIES

View File

@@ -1,9 +1,18 @@
"""Session TTL management with sliding window and interrupt extension."""
"""Session TTL management with sliding window and interrupt extension.
Provides both in-memory (SessionManager) and PostgreSQL-backed
(PgSessionManager) implementations behind a common Protocol.
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from psycopg_pool import AsyncConnectionPool
@dataclass(frozen=True)
@@ -13,8 +22,19 @@ class SessionState:
has_pending_interrupt: bool
class SessionManagerProtocol(Protocol):
"""Protocol for session TTL management."""
def touch(self, thread_id: str) -> SessionState: ...
def is_expired(self, thread_id: str) -> bool: ...
def extend_for_interrupt(self, thread_id: str) -> SessionState: ...
def resolve_interrupt(self, thread_id: str) -> SessionState: ...
def get_state(self, thread_id: str) -> SessionState | None: ...
def remove(self, thread_id: str) -> None: ...
class SessionManager:
"""Manages session TTL with sliding window and interrupt extensions.
"""In-memory session manager for single-worker development.
- Each message resets the TTL (sliding window).
- A pending interrupt suspends expiration until resolved.
@@ -40,10 +60,8 @@ class SessionManager:
state = self._sessions.get(thread_id)
if state is None:
return True
if state.has_pending_interrupt:
return False
elapsed = time.time() - state.last_activity
return elapsed > self._session_ttl
@@ -52,7 +70,6 @@ class SessionManager:
existing = self._sessions.get(thread_id)
if existing is None:
return self.touch(thread_id)
new_state = SessionState(
thread_id=thread_id,
last_activity=existing.last_activity,
@@ -76,3 +93,120 @@ class SessionManager:
def remove(self, thread_id: str) -> None:
self._sessions = {k: v for k, v in self._sessions.items() if k != thread_id}
# Alias for explicit naming
InMemorySessionManager = SessionManager
class PgSessionManager:
"""PostgreSQL-backed session manager for multi-worker production."""
def __init__(
self,
pool: AsyncConnectionPool,
session_ttl_seconds: int = 1800,
) -> None:
self._pool = pool
self._session_ttl = session_ttl_seconds
def touch(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(self._touch(thread_id))
async def _touch(self, thread_id: str) -> SessionState:
now = datetime.now(timezone.utc)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
VALUES (%(tid)s, %(now)s, FALSE)
ON CONFLICT (thread_id) DO UPDATE
SET last_activity = %(now)s
""",
{"tid": thread_id, "now": now},
)
return SessionState(
thread_id=thread_id,
last_activity=now.timestamp(),
has_pending_interrupt=False,
)
def is_expired(self, thread_id: str) -> bool:
state = self.get_state(thread_id)
if state is None:
return True
if state.has_pending_interrupt:
return False
elapsed = time.time() - state.last_activity
return elapsed > self._session_ttl
def extend_for_interrupt(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._set_interrupt(thread_id, True)
)
def resolve_interrupt(self, thread_id: str) -> SessionState:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._set_interrupt(thread_id, False)
)
async def _set_interrupt(
self, thread_id: str, has_interrupt: bool
) -> SessionState:
now = datetime.now(timezone.utc)
async with self._pool.connection() as conn:
await conn.execute(
"""
INSERT INTO sessions (thread_id, last_activity, has_pending_interrupt)
VALUES (%(tid)s, %(now)s, %(interrupt)s)
ON CONFLICT (thread_id) DO UPDATE
SET last_activity = %(now)s,
has_pending_interrupt = %(interrupt)s
""",
{"tid": thread_id, "now": now, "interrupt": has_interrupt},
)
return SessionState(
thread_id=thread_id,
last_activity=now.timestamp(),
has_pending_interrupt=has_interrupt,
)
def get_state(self, thread_id: str) -> SessionState | None:
import asyncio
return asyncio.get_event_loop().run_until_complete(
self._get_state(thread_id)
)
async def _get_state(self, thread_id: str) -> SessionState | None:
async with self._pool.connection() as conn:
cursor = await conn.execute(
"SELECT last_activity, has_pending_interrupt FROM sessions WHERE thread_id = %(tid)s",
{"tid": thread_id},
)
row = await cursor.fetchone()
if row is None:
return None
return SessionState(
thread_id=thread_id,
last_activity=row["last_activity"].timestamp(),
has_pending_interrupt=row["has_pending_interrupt"],
)
def remove(self, thread_id: str) -> None:
import asyncio
asyncio.get_event_loop().run_until_complete(self._remove(thread_id))
async def _remove(self, thread_id: str) -> None:
async with self._pool.connection() as conn:
await conn.execute(
"DELETE FROM sessions WHERE thread_id = %(tid)s",
{"tid": thread_id},
)

View File

@@ -0,0 +1,3 @@
"""Tools package for smart-support backend."""
from __future__ import annotations

View File

@@ -0,0 +1,72 @@
"""Error classification and retry logic for tool calls."""
from __future__ import annotations
import asyncio
from enum import Enum
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from collections.abc import Callable
import httpx
class ErrorCategory(Enum):
"""Categories for error classification to guide retry decisions."""
RETRYABLE = "retryable"
NON_RETRYABLE = "non_retryable"
AUTH_FAILURE = "auth_failure"
TIMEOUT = "timeout"
NETWORK = "network"
def classify_error(exc: Exception) -> ErrorCategory:
"""Classify an exception into an ErrorCategory.
Rules:
- httpx.TimeoutException -> TIMEOUT
- httpx.ConnectError -> NETWORK
- httpx.HTTPStatusError 401/403 -> AUTH_FAILURE
- httpx.HTTPStatusError 429/500/502/503 -> RETRYABLE
- anything else -> NON_RETRYABLE
"""
if isinstance(exc, httpx.TimeoutException):
return ErrorCategory.TIMEOUT
if isinstance(exc, httpx.ConnectError):
return ErrorCategory.NETWORK
if isinstance(exc, httpx.HTTPStatusError):
code = exc.response.status_code
if code in (401, 403):
return ErrorCategory.AUTH_FAILURE
if code in (429, 500, 502, 503):
return ErrorCategory.RETRYABLE
return ErrorCategory.NON_RETRYABLE
return ErrorCategory.NON_RETRYABLE
async def with_retry(
fn: Callable[..., Any],
max_retries: int = 3,
base_delay: float = 1.0,
) -> Any:
"""Execute an async callable with exponential backoff for RETRYABLE errors.
Only ErrorCategory.RETRYABLE errors trigger retries. All other error
categories raise immediately after the first attempt.
"""
last_exc: Exception | None = None
for attempt in range(1, max_retries + 1):
try:
return await fn()
except Exception as exc:
category = classify_error(exc)
if category != ErrorCategory.RETRYABLE:
raise
last_exc = exc
if attempt < max_retries:
delay = base_delay * (2 ** (attempt - 1))
await asyncio.sleep(delay)
raise last_exc # type: ignore[misc]

30
backend/app/ws_context.py Normal file
View File

@@ -0,0 +1,30 @@
"""WebSocketContext -- bundles all dependencies needed by dispatch_message."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from app.analytics.event_recorder import AnalyticsRecorder
from app.callbacks import TokenUsageCallbackHandler
from app.conversation_tracker import ConversationTrackerProtocol
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
@dataclass(frozen=True)
class WebSocketContext:
"""All dependencies required for WebSocket message processing.
Replaces the previous 9-parameter function signature in dispatch_message.
"""
graph_ctx: GraphContext
session_manager: SessionManager
callback_handler: TokenUsageCallbackHandler
interrupt_manager: InterruptManager | None = None
analytics_recorder: AnalyticsRecorder | None = None
conversation_tracker: ConversationTrackerProtocol | None = None
pool: Any = None

View File

@@ -3,47 +3,98 @@
from __future__ import annotations
import json
import logging
import re
from typing import TYPE_CHECKING, Any
import time
from collections import defaultdict
from typing import TYPE_CHECKING
from langchain_core.messages import HumanMessage
from langgraph.types import Command
if TYPE_CHECKING:
from fastapi import WebSocket
from langgraph.graph.state import CompiledStateGraph
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
logger = logging.getLogger(__name__)
import structlog
logger = structlog.get_logger()
MAX_MESSAGE_SIZE = 32_768 # 32 KB
MAX_CONTENT_LENGTH = 8_000 # characters
MAX_CONTENT_LENGTH = 10_000 # characters
THREAD_ID_PATTERN = re.compile(r"^[a-zA-Z0-9\-_]{1,128}$")
# Rate limiting: max 10 messages per 10-second window, per thread
_RATE_LIMIT_MAX = 10
_RATE_LIMIT_WINDOW = 10.0
_MAX_TRACKED_THREADS = 10_000
_thread_timestamps: dict[str, list[float]] = defaultdict(list)
def _evict_stale_threads(cutoff: float) -> None:
"""Remove thread entries with no recent timestamps to prevent memory leak."""
stale = [tid for tid, ts in _thread_timestamps.items() if not ts or ts[-1] < cutoff]
for tid in stale:
del _thread_timestamps[tid]
async def handle_user_message(
ws: WebSocket,
graph: CompiledStateGraph,
ctx: GraphContext,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
content: str,
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Process a user message through the graph and stream results back."""
if session_manager.is_expired(thread_id):
existing = session_manager.get_state(thread_id)
if existing is not None and session_manager.is_expired(thread_id):
msg = "Session expired. Please start a new conversation."
await _send_json(ws, {"type": "error", "message": msg})
return
session_manager.touch(thread_id)
classification = await ctx.classify_intent(content)
if classification is not None:
logger.info(
"Intent classification for thread %s: ambiguous=%s, intents=%s",
thread_id,
classification.is_ambiguous,
[i.agent_name for i in classification.intents],
)
if classification.is_ambiguous and classification.clarification_question:
await _send_json(
ws,
{
"type": "clarification",
"thread_id": thread_id,
"message": classification.clarification_question,
},
)
await _send_json(ws, {"type": "message_complete", "thread_id": thread_id})
return
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
if classification and len(classification.intents) > 1:
agent_names = [i.agent_name for i in classification.intents]
hint = (
f"\n[System: This request involves multiple actions. "
f"Execute in order: {', '.join(agent_names)}]"
)
input_msg = {"messages": [HumanMessage(content=content + hint)]}
else:
input_msg = {"messages": [HumanMessage(content=content)]}
try:
async for chunk in graph.astream(input_msg, config=config, stream_mode="messages"):
async for chunk in ctx.graph.astream(input_msg, config=config, stream_mode="messages"):
msg_chunk, metadata = chunk
node = metadata.get("langgraph_node", "")
@@ -68,10 +119,18 @@ async def handle_user_message(
},
)
state = await graph.aget_state(config)
state = await ctx.graph.aget_state(config)
if _has_interrupt(state):
interrupt_data = _extract_interrupt(state)
session_manager.extend_for_interrupt(thread_id)
if interrupt_manager is not None:
interrupt_manager.register(
thread_id=thread_id,
action=interrupt_data.get("action", "unknown"),
params=interrupt_data.get("params", {}),
)
await _send_json(
ws,
{
@@ -91,20 +150,32 @@ async def handle_user_message(
async def handle_interrupt_response(
ws: WebSocket,
graph: CompiledStateGraph,
ctx: GraphContext,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
thread_id: str,
approved: bool,
interrupt_manager: InterruptManager | None = None,
) -> None:
"""Resume graph execution after interrupt approval/rejection."""
if interrupt_manager is not None:
status = interrupt_manager.check_status(thread_id)
if status is not None and status.is_expired:
retry_prompt = interrupt_manager.generate_retry_prompt(status.record)
interrupt_manager.resolve(thread_id)
session_manager.resolve_interrupt(thread_id)
await _send_json(ws, retry_prompt)
return
interrupt_manager.resolve(thread_id)
session_manager.resolve_interrupt(thread_id)
session_manager.touch(thread_id)
config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
try:
async for chunk in graph.astream(
async for chunk in ctx.graph.astream(
Command(resume=approved),
config=config,
stream_mode="messages",
@@ -132,9 +203,7 @@ async def handle_interrupt_response(
async def dispatch_message(
ws: WebSocket,
graph: CompiledStateGraph,
session_manager: SessionManager,
callback_handler: TokenUsageCallbackHandler,
ctx: WebSocketContext,
raw_data: str,
) -> None:
"""Parse and route an incoming WebSocket message."""
@@ -144,10 +213,14 @@ async def dispatch_message(
try:
data = json.loads(raw_data)
except json.JSONDecodeError:
except (json.JSONDecodeError, ValueError):
await _send_json(ws, {"type": "error", "message": "Invalid JSON"})
return
if not isinstance(data, dict):
await _send_json(ws, {"type": "error", "message": "Invalid JSON: expected object"})
return
msg_type = data.get("type")
thread_id = data.get("thread_id", "")
@@ -161,24 +234,81 @@ async def dispatch_message(
if msg_type == "message":
content = data.get("content", "")
if not content:
if not content or not content.strip():
await _send_json(ws, {"type": "error", "message": "Missing message content"})
return
if len(content) > MAX_CONTENT_LENGTH:
await _send_json(ws, {"type": "error", "message": "Message content too long"})
return
await handle_user_message(ws, graph, session_manager, callback_handler, thread_id, content)
# Rate limiting check (per-thread, with bounded memory)
now = time.time()
cutoff = now - _RATE_LIMIT_WINDOW
if len(_thread_timestamps) > _MAX_TRACKED_THREADS:
_evict_stale_threads(cutoff)
recent = [t for t in _thread_timestamps[thread_id] if t >= cutoff]
if len(recent) >= _RATE_LIMIT_MAX:
await _send_json(ws, {"type": "error", "message": "Rate limit exceeded"})
return
_thread_timestamps[thread_id] = [*recent, now]
await handle_user_message(
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
thread_id, content,
interrupt_manager=ctx.interrupt_manager,
)
await _fire_and_forget_tracking(
thread_id=thread_id,
pool=ctx.pool,
analytics_recorder=ctx.analytics_recorder,
conversation_tracker=ctx.conversation_tracker,
agent_name=None,
tokens=0,
cost=0.0,
)
elif msg_type == "interrupt_response":
approved = data.get("approved", False)
await handle_interrupt_response(
ws, graph, session_manager, callback_handler, thread_id, approved
ws, ctx.graph_ctx, ctx.session_manager, ctx.callback_handler,
thread_id, approved,
interrupt_manager=ctx.interrupt_manager,
)
else:
await _send_json(ws, {"type": "error", "message": "Unknown message type"})
async def _fire_and_forget_tracking(
thread_id: str,
pool: object,
analytics_recorder: object | None,
conversation_tracker: object | None,
agent_name: str | None,
tokens: int,
cost: float,
) -> None:
"""Fire-and-forget analytics/tracking; failures must NOT break chat."""
try:
if conversation_tracker is not None and pool is not None:
await conversation_tracker.ensure_conversation(pool, thread_id)
await conversation_tracker.record_turn(pool, thread_id, agent_name, tokens, cost)
except Exception:
logger.exception("Conversation tracker error for thread %s (suppressed)", thread_id)
try:
if analytics_recorder is not None:
await analytics_recorder.record(
thread_id=thread_id,
event_type="message",
agent_name=agent_name,
tokens_used=tokens,
cost_usd=cost,
)
except Exception:
logger.exception("Analytics recorder error for thread %s (suppressed)", thread_id)
def _has_interrupt(state: Any) -> bool:
"""Check if the graph state has a pending interrupt."""
tasks = getattr(state, "tasks", ())

View File

@@ -0,0 +1,153 @@
"""Seed script -- inserts sample conversations and analytics events for demo purposes.
Usage:
cd backend
python fixtures/demo_data.py
"""
from __future__ import annotations
import asyncio
import os
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import psycopg
DATABASE_URL = os.environ.get(
"DATABASE_URL",
"postgresql://smart_support:dev_password@localhost:5432/smart_support",
)
SAMPLE_CONVERSATIONS = [
{
"thread_id": "demo-thread-001",
"agents_used": ["order_agent"],
"turn_count": 3,
"total_tokens": 1250,
"total_cost_usd": 0.00375,
"resolution_type": "resolved",
"minutes_ago": 5,
},
{
"thread_id": "demo-thread-002",
"agents_used": ["order_agent", "refund_agent"],
"turn_count": 6,
"total_tokens": 3200,
"total_cost_usd": 0.0096,
"resolution_type": "resolved",
"minutes_ago": 30,
},
{
"thread_id": "demo-thread-003",
"agents_used": ["general_agent"],
"turn_count": 2,
"total_tokens": 800,
"total_cost_usd": 0.0024,
"resolution_type": None,
"minutes_ago": 60,
},
{
"thread_id": "demo-thread-004",
"agents_used": ["order_agent", "general_agent"],
"turn_count": 8,
"total_tokens": 4500,
"total_cost_usd": 0.0135,
"resolution_type": "escalated",
"minutes_ago": 120,
},
{
"thread_id": "demo-thread-005",
"agents_used": ["refund_agent"],
"turn_count": 4,
"total_tokens": 2100,
"total_cost_usd": 0.0063,
"resolution_type": "resolved",
"minutes_ago": 240,
},
]
SAMPLE_EVENTS = [
{"thread_id": "demo-thread-001", "event_type": "message", "agent_name": "order_agent", "tokens_used": 400, "cost_usd": 0.0012, "success": True},
{"thread_id": "demo-thread-001", "event_type": "tool_call", "agent_name": "order_agent", "tool_name": "get_order_status", "tokens_used": 0, "cost_usd": 0.0, "success": True},
{"thread_id": "demo-thread-002", "event_type": "message", "agent_name": "order_agent", "tokens_used": 1600, "cost_usd": 0.0048, "success": True},
{"thread_id": "demo-thread-002", "event_type": "message", "agent_name": "refund_agent", "tokens_used": 1600, "cost_usd": 0.0048, "success": True},
{"thread_id": "demo-thread-002", "event_type": "tool_call", "agent_name": "refund_agent", "tool_name": "process_refund", "tokens_used": 0, "cost_usd": 0.0, "success": True},
{"thread_id": "demo-thread-003", "event_type": "message", "agent_name": "general_agent", "tokens_used": 800, "cost_usd": 0.0024, "success": True},
{"thread_id": "demo-thread-004", "event_type": "message", "agent_name": "order_agent", "tokens_used": 2000, "cost_usd": 0.006, "success": True},
{"thread_id": "demo-thread-004", "event_type": "escalation", "agent_name": "general_agent", "tokens_used": 2500, "cost_usd": 0.0075, "success": False},
{"thread_id": "demo-thread-005", "event_type": "message", "agent_name": "refund_agent", "tokens_used": 2100, "cost_usd": 0.0063, "success": True},
]
_INSERT_CONVERSATION = """
INSERT INTO conversations
(thread_id, started_at, last_activity, turn_count, agents_used,
total_tokens, total_cost_usd, resolution_type, ended_at)
VALUES
(%(thread_id)s, %(started_at)s, %(last_activity)s, %(turn_count)s,
%(agents_used)s, %(total_tokens)s, %(total_cost_usd)s,
%(resolution_type)s, %(ended_at)s)
ON CONFLICT (thread_id) DO NOTHING
"""
_INSERT_EVENT = """
INSERT INTO analytics_events
(thread_id, event_type, agent_name, tool_name, tokens_used, cost_usd, success)
VALUES
(%(thread_id)s, %(event_type)s, %(agent_name)s, %(tool_name)s,
%(tokens_used)s, %(cost_usd)s, %(success)s)
"""
async def seed() -> None:
now = datetime.now(tz=timezone.utc)
async with await psycopg.AsyncConnection.connect(DATABASE_URL) as conn:
print("Seeding conversations...")
for conv in SAMPLE_CONVERSATIONS:
started_at = now - timedelta(minutes=conv["minutes_ago"])
last_activity = started_at + timedelta(minutes=conv["turn_count"] * 2)
ended_at = last_activity if conv["resolution_type"] else None
await conn.execute(
_INSERT_CONVERSATION,
{
"thread_id": conv["thread_id"],
"started_at": started_at,
"last_activity": last_activity,
"turn_count": conv["turn_count"],
"agents_used": conv["agents_used"],
"total_tokens": conv["total_tokens"],
"total_cost_usd": conv["total_cost_usd"],
"resolution_type": conv["resolution_type"],
"ended_at": ended_at,
},
)
print(f" Inserted conversation {conv['thread_id']}")
print("Seeding analytics events...")
for event in SAMPLE_EVENTS:
await conn.execute(
_INSERT_EVENT,
{
"thread_id": event["thread_id"],
"event_type": event["event_type"],
"agent_name": event.get("agent_name"),
"tool_name": event.get("tool_name"),
"tokens_used": event.get("tokens_used", 0),
"cost_usd": event.get("cost_usd", 0.0),
"success": event.get("success"),
},
)
print(f" Inserted event {event['event_type']} for {event['thread_id']}")
await conn.commit()
print("Done. Demo data seeded successfully.")
if __name__ == "__main__":
asyncio.run(seed())

View File

@@ -0,0 +1,238 @@
openapi: "3.0.3"
info:
title: "E-Commerce API"
description: "Sample e-commerce API for Smart Support demo."
version: "1.0.0"
servers:
- url: "https://api.example-shop.com/v1"
description: "Production server"
paths:
/orders/{order_id}:
get:
operationId: getOrder
summary: "Get order details"
description: "Retrieve the full details of a specific order."
parameters:
- name: order_id
in: path
required: true
schema:
type: string
responses:
"200":
description: "Order details"
content:
application/json:
schema:
$ref: "#/components/schemas/Order"
/orders/{order_id}/cancel:
post:
operationId: cancelOrder
summary: "Cancel an order"
description: "Cancel an order that has not yet been shipped."
parameters:
- name: order_id
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
reason:
type: string
responses:
"200":
description: "Order cancelled"
"400":
description: "Order cannot be cancelled (already shipped)"
/orders/{order_id}/refund:
post:
operationId: refundOrder
summary: "Request a refund"
description: "Submit a refund request for a completed order."
parameters:
- name: order_id
in: path
required: true
schema:
type: string
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
amount:
type: number
description: "Refund amount in USD. Leave null for full refund."
reason:
type: string
responses:
"200":
description: "Refund submitted"
"400":
description: "Invalid refund request"
/customers/{customer_id}:
get:
operationId: getCustomer
summary: "Get customer profile"
description: "Retrieve customer profile and account information."
parameters:
- name: customer_id
in: path
required: true
schema:
type: string
responses:
"200":
description: "Customer profile"
content:
application/json:
schema:
$ref: "#/components/schemas/Customer"
/customers/{customer_id}/orders:
get:
operationId: listCustomerOrders
summary: "List customer orders"
description: "Get a paginated list of orders for a customer."
parameters:
- name: customer_id
in: path
required: true
schema:
type: string
- name: page
in: query
schema:
type: integer
default: 1
- name: per_page
in: query
schema:
type: integer
default: 20
responses:
"200":
description: "List of orders"
/products/{product_id}:
get:
operationId: getProduct
summary: "Get product details"
description: "Retrieve product information including inventory status."
parameters:
- name: product_id
in: path
required: true
schema:
type: string
responses:
"200":
description: "Product details"
/support/tickets:
post:
operationId: createSupportTicket
summary: "Create support ticket"
description: "Open a new support ticket for a customer issue."
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CreateTicketRequest"
responses:
"201":
description: "Ticket created"
/support/tickets/{ticket_id}:
get:
operationId: getSupportTicket
summary: "Get support ticket"
description: "Retrieve a support ticket and its conversation history."
parameters:
- name: ticket_id
in: path
required: true
schema:
type: string
responses:
"200":
description: "Ticket details"
components:
schemas:
Order:
type: object
properties:
order_id:
type: string
customer_id:
type: string
status:
type: string
enum: [pending, processing, shipped, delivered, cancelled, refunded]
items:
type: array
items:
$ref: "#/components/schemas/OrderItem"
total_usd:
type: number
created_at:
type: string
format: date-time
OrderItem:
type: object
properties:
product_id:
type: string
name:
type: string
quantity:
type: integer
unit_price_usd:
type: number
Customer:
type: object
properties:
customer_id:
type: string
email:
type: string
name:
type: string
tier:
type: string
enum: [standard, premium, vip]
created_at:
type: string
format: date-time
CreateTicketRequest:
type: object
required: [customer_id, subject, description]
properties:
customer_id:
type: string
subject:
type: string
description:
type: string
priority:
type: string
enum: [low, medium, high, urgent]
default: medium

View File

@@ -6,18 +6,23 @@ requires-python = ">=3.11"
dependencies = [
"fastapi>=0.115,<1.0",
"uvicorn[standard]>=0.34,<1.0",
"langgraph>=0.4,<1.0",
"langgraph-supervisor>=0.0.12,<1.0",
"langgraph>=1.0,<2.0",
"langgraph-supervisor>=0.0.30,<1.0",
"langgraph-checkpoint-postgres>=3.0,<4.0",
"langchain-core>=0.3,<1.0",
"langchain-anthropic>=0.3,<2.0",
"langchain-openai>=0.3,<1.0",
"langchain>=1.0,<2.0",
"langchain-core>=1.0,<2.0",
"langchain-anthropic>=1.0,<2.0",
"langchain-openai>=1.0,<2.0",
"langchain-google-genai>=2.1,<3.0",
"psycopg[binary,pool]>=3.2,<4.0",
"pydantic>=2.10,<3.0",
"pydantic-settings>=2.7,<3.0",
"pyyaml>=6.0,<7.0",
"python-dotenv>=1.0,<2.0",
"httpx>=0.28,<1.0",
"openapi-spec-validator>=0.7,<1.0",
"alembic>=1.13,<2.0",
"structlog>=24.0,<26.0",
]
[project.optional-dependencies]
@@ -27,6 +32,7 @@ dev = [
"pytest-cov>=6.0,<7.0",
"httpx>=0.28,<1.0",
"ruff>=0.9,<1.0",
"pytest-httpx>=0.35,<1.0",
]
[build-system]

View File

@@ -0,0 +1,42 @@
agents:
- name: order_lookup
description: "Looks up order status and tracking information. Use for queries about order status, shipping, and delivery."
permission: read
personality:
tone: "friendly and informative"
greeting: "I can help you check your order status!"
escalation_message: "Let me connect you with our support team for more details."
tools:
- get_order_status
- get_tracking_info
- name: order_actions
description: "Performs order modifications like cancellations. Use when the customer wants to cancel, modify, or change an order."
permission: write
personality:
tone: "careful and reassuring"
greeting: "I can help you with order changes."
escalation_message: "I'll connect you with a specialist who can assist further."
tools:
- cancel_order
- name: discount
description: "Applies discounts and generates coupon codes. Use when the customer asks for discounts, promotions, or coupon codes."
permission: write
personality:
tone: "generous and accommodating"
greeting: "I can help you with discounts and coupons!"
escalation_message: "Let me connect you with our promotions team."
tools:
- apply_discount
- generate_coupon
- name: fallback
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
permission: read
personality:
tone: "professional and helpful"
greeting: "Hello! How can I help you today?"
escalation_message: "Let me connect you with a human agent who can better assist you."
tools:
- fallback_respond

View File

@@ -0,0 +1,31 @@
agents:
- name: transaction_lookup
description: "Looks up transaction history, balances, and payment details. Use for queries about transactions and account activity."
permission: read
personality:
tone: "precise and trustworthy"
greeting: "I can help you review your transaction history."
escalation_message: "Let me connect you with our financial support team."
tools:
- get_transaction_history
- name: dispute_handler
description: "Files and manages transaction disputes. Use when the customer wants to dispute a charge or check dispute status."
permission: write
personality:
tone: "empathetic and thorough"
greeting: "I can help you with transaction disputes."
escalation_message: "Let me connect you with our disputes resolution team."
tools:
- file_dispute
- check_dispute_status
- name: fallback
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
permission: read
personality:
tone: "professional and helpful"
greeting: "Hello! How can I help you today?"
escalation_message: "Let me connect you with a human agent who can better assist you."
tools:
- fallback_respond

View File

@@ -0,0 +1,31 @@
agents:
- name: account_lookup
description: "Looks up account status, subscription details, and billing history. Use for queries about account information."
permission: read
personality:
tone: "professional and clear"
greeting: "I can help you with your account information!"
escalation_message: "Let me connect you with our account support team."
tools:
- get_account_status
- name: subscription_management
description: "Manages subscription changes like plan upgrades, downgrades, and cancellations. Use when the customer wants to change their subscription."
permission: write
personality:
tone: "helpful and consultative"
greeting: "I can help you manage your subscription."
escalation_message: "Let me connect you with our billing specialist."
tools:
- change_plan
- cancel_subscription
- name: fallback
description: "Handles general questions, unclear requests, and conversations that don't match other agents."
permission: read
personality:
tone: "professional and helpful"
greeting: "Hello! How can I help you today?"
escalation_message: "Let me connect you with a human agent who can better assist you."
tools:
- fallback_respond

View File

@@ -15,6 +15,16 @@ if TYPE_CHECKING:
from pathlib import Path
@pytest.fixture(autouse=True)
def clear_rate_limit_state() -> None:
"""Clear module-level rate limit state between tests to prevent leakage."""
import app.ws_handler as ws_handler
ws_handler._thread_timestamps.clear()
yield
ws_handler._thread_timestamps.clear()
@pytest.fixture
def test_settings() -> Settings:
return Settings(

View File

@@ -0,0 +1,230 @@
"""E2E test fixtures -- full FastAPI app with mocked LLM and database."""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from httpx import ASGITransport, AsyncClient
from app.analytics.api import router as analytics_router
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager
from app.openapi.review_api import _job_store, router as openapi_router
from app.replay.api import router as replay_router
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
# ---------------------------------------------------------------------------
# Graph helpers -- simulate LangGraph streaming behaviour
# ---------------------------------------------------------------------------
class AsyncIterHelper:
"""Make a list behave as an async iterator."""
def __init__(self, items: list) -> None:
self._items = list(items)
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
def make_chunk(content: str, node: str = "order_lookup") -> tuple:
c = MagicMock()
c.content = content
c.tool_calls = []
return (c, {"langgraph_node": node})
def make_tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
c = MagicMock()
c.content = ""
c.tool_calls = [{"name": name, "args": args}]
return (c, {"langgraph_node": node})
def make_state(*, interrupt: bool = False, data: dict | None = None) -> Any:
s = MagicMock()
if interrupt:
obj = MagicMock()
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
t = MagicMock()
t.interrupts = (obj,)
s.tasks = (t,)
else:
s.tasks = ()
return s
def make_graph(
chunks: list | None = None,
state: Any = None,
resume_chunks: list | None = None,
) -> MagicMock:
"""Build a mock LangGraph CompiledStateGraph."""
g = MagicMock()
if state is None:
state = make_state()
streams = [chunks or [], resume_chunks or []]
idx = {"n": 0}
def astream_side_effect(*a, **kw):
i = min(idx["n"], len(streams) - 1)
idx["n"] += 1
return AsyncIterHelper(list(streams[i]))
g.astream = MagicMock(side_effect=astream_side_effect)
g.aget_state = AsyncMock(return_value=state)
return g
def make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
"""Build a GraphContext wrapping a mock graph."""
g = graph or make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
# ---------------------------------------------------------------------------
# Fake database pool
# ---------------------------------------------------------------------------
class FakeCursor:
"""Minimal async cursor returning pre-configured rows."""
def __init__(self, rows: list[dict]) -> None:
self._rows = rows
async def fetchall(self) -> list[dict]:
return self._rows
async def fetchone(self) -> tuple | dict | None:
return self._rows[0] if self._rows else None
class FakeConnection:
"""Fake async connection that returns a FakeCursor."""
def __init__(self, rows: list[dict]) -> None:
self._rows = rows
async def execute(self, query: str, params: dict | None = None) -> FakeCursor:
return FakeCursor(self._rows)
class FakePool:
"""Minimal pool that yields a fake connection."""
def __init__(self, rows: list[dict] | None = None) -> None:
self._rows = rows or []
@asynccontextmanager
async def connection(self):
yield FakeConnection(self._rows)
async def close(self) -> None:
pass
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
def create_e2e_app(
graph: MagicMock | None = None,
pool: FakePool | None = None,
session_ttl: int = 3600,
interrupt_ttl: int = 1800,
) -> FastAPI:
"""Create a FastAPI app wired with mocked dependencies for E2E testing."""
g = graph or make_graph()
graph_ctx = make_graph_ctx(g)
p = pool or FakePool()
sm = SessionManager(session_ttl_seconds=session_ttl)
im = InterruptManager(ttl_seconds=interrupt_ttl)
app = FastAPI(title="Smart Support E2E Test")
app.include_router(openapi_router)
app.include_router(replay_router)
app.include_router(analytics_router)
app.state.graph_ctx = graph_ctx
app.state.session_manager = sm
app.state.interrupt_manager = im
app.state.pool = p
app.state.settings = MagicMock(llm_model="test-model")
app.state.analytics_recorder = AsyncMock()
app.state.conversation_tracker = AsyncMock()
@app.get("/api/v1/health")
def health_check() -> dict:
return {"status": "ok", "version": "test"}
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket) -> None:
await ws.accept()
try:
while True:
raw_data = await ws.receive_text()
ws_ctx = WebSocketContext(
graph_ctx=app.state.graph_ctx,
session_manager=app.state.session_manager,
callback_handler=TokenUsageCallbackHandler(model_name="test-model"),
interrupt_manager=app.state.interrupt_manager,
analytics_recorder=app.state.analytics_recorder,
conversation_tracker=app.state.conversation_tracker,
pool=app.state.pool,
)
await dispatch_message(ws, ws_ctx, raw_data)
except WebSocketDisconnect:
pass
return app
@pytest.fixture
def e2e_graph():
"""Default graph fixture -- returns tokens and message_complete."""
return make_graph(
chunks=[make_chunk("Order 1042 is "), make_chunk("shipped.")]
)
@pytest.fixture
def e2e_app(e2e_graph):
"""Default E2E app fixture."""
return create_e2e_app(graph=e2e_graph)
@pytest.fixture
async def e2e_client(e2e_app):
"""Async HTTP client for E2E tests."""
transport = ASGITransport(app=e2e_app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
@pytest.fixture(autouse=True)
def clear_openapi_job_store():
"""Clear the in-memory job store between tests."""
_job_store.clear()
yield
_job_store.clear()

View File

@@ -0,0 +1,384 @@
"""E2E tests for critical chat user flows (flows 1-4).
Flow 1: Happy path -- query order, get answer
Flow 2: Approval flow -- write operation, interrupt, approve, execute
Flow 3: Rejection flow -- write operation, interrupt, reject, no execution
Flow 4: Multi-turn context -- sequential messages in same session
"""
from __future__ import annotations
import json
import pytest
from starlette.testclient import TestClient
from tests.e2e.conftest import (
create_e2e_app,
make_chunk,
make_graph,
make_state,
make_tool_chunk,
)
pytestmark = pytest.mark.e2e
class TestFlow1HappyPath:
"""Flow 1: query order -> get answer with streaming tokens."""
def test_websocket_happy_path_order_query(self) -> None:
graph = make_graph(
chunks=[
make_tool_chunk("get_order_status", {"order_id": "1042"}),
make_chunk("Order 1042 has been shipped and is on its way."),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-happy-1",
"content": "What is the status of order 1042?",
})
messages = []
while True:
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
tool_calls = [m for m in messages if m["type"] == "tool_call"]
assert len(tool_calls) == 1
assert tool_calls[0]["tool"] == "get_order_status"
assert tool_calls[0]["args"] == {"order_id": "1042"}
tokens = [m for m in messages if m["type"] == "token"]
assert len(tokens) == 1
assert "shipped" in tokens[0]["content"]
completes = [m for m in messages if m["type"] == "message_complete"]
assert len(completes) == 1
assert completes[0]["thread_id"] == "e2e-happy-1"
def test_websocket_multiple_token_stream(self) -> None:
"""Verify streaming returns multiple token chunks."""
graph = make_graph(
chunks=[
make_chunk("Your order "),
make_chunk("1042 "),
make_chunk("was delivered "),
make_chunk("yesterday."),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-stream-1",
"content": "Where is my order?",
})
messages = _collect_until_complete(ws)
tokens = [m for m in messages if m["type"] == "token"]
assert len(tokens) == 4
full_text = "".join(t["content"] for t in tokens)
assert "1042" in full_text
assert "delivered" in full_text
class TestFlow2ApprovalFlow:
"""Flow 2: write operation -> interrupt -> approve -> execute."""
def test_interrupt_approve_executes_action(self) -> None:
interrupt_state = make_state(
interrupt=True,
data={"action": "cancel_order", "order_id": "1042"},
)
graph = make_graph(
chunks=[],
state=interrupt_state,
resume_chunks=[
make_chunk("Order 1042 has been cancelled successfully.", "order_actions"),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Step 1: Send cancel request
ws.send_json({
"type": "message",
"thread_id": "e2e-approve-1",
"content": "Cancel order 1042",
})
messages = _collect_until_type(ws, "interrupt")
interrupts = [m for m in messages if m["type"] == "interrupt"]
assert len(interrupts) == 1
assert interrupts[0]["action"] == "cancel_order"
assert interrupts[0]["thread_id"] == "e2e-approve-1"
# Step 2: Approve the interrupt
ws.send_json({
"type": "interrupt_response",
"thread_id": "e2e-approve-1",
"approved": True,
})
resume_messages = _collect_until_complete(ws)
tokens = [m for m in resume_messages if m["type"] == "token"]
assert len(tokens) == 1
assert "cancelled" in tokens[0]["content"]
assert tokens[0]["agent"] == "order_actions"
completes = [m for m in resume_messages if m["type"] == "message_complete"]
assert len(completes) == 1
class TestFlow3RejectionFlow:
"""Flow 3: write operation -> interrupt -> reject -> no execution."""
def test_interrupt_reject_does_not_execute(self) -> None:
interrupt_state = make_state(
interrupt=True,
data={"action": "cancel_order", "order_id": "1042"},
)
graph = make_graph(
chunks=[],
state=interrupt_state,
resume_chunks=[
make_chunk("Understood. Order 1042 will remain active.", "order_actions"),
],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Step 1: Trigger interrupt
ws.send_json({
"type": "message",
"thread_id": "e2e-reject-1",
"content": "Cancel order 1042",
})
messages = _collect_until_type(ws, "interrupt")
assert any(m["type"] == "interrupt" for m in messages)
# Step 2: Reject
ws.send_json({
"type": "interrupt_response",
"thread_id": "e2e-reject-1",
"approved": False,
})
resume_messages = _collect_until_complete(ws)
tokens = [m for m in resume_messages if m["type"] == "token"]
assert len(tokens) == 1
assert "remain active" in tokens[0]["content"]
# Verify graph.astream was called with resume=False
resume_call = graph.astream.call_args_list[-1]
command = resume_call[0][0]
assert command.resume is False
class TestFlow4MultiTurnContext:
"""Flow 4: multi-turn conversation in the same session."""
def test_multi_turn_messages_share_session(self) -> None:
"""Multiple messages in the same thread_id maintain session context."""
graph = make_graph(
chunks=[make_chunk("Order 1042 status: shipped.")],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Turn 1: Query order
ws.send_json({
"type": "message",
"thread_id": "e2e-multi-1",
"content": "What is the status of order 1042?",
})
turn1 = _collect_until_complete(ws)
assert any(m["type"] == "message_complete" for m in turn1)
# Turn 2: Follow-up in same thread
ws.send_json({
"type": "message",
"thread_id": "e2e-multi-1",
"content": "When will it arrive?",
})
turn2 = _collect_until_complete(ws)
assert any(m["type"] == "message_complete" for m in turn2)
# Turn 3: Another follow-up
ws.send_json({
"type": "message",
"thread_id": "e2e-multi-1",
"content": "Can you track it?",
})
turn3 = _collect_until_complete(ws)
assert any(m["type"] == "message_complete" for m in turn3)
# Verify all turns used the same thread_id in graph calls
for call in graph.astream.call_args_list:
config = call[1].get("config", call[0][1] if len(call[0]) > 1 else {})
assert config["configurable"]["thread_id"] == "e2e-multi-1"
def test_separate_threads_are_independent(self) -> None:
"""Different thread_ids have independent sessions."""
graph = make_graph(
chunks=[make_chunk("Response.")],
)
app = create_e2e_app(graph=graph)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# Thread A
ws.send_json({
"type": "message",
"thread_id": "e2e-thread-a",
"content": "Hello from thread A",
})
_collect_until_complete(ws)
# Thread B
ws.send_json({
"type": "message",
"thread_id": "e2e-thread-b",
"content": "Hello from thread B",
})
_collect_until_complete(ws)
# Both threads should exist as separate sessions
sm = app.state.session_manager
assert sm.get_state("e2e-thread-a") is not None
assert sm.get_state("e2e-thread-b") is not None
class TestChatEdgeCases:
"""Edge cases and error handling for the chat WebSocket."""
def test_invalid_json_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_text("not valid json")
msg = ws.receive_json()
assert msg["type"] == "error"
assert "Invalid JSON" in msg["message"]
def test_missing_thread_id_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({"type": "message", "content": "hello"})
msg = ws.receive_json()
assert msg["type"] == "error"
assert "thread_id" in msg["message"]
def test_empty_content_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-err-1",
"content": "",
})
msg = ws.receive_json()
assert msg["type"] == "error"
def test_expired_session_returns_error(self) -> None:
graph = make_graph(chunks=[make_chunk("Response.")])
app = create_e2e_app(graph=graph, session_ttl=0)
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
# First message creates the session (TTL=0)
ws.send_json({
"type": "message",
"thread_id": "e2e-expired-1",
"content": "hello",
})
_collect_until_complete_or_error(ws)
# Second message finds the session expired (TTL=0)
ws.send_json({
"type": "message",
"thread_id": "e2e-expired-1",
"content": "hello again",
})
messages = _collect_until_complete_or_error(ws)
errors = [m for m in messages if m["type"] == "error"]
assert len(errors) >= 1
assert "expired" in errors[0]["message"].lower()
def test_oversized_message_returns_error(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
with client.websocket_connect("/ws") as ws:
ws.send_text("x" * 40_000)
msg = ws.receive_json()
assert msg["type"] == "error"
assert "too large" in msg["message"].lower()
def test_health_endpoint(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/v1/health")
assert resp.status_code == 200
assert resp.json()["status"] == "ok"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _collect_until_complete(ws, *, max_messages: int = 50) -> list[dict]:
"""Receive WebSocket messages until message_complete or error."""
messages = []
for _ in range(max_messages):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
return messages
def _collect_until_type(ws, msg_type: str, *, max_messages: int = 50) -> list[dict]:
"""Receive until a specific message type is received."""
messages = []
for _ in range(max_messages):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] == msg_type:
break
return messages
def _collect_until_complete_or_error(ws, *, max_messages: int = 50) -> list[dict]:
"""Receive until message_complete or error."""
messages = []
for _ in range(max_messages):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
return messages

View File

@@ -0,0 +1,201 @@
"""E2E tests for OpenAPI import flow (flow 5).
Flow 5: paste OpenAPI spec URL -> import job -> classify endpoints ->
review classifications -> approve -> tool generation.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from starlette.testclient import TestClient
from app.openapi.models import ClassificationResult, EndpointInfo
from app.openapi.review_api import _job_store
from tests.e2e.conftest import create_e2e_app
pytestmark = pytest.mark.e2e
def _fake_endpoint(
path: str = "/orders/{id}",
method: str = "GET",
operation_id: str = "getOrder",
summary: str = "Get order details",
) -> EndpointInfo:
return EndpointInfo(
path=path,
method=method,
operation_id=operation_id,
summary=summary,
description="",
parameters=(),
request_body_schema=None,
response_schema=None,
)
def _fake_classification(
endpoint: EndpointInfo | None = None,
access_type: str = "read",
needs_interrupt: bool = False,
agent_group: str = "order_lookup",
) -> ClassificationResult:
return ClassificationResult(
endpoint=endpoint or _fake_endpoint(),
access_type=access_type,
customer_params=["order_id"],
agent_group=agent_group,
confidence=0.95,
needs_interrupt=needs_interrupt,
)
class TestFlow5OpenAPIImport:
"""Flow 5: full OpenAPI import lifecycle."""
def test_import_job_lifecycle(self) -> None:
"""Start import -> check status -> review classifications -> approve."""
app = create_e2e_app()
with TestClient(app) as client:
# Step 1: Start import job
resp = client.post(
"/api/v1/openapi/import",
json={"url": "https://api.example.com/openapi.json"},
)
assert resp.status_code == 202
body = resp.json()
assert body["status"] == "pending"
job_id = body["job_id"]
# Step 2: Check job status (still pending since background task hasn't run)
resp = client.get(f"/api/v1/openapi/jobs/{job_id}")
assert resp.status_code == 200
assert resp.json()["job_id"] == job_id
def test_import_job_with_classifications(self) -> None:
"""Simulate completed import and review classified endpoints."""
app = create_e2e_app()
# Seed a completed job directly
ep_read = _fake_endpoint("/orders/{id}", "GET", "getOrder", "Get order")
ep_write = _fake_endpoint("/orders/{id}/cancel", "POST", "cancelOrder", "Cancel order")
clf_read = _fake_classification(ep_read, "read", False, "order_lookup")
clf_write = _fake_classification(ep_write, "write", True, "order_actions")
job_id = "test-job-001"
_job_store[job_id] = {
"job_id": job_id,
"status": "done",
"spec_url": "https://api.example.com/openapi.json",
"total_endpoints": 2,
"classified_count": 2,
"error_message": None,
"classifications": [clf_read, clf_write],
}
with TestClient(app) as client:
# Step 1: Get classifications
resp = client.get(f"/api/v1/openapi/jobs/{job_id}/classifications")
assert resp.status_code == 200
classifications = resp.json()
assert len(classifications) == 2
# Verify read endpoint
read_clf = classifications[0]
assert read_clf["access_type"] == "read"
assert read_clf["needs_interrupt"] is False
assert read_clf["endpoint"]["path"] == "/orders/{id}"
# Verify write endpoint
write_clf = classifications[1]
assert write_clf["access_type"] == "write"
assert write_clf["needs_interrupt"] is True
assert write_clf["endpoint"]["path"] == "/orders/{id}/cancel"
# Step 2: Update a classification
resp = client.put(
f"/api/v1/openapi/jobs/{job_id}/classifications/0",
json={
"access_type": "write",
"needs_interrupt": True,
"agent_group": "order_actions",
},
)
assert resp.status_code == 200
updated = resp.json()
assert updated["access_type"] == "write"
assert updated["needs_interrupt"] is True
assert updated["agent_group"] == "order_actions"
# Step 3: Approve the job
resp = client.post(f"/api/v1/openapi/jobs/{job_id}/approve")
assert resp.status_code == 200
assert resp.json()["status"] == "approved"
def test_import_nonexistent_job_returns_404(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/v1/openapi/jobs/nonexistent")
assert resp.status_code == 404
def test_import_invalid_url_returns_422(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.post("/api/v1/openapi/import", json={"url": "not-a-url"})
assert resp.status_code == 422
def test_classification_index_out_of_range(self) -> None:
app = create_e2e_app()
job_id = "test-job-range"
_job_store[job_id] = {
"job_id": job_id,
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 1,
"classified_count": 1,
"error_message": None,
"classifications": [_fake_classification()],
}
with TestClient(app) as client:
resp = client.put(
f"/api/v1/openapi/jobs/{job_id}/classifications/99",
json={
"access_type": "read",
"needs_interrupt": False,
"agent_group": "order_lookup",
},
)
assert resp.status_code == 404
def test_update_classification_invalid_agent_group(self) -> None:
app = create_e2e_app()
job_id = "test-job-invalid"
_job_store[job_id] = {
"job_id": job_id,
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 1,
"classified_count": 1,
"error_message": None,
"classifications": [_fake_classification()],
}
with TestClient(app) as client:
resp = client.put(
f"/api/v1/openapi/jobs/{job_id}/classifications/0",
json={
"access_type": "read",
"needs_interrupt": False,
"agent_group": "invalid group!", # spaces and special chars
},
)
assert resp.status_code == 422

View File

@@ -0,0 +1,230 @@
"""E2E tests for replay and analytics flows (flow 6).
Flow 6: list conversations -> select one -> step-by-step replay.
Also tests the analytics dashboard endpoint.
"""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from starlette.testclient import TestClient
from tests.e2e.conftest import FakePool, create_e2e_app
pytestmark = pytest.mark.e2e
# ---------------------------------------------------------------------------
# Custom pool that returns specific data per query
# ---------------------------------------------------------------------------
class ReplayPool(FakePool):
"""Pool that returns different data depending on the SQL query."""
def __init__(
self,
conversations: list[dict] | None = None,
checkpoints: list[dict] | None = None,
analytics_rows: list[dict] | None = None,
) -> None:
super().__init__()
self._conversations = conversations or []
self._checkpoints = checkpoints or []
self._analytics = analytics_rows or []
class _Conn:
def __init__(self, convos, checkpoints, analytics):
self._convos = convos
self._checkpoints = checkpoints
self._analytics = analytics
async def execute(self, query: str, params=None):
from tests.e2e.conftest import FakeCursor
if "COUNT" in query and "conversations" in query:
return FakeCursor([(len(self._convos),)])
if "conversations" in query and "SELECT" in query:
# Respect LIMIT/OFFSET from params if provided
rows = self._convos
if params:
offset = params.get("offset", 0)
limit = params.get("limit", len(rows))
rows = rows[offset : offset + limit]
return FakeCursor(rows)
if "checkpoints" in query:
return FakeCursor(self._checkpoints)
# Analytics queries
return FakeCursor(self._analytics)
def connection(self):
from contextlib import asynccontextmanager
conn = self._Conn(self._conversations, self._checkpoints, self._analytics)
@asynccontextmanager
async def _ctx():
yield conn
return _ctx()
class TestFlow6ReplayConversation:
"""Flow 6: list conversations -> select one -> step replay."""
def test_list_conversations(self) -> None:
now = datetime.now(tz=timezone.utc).isoformat()
conversations = [
{
"thread_id": "conv-001",
"created_at": now,
"last_activity": now,
"status": "active",
"total_tokens": 150,
"total_cost_usd": 0.003,
},
{
"thread_id": "conv-002",
"created_at": now,
"last_activity": now,
"status": "completed",
"total_tokens": 300,
"total_cost_usd": 0.006,
},
]
pool = ReplayPool(conversations=conversations)
app = create_e2e_app(pool=pool)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
data = body["data"]
assert len(data["conversations"]) == 2
assert data["conversations"][0]["thread_id"] == "conv-001"
assert data["conversations"][1]["thread_id"] == "conv-002"
assert data["total"] == 2
def test_list_conversations_pagination(self) -> None:
conversations = [
{
"thread_id": f"conv-{i:03d}",
"created_at": "2026-04-01T00:00:00Z",
"last_activity": "2026-04-01T00:00:00Z",
"status": "active",
"total_tokens": 100,
"total_cost_usd": 0.001,
}
for i in range(5)
]
pool = ReplayPool(conversations=conversations)
app = create_e2e_app(pool=pool)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations", params={"page": 1, "per_page": 2})
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
data = body["data"]
assert data["total"] == 5
assert data["page"] == 1
assert data["per_page"] == 2
assert len(data["conversations"]) == 2
def test_replay_thread_not_found(self) -> None:
pool = ReplayPool(checkpoints=[])
app = create_e2e_app(pool=pool)
with TestClient(app) as client:
resp = client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404
def test_replay_invalid_thread_id_format(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
# Thread ID with special chars fails regex validation
resp = client.get("/api/v1/replay/invalid%20thread%21%40")
assert resp.status_code == 400
class TestAnalyticsDashboard:
"""Analytics endpoint tests."""
def test_analytics_invalid_range_format(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/v1/analytics", params={"range": "invalid"})
assert resp.status_code == 400
def test_analytics_range_too_large(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/v1/analytics", params={"range": "999d"})
assert resp.status_code == 400
def test_analytics_range_zero_rejected(self) -> None:
app = create_e2e_app()
with TestClient(app) as client:
resp = client.get("/api/v1/analytics", params={"range": "0d"})
assert resp.status_code == 400
class TestFullUserJourney:
"""End-to-end journey: chat -> then check replay list shows the conversation."""
def test_chat_then_check_conversations_endpoint(self) -> None:
"""After chatting via WebSocket, the conversations endpoint is reachable."""
from tests.e2e.conftest import make_chunk, make_graph
graph = make_graph(chunks=[make_chunk("Your order is shipped.")])
now = datetime.now(tz=timezone.utc).isoformat()
pool = ReplayPool(
conversations=[
{
"thread_id": "e2e-journey-1",
"created_at": now,
"last_activity": now,
"status": "active",
"total_tokens": 50,
"total_cost_usd": 0.001,
},
],
)
app = create_e2e_app(graph=graph, pool=pool)
with TestClient(app) as client:
# Step 1: Chat via WebSocket
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "message",
"thread_id": "e2e-journey-1",
"content": "Where is my order?",
})
messages = []
for _ in range(20):
msg = ws.receive_json()
messages.append(msg)
if msg["type"] in ("message_complete", "error"):
break
assert any(m["type"] == "message_complete" for m in messages)
# Step 2: Check conversations endpoint
resp = client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert any(
c["thread_id"] == "e2e-journey-1"
for c in body["data"]["conversations"]
)
# Step 3: Health check still works
resp = client.get("/api/v1/health")
assert resp.status_code == 200

View File

@@ -0,0 +1,183 @@
"""Integration tests for the /api/v1/analytics endpoint.
Tests the full API layer (routing, parameter validation, serialization,
error handling) with a mocked database pool.
"""
from __future__ import annotations
from dataclasses import asdict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from app.analytics.models import AnalyticsResult, InterruptStats
pytestmark = pytest.mark.integration
_SAMPLE_RESULT = AnalyticsResult(
range="7d",
total_conversations=42,
resolution_rate=0.85,
escalation_rate=0.05,
avg_turns_per_conversation=3.2,
avg_cost_per_conversation_usd=0.012,
agent_usage=(),
interrupt_stats=InterruptStats(total=10, approved=7, rejected=2, expired=1),
)
def _build_app():
"""Build a minimal FastAPI app with the analytics router and mocked deps."""
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.analytics.api import router as analytics_router
from app.api_utils import envelope
test_app = FastAPI()
test_app.include_router(analytics_router)
@test_app.exception_handler(Exception)
async def _catch_all(request, exc):
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
from fastapi import HTTPException
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
# No admin_api_key set -> auth is skipped
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = MagicMock()
return test_app
class TestAnalyticsValidRange:
"""Test analytics endpoint with valid range parameters."""
async def test_valid_range_7d_returns_envelope(self) -> None:
"""GET /api/v1/analytics?range=7d returns success envelope with data."""
test_app = _build_app()
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=_SAMPLE_RESULT,
):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "7d"})
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["error"] is None
assert body["data"]["total_conversations"] == 42
assert body["data"]["resolution_rate"] == 0.85
async def test_default_range_returns_success(self) -> None:
"""GET /api/v1/analytics with no range param defaults to 7d."""
test_app = _build_app()
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=_SAMPLE_RESULT,
) as mock_get:
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics")
assert resp.status_code == 200
# Verify default range of 7 days was passed
mock_get.assert_called_once()
call_args = mock_get.call_args
assert call_args[1].get("range_days", call_args[0][1] if len(call_args[0]) > 1 else None) in (7, None) or call_args[0][1] == 7
async def test_large_range_365d_works(self) -> None:
"""GET /api/v1/analytics?range=365d is accepted (max boundary)."""
test_app = _build_app()
result = AnalyticsResult(
range="365d",
total_conversations=1000,
resolution_rate=0.9,
escalation_rate=0.02,
avg_turns_per_conversation=4.0,
avg_cost_per_conversation_usd=0.01,
agent_usage=(),
interrupt_stats=InterruptStats(),
)
with patch(
"app.analytics.api.get_analytics",
new_callable=AsyncMock,
return_value=result,
):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "365d"})
assert resp.status_code == 200
assert resp.json()["success"] is True
class TestAnalyticsInvalidRange:
"""Test analytics endpoint with invalid range parameters."""
async def test_invalid_range_format_returns_400(self) -> None:
"""GET /api/v1/analytics?range=abc returns 400 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "abc"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert "Invalid range format" in body["error"]
async def test_zero_day_range_returns_400(self) -> None:
"""GET /api/v1/analytics?range=0d returns 400 because 0 is below minimum."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "0d"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "between 1 and 365" in body["error"]
async def test_range_exceeding_max_returns_400(self) -> None:
"""GET /api/v1/analytics?range=999d returns 400 because it exceeds 365."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "999d"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "between 1 and 365" in body["error"]

View File

@@ -0,0 +1,128 @@
"""Integration tests for global error handling and envelope format consistency.
Tests that all error responses from the FastAPI app conform to the
standard envelope: {"success": false, "data": null, "error": "..."}.
"""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _build_app():
"""Build the actual FastAPI app with exception handlers but mocked state."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.analytics.api import router as analytics_router
from app.api_utils import envelope
from app.replay.api import router as replay_router
test_app = FastAPI()
test_app.include_router(analytics_router)
test_app.include_router(replay_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@test_app.exception_handler(Exception)
async def _catch_all(request, exc):
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
@test_app.get("/api/v1/health")
def health_check():
return {"status": "ok", "version": "0.6.0"}
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = MagicMock()
return test_app
class TestEnvelopeFormat:
"""Tests that error responses consistently follow envelope format."""
async def test_http_400_produces_envelope(self) -> None:
"""A 400 error returns standard envelope with success=false."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "invalid"})
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
assert len(body["error"]) > 0
async def test_validation_error_produces_422_envelope(self) -> None:
"""Invalid query param type returns 422 with envelope format."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
# page must be >= 1; passing 0 triggers validation error
resp = await client.get("/api/v1/conversations", params={"page": 0})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
async def test_all_error_fields_present(self) -> None:
"""Error envelope contains exactly success, data, and error keys."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/analytics", params={"range": "bad"})
body = resp.json()
assert set(body.keys()) == {"success", "data", "error"}
async def test_health_endpoint_returns_200(self) -> None:
"""Health check returns 200 with status ok."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/health")
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "ok"
assert "version" in body
async def test_unknown_endpoint_returns_404(self) -> None:
"""Requesting a non-existent path returns 404."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/nonexistent-path")
# FastAPI returns 404 for unknown routes; may or may not be wrapped
assert resp.status_code == 404

View File

@@ -0,0 +1,203 @@
"""Integration tests for the OpenAPI import pipeline orchestrator.
Tests the full pipeline: fetch -> validate -> parse -> classify.
Uses mocked HTTP and mocked LLM classifier.
RED phase: written before implementation.
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
from app.openapi.models import ImportJob
pytestmark = pytest.mark.integration
_VALID_SPEC_JSON = """{
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {
"/orders": {
"get": {
"operationId": "list_orders",
"summary": "List orders",
"description": "Returns all orders",
"responses": {"200": {"description": "OK"}}
}
},
"/orders/{id}": {
"delete": {
"operationId": "delete_order",
"summary": "Delete order",
"description": "Deletes an order",
"parameters": [
{"name": "id", "in": "path", "required": true, "schema": {"type": "string"}}
],
"responses": {"204": {"description": "Deleted"}}
}
}
}
}"""
_INVALID_SPEC_JSON = '{"not": "a valid openapi spec"}'
_PUBLIC_IP = "93.184.216.34"
@pytest.fixture
def mock_classifier():
"""A mock classifier that classifies using heuristics."""
from app.openapi.classifier import HeuristicClassifier
return HeuristicClassifier()
@pytest.fixture
def orchestrator(mock_classifier):
"""Create an ImportOrchestrator with the mock classifier."""
from app.openapi.importer import ImportOrchestrator
return ImportOrchestrator(classifier=mock_classifier)
class TestImportOrchestratorSuccess:
"""Tests for successful import pipeline flows."""
@pytest.mark.usefixtures("_mock_public_dns")
async def test_full_pipeline_succeeds(self, orchestrator, httpx_mock) -> None:
"""Full pipeline with valid spec and mocked HTTP succeeds."""
httpx_mock.add_response(
url="https://example.com/api/spec.json",
text=_VALID_SPEC_JSON,
headers={"content-type": "application/json"},
)
job = await orchestrator.start_import(
url="https://example.com/api/spec.json",
job_id="test-job-1",
on_progress=None,
)
assert isinstance(job, ImportJob)
assert job.status == "done"
assert job.job_id == "test-job-1"
assert job.total_endpoints == 2
assert job.classified_count == 2
assert job.error_message is None
@pytest.mark.usefixtures("_mock_public_dns")
async def test_progress_callback_called_at_stages(self, orchestrator, httpx_mock) -> None:
"""on_progress callback is called at each pipeline stage."""
httpx_mock.add_response(
url="https://example.com/api/spec.json",
text=_VALID_SPEC_JSON,
headers={"content-type": "application/json"},
)
stages_seen: list[str] = []
def on_progress(stage: str, job: ImportJob) -> None:
stages_seen.append(stage)
await orchestrator.start_import(
url="https://example.com/api/spec.json",
job_id="test-job-2",
on_progress=on_progress,
)
assert "fetching" in stages_seen
assert "validating" in stages_seen
assert "parsing" in stages_seen
assert "classifying" in stages_seen
@pytest.mark.usefixtures("_mock_public_dns")
async def test_none_progress_callback_does_not_raise(
self, orchestrator, httpx_mock
) -> None:
"""Passing None as on_progress does not raise."""
httpx_mock.add_response(
url="https://example.com/api/spec.json",
text=_VALID_SPEC_JSON,
headers={"content-type": "application/json"},
)
job = await orchestrator.start_import(
url="https://example.com/api/spec.json",
job_id="test-job-3",
on_progress=None,
)
assert job.status == "done"
class TestImportOrchestratorFailures:
"""Tests for error handling in the import pipeline."""
async def test_fetch_failure_sets_failed_status(self, orchestrator) -> None:
"""When HTTP fetch fails, job status is 'failed'."""
with patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]):
job = await orchestrator.start_import(
url="http://internal.corp/spec.json",
job_id="test-job-fail-1",
on_progress=None,
)
assert job.status == "failed"
assert job.error_message is not None
@pytest.mark.usefixtures("_mock_public_dns")
async def test_validation_failure_sets_failed_status(
self, orchestrator, httpx_mock
) -> None:
"""When spec validation fails, job status is 'failed'."""
httpx_mock.add_response(
url="https://example.com/api/bad.json",
text=_INVALID_SPEC_JSON,
headers={"content-type": "application/json"},
)
job = await orchestrator.start_import(
url="https://example.com/api/bad.json",
job_id="test-job-fail-2",
on_progress=None,
)
assert job.status == "failed"
assert job.error_message is not None
@pytest.mark.usefixtures("_mock_public_dns")
async def test_error_message_is_descriptive(self, orchestrator, httpx_mock) -> None:
"""Error message contains useful context."""
httpx_mock.add_response(
url="https://example.com/api/bad.json",
text=_INVALID_SPEC_JSON,
headers={"content-type": "application/json"},
)
job = await orchestrator.start_import(
url="https://example.com/api/bad.json",
job_id="test-job-fail-3",
on_progress=None,
)
assert len(job.error_message) > 0
@pytest.mark.usefixtures("_mock_public_dns")
async def test_failed_status_progress_called_with_failed(
self, orchestrator, httpx_mock
) -> None:
"""When pipeline fails, on_progress is called with 'failed' stage."""
httpx_mock.add_response(
url="https://example.com/api/bad.json",
text=_INVALID_SPEC_JSON,
headers={"content-type": "application/json"},
)
stages_seen: list[str] = []
def on_progress(stage: str, job: ImportJob) -> None:
stages_seen.append(stage)
await orchestrator.start_import(
url="https://example.com/api/bad.json",
job_id="test-job-fail-4",
on_progress=on_progress,
)
assert "failed" in stages_seen
@pytest.fixture
def _mock_public_dns():
with patch("app.openapi.ssrf.resolve_hostname", return_value=[_PUBLIC_IP]):
yield

View File

@@ -0,0 +1,164 @@
"""Integration tests for /api/v1/openapi/ endpoints.
Tests the full API layer for the OpenAPI import review workflow,
including job creation, status retrieval, classification updates,
and approval triggering.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _build_app():
"""Build a minimal FastAPI app with the openapi router and mocked deps."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.api_utils import envelope
from app.openapi.review_api import router as openapi_router
test_app = FastAPI()
test_app.include_router(openapi_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
test_app.state.settings = MagicMock(admin_api_key="")
return test_app
@pytest.fixture(autouse=True)
def _clear_job_store():
"""Clear the in-memory job store between tests."""
from app.openapi.review_api import _job_store
_job_store.clear()
yield
_job_store.clear()
class TestImportEndpoint:
"""Tests for POST /api/v1/openapi/import."""
async def test_import_returns_202_with_job_id(self) -> None:
"""Starting an import returns 202 with a job_id."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/openapi/import",
json={"url": "https://example.com/api/spec.json"},
)
assert resp.status_code == 202
body = resp.json()
assert "job_id" in body
assert body["status"] == "pending"
assert body["spec_url"] == "https://example.com/api/spec.json"
async def test_import_invalid_url_returns_422(self) -> None:
"""POST with invalid URL (no http/https) returns 422."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/openapi/import",
json={"url": "ftp://example.com/spec.json"},
)
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
class TestJobStatusEndpoint:
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
async def test_get_existing_job_returns_status(self) -> None:
"""Retrieving an existing job returns its status."""
from app.openapi.review_api import _job_store
_job_store["test-job-1"] = {
"job_id": "test-job-1",
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 5,
"classified_count": 5,
"error_message": None,
"classifications": [],
}
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/openapi/jobs/test-job-1")
assert resp.status_code == 200
body = resp.json()
assert body["job_id"] == "test-job-1"
assert body["status"] == "done"
assert body["total_endpoints"] == 5
async def test_get_unknown_job_returns_404(self) -> None:
"""Retrieving a non-existent job returns 404 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/openapi/jobs/unknown-id-999")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert "not found" in body["error"].lower()
class TestApproveEndpoint:
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
async def test_approve_with_no_classifications_returns_400(self) -> None:
"""Approving a job with no classifications returns 400."""
from app.openapi.review_api import _job_store
_job_store["empty-job"] = {
"job_id": "empty-job",
"status": "done",
"spec_url": "https://example.com/spec.json",
"total_endpoints": 0,
"classified_count": 0,
"error_message": None,
"classifications": [],
}
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.post("/api/v1/openapi/jobs/empty-job/approve")
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert "no classifications" in body["error"].lower()

View File

@@ -0,0 +1,512 @@
"""Phase 2 checkpoint acceptance tests.
Each test maps to one checkpoint criterion from DEVELOPMENT-PLAN.md:
1. "查询订单 1042" -> routes to order_lookup agent
2. "取消订单 1042 并给我一个 10% 折扣" -> sequential multi-agent execution
3. Ambiguous message -> fallback asks for clarification
4. Interrupt > 30 min TTL -> auto-cancel + retry prompt
5. Agent escalation -> Webhook POST succeeds (or logs after retries)
6. E-commerce template -> 4 pre-configured agents work
7. pytest --cov >= 80% (verified separately)
"""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.escalation import EscalationPayload, NoOpEscalator, WebhookEscalator
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.registry import AgentConfig, AgentRegistry
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
TEMPLATES_DIR = Path(__file__).parent.parent.parent / "templates"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class AsyncIterHelper:
def __init__(self, items: list) -> None:
self._items = list(items)
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
class FakeWS:
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_json(self, data: dict) -> None:
self.sent.append(data)
def _chunk(content: str, node: str) -> tuple:
c = MagicMock()
c.content = content
c.tool_calls = []
return (c, {"langgraph_node": node})
def _tool_chunk(name: str, args: dict, node: str) -> tuple:
c = MagicMock()
c.content = ""
c.tool_calls = [{"name": name, "args": args}]
return (c, {"langgraph_node": node})
def _state(*, interrupt: bool = False, data: dict | None = None):
s = MagicMock()
if interrupt:
obj = MagicMock()
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
t = MagicMock()
t.interrupts = (obj,)
s.tasks = (t,)
else:
s.tasks = ()
return s
def _agent(name: str, desc: str, perm: str = "read") -> AgentConfig:
return AgentConfig(name=name, description=desc, permission=perm, tools=["fallback_respond"])
# ---------------------------------------------------------------------------
# Checkpoint 1: "查询订单 1042" -> 路由到订单查询 Agent
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint1OrderQueryRouting:
"""Verify intent classifier routes order queries to order_lookup."""
@pytest.mark.asyncio
async def test_order_query_classified_to_order_lookup(self) -> None:
expected = ClassificationResult(
intents=(
IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="order query"),
),
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (
_agent("order_lookup", "Looks up order status and tracking"),
_agent("order_actions", "Modifies orders", "write"),
_agent("discount", "Applies discounts", "write"),
_agent("fallback", "Handles unclear requests"),
)
result = await classifier.classify("查询订单 1042", agents)
assert len(result.intents) == 1
assert result.intents[0].agent_name == "order_lookup"
assert result.intents[0].confidence >= 0.9
assert not result.is_ambiguous
@pytest.mark.asyncio
async def test_order_query_streams_from_order_lookup_agent(self) -> None:
"""Full dispatch: classify -> route -> stream from order_lookup."""
graph = MagicMock()
# Classifier returns order_lookup
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
))
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
# Graph streams order_lookup response
graph.astream = MagicMock(return_value=AsyncIterHelper([
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
_chunk("Order 1042 is shipped.", "order_lookup"),
]))
graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
await dispatch_message(ws, ws_ctx, raw)
tool_msgs = [m for m in ws.sent if m["type"] == "tool_call"]
assert any(m["tool"] == "get_order_status" for m in tool_msgs)
token_msgs = [m for m in ws.sent if m["type"] == "token"]
assert any(m["agent"] == "order_lookup" for m in token_msgs)
# ---------------------------------------------------------------------------
# Checkpoint 2: Multi-intent -> sequential execution
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint2MultiIntentSequential:
"""Verify multi-intent classified and hint injected for sequential execution."""
@pytest.mark.asyncio
async def test_multi_intent_classification(self) -> None:
expected = ClassificationResult(
intents=(
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
),
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (
_agent("order_actions", "Modifies orders", "write"),
_agent("discount", "Applies discounts", "write"),
_agent("fallback", "Handles unclear requests"),
)
result = await classifier.classify("取消订单 1042 并给我一个 10% 折扣", agents)
assert len(result.intents) == 2
assert result.intents[0].agent_name == "order_actions"
assert result.intents[1].agent_name == "discount"
assert not result.is_ambiguous
@pytest.mark.asyncio
async def test_multi_intent_injects_routing_hint(self) -> None:
"""When multi-intent detected, a [System: ...] hint is appended to the message."""
graph = MagicMock()
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
intents=(
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
),
))
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({
"type": "message",
"thread_id": "t1",
"content": "取消订单 1042 并给我一个 10% 折扣",
})
await dispatch_message(ws, ws_ctx, raw)
# Verify the graph was called with the routing hint in the message
call_args = graph.astream.call_args
input_msg = call_args[0][0]
msg_content = input_msg["messages"][0].content
assert "[System:" in msg_content
assert "order_actions" in msg_content
assert "discount" in msg_content
# ---------------------------------------------------------------------------
# Checkpoint 3: Ambiguous message -> clarification
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint3AmbiguousClarification:
"""Verify ambiguous messages trigger clarification prompt."""
@pytest.mark.asyncio
async def test_ambiguous_intent_returns_clarification(self) -> None:
expected = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),),
is_ambiguous=False, # low confidence will trigger ambiguity threshold
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_agent("order_lookup", "Orders"), _agent("fallback", "Fallback"))
result = await classifier.classify("嗯...", agents)
assert result.is_ambiguous
assert result.clarification_question is not None
@pytest.mark.asyncio
async def test_ambiguous_sends_clarification_via_websocket(self) -> None:
graph = MagicMock()
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question=(
"Could you please provide more details about what you need help with?"
),
))
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=_state())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=mock_classifier,
)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "嗯..."})
await dispatch_message(ws, ws_ctx, raw)
clarifications = [m for m in ws.sent if m["type"] == "clarification"]
assert len(clarifications) == 1
assert "more details" in clarifications[0]["message"]
# Should NOT call graph.astream since we returned early
graph.astream.assert_not_called()
# ---------------------------------------------------------------------------
# Checkpoint 4: Interrupt > 30 min -> auto-cancel + retry
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint4InterruptTTLAutoCancel:
"""Verify interrupt TTL expiration triggers auto-cancel and retry prompt."""
@pytest.mark.asyncio
async def test_30min_expired_interrupt_auto_cancels(self) -> None:
st = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
graph = MagicMock()
graph.astream = MagicMock(return_value=AsyncIterHelper([]))
graph.aget_state = AsyncMock(return_value=st)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=graph, registry=mock_registry, intent_classifier=None)
sm = SessionManager()
sm.touch("t1")
im = InterruptManager(ttl_seconds=1800) # 30 minutes
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
# Trigger interrupt
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "Cancel order 1042"})
await dispatch_message(ws, ws_ctx, raw)
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1
# Simulate 31 minutes passing
record = im._interrupts["t1"]
ws.sent.clear()
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 1860 # 31 min
raw = json.dumps({
"type": "interrupt_response",
"thread_id": "t1",
"approved": True,
})
await dispatch_message(ws, ws_ctx, raw)
# Should get retry prompt, NOT resume the graph
expired_msgs = [m for m in ws.sent if m["type"] == "interrupt_expired"]
assert len(expired_msgs) == 1
assert "30 minutes" in expired_msgs[0]["message"]
assert expired_msgs[0]["action"] == "cancel_order"
assert expired_msgs[0]["thread_id"] == "t1"
def test_cleanup_expired_returns_records(self) -> None:
im = InterruptManager(ttl_seconds=1800)
im.register("t1", "cancel_order", {"order_id": "1042"})
im.register("t2", "apply_discount", {"order_id": "1043"})
with patch("app.interrupt_manager.time") as mock_time:
record = im._interrupts["t1"]
mock_time.time.return_value = record.created_at + 1860
expired = im.cleanup_expired()
assert len(expired) == 2
actions = {r.action for r in expired}
assert "cancel_order" in actions
assert "apply_discount" in actions
# ---------------------------------------------------------------------------
# Checkpoint 5: Agent escalation -> Webhook POST
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint5WebhookEscalation:
"""Verify webhook escalation sends POST and retries on failure."""
@pytest.mark.asyncio
async def test_webhook_post_success(self) -> None:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("app.escalation.httpx.AsyncClient", return_value=mock_client):
escalator = WebhookEscalator(url="https://support.example.com/escalate")
payload = EscalationPayload(
thread_id="t1",
reason="Agent cannot resolve customer issue",
conversation_summary="Customer asked about refund policy",
metadata={"customer_id": "C-123"},
)
result = await escalator.escalate(payload)
assert result.success
assert result.status_code == 200
assert result.attempts == 1
# Verify POST was called with correct payload
call_args = mock_client.post.call_args
assert call_args[0][0] == "https://support.example.com/escalate"
posted_data = call_args[1]["json"]
assert posted_data["thread_id"] == "t1"
assert posted_data["reason"] == "Agent cannot resolve customer issue"
@pytest.mark.asyncio
async def test_webhook_retries_then_logs(self) -> None:
fail_response = AsyncMock()
fail_response.status_code = 503
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=fail_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
):
escalator = WebhookEscalator(
url="https://support.example.com/escalate",
max_retries=3,
)
payload = EscalationPayload(
thread_id="t1",
reason="Escalation needed",
conversation_summary="Summary",
)
result = await escalator.escalate(payload)
assert not result.success
assert result.attempts == 3
assert "503" in result.error
@pytest.mark.asyncio
async def test_noop_escalator_when_disabled(self) -> None:
escalator = NoOpEscalator()
payload = EscalationPayload(
thread_id="t1",
reason="Test",
conversation_summary="Test",
)
result = await escalator.escalate(payload)
assert not result.success
assert "disabled" in result.error.lower()
# ---------------------------------------------------------------------------
# Checkpoint 6: E-commerce template -> pre-configured agents
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestCheckpoint6EcommerceTemplate:
"""Verify e-commerce template loads with correct agents."""
def test_ecommerce_template_loads_4_agents(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
assert len(registry) == 4
def test_ecommerce_template_has_correct_agents(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
agents = registry.list_agents()
names = {a.name for a in agents}
assert names == {"order_lookup", "order_actions", "discount", "fallback"}
def test_ecommerce_order_lookup_is_read(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
agent = registry.get_agent("order_lookup")
assert agent.permission == "read"
assert "get_order_status" in agent.tools
assert "get_tracking_info" in agent.tools
def test_ecommerce_order_actions_is_write(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
agent = registry.get_agent("order_actions")
assert agent.permission == "write"
assert "cancel_order" in agent.tools
def test_ecommerce_discount_is_write(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
agent = registry.get_agent("discount")
assert agent.permission == "write"
assert "apply_discount" in agent.tools
assert "generate_coupon" in agent.tools
def test_ecommerce_fallback_is_read(self) -> None:
registry = AgentRegistry.load_template("e-commerce", TEMPLATES_DIR)
agent = registry.get_agent("fallback")
assert agent.permission == "read"
def test_all_three_templates_available(self) -> None:
templates = AgentRegistry.list_templates(TEMPLATES_DIR)
assert "e-commerce" in templates
assert "saas" in templates
assert "fintech" in templates

View File

@@ -0,0 +1,213 @@
"""Integration tests for /api/v1/conversations and /api/v1/replay/{thread_id}.
Tests the full API layer with a mocked database pool, verifying routing,
serialization, pagination, and error handling in envelope format.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
pytestmark = pytest.mark.integration
def _make_fake_cursor(rows, *, fetchone_value=None):
"""Build a fake async cursor returning the given rows on fetchall."""
cursor = AsyncMock()
cursor.fetchall = AsyncMock(return_value=rows)
if fetchone_value is not None:
cursor.fetchone = AsyncMock(return_value=fetchone_value)
return cursor
class _FakeConnection:
"""Fake async connection that returns pre-configured cursors in order."""
def __init__(self, cursors: list) -> None:
self._cursors = list(cursors)
self._idx = 0
async def execute(self, sql, params=None):
cursor = self._cursors[self._idx]
self._idx += 1
return cursor
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class _FakePool:
"""Fake connection pool that yields a fake connection."""
def __init__(self, conn: _FakeConnection) -> None:
self._conn = conn
def connection(self):
return self._conn
def _build_app(pool=None):
"""Build a minimal FastAPI app with the replay router and mocked deps."""
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.api_utils import envelope
from app.replay.api import router as replay_router
test_app = FastAPI()
test_app.include_router(replay_router)
@test_app.exception_handler(HTTPException)
async def _http_exc(request, exc):
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@test_app.exception_handler(RequestValidationError)
async def _validation_exc(request, exc):
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
test_app.state.settings = MagicMock(admin_api_key="")
test_app.state.pool = pool or MagicMock()
return test_app
class TestListConversations:
"""Tests for GET /api/v1/conversations endpoint."""
async def test_returns_paginated_envelope(self) -> None:
"""Conversations list returns envelope with pagination metadata."""
count_cursor = _make_fake_cursor([], fetchone_value=(3,))
rows = [
{"thread_id": "t1", "created_at": "2026-01-01", "last_activity": "2026-01-01",
"status": "active", "total_tokens": 100, "total_cost_usd": 0.01},
{"thread_id": "t2", "created_at": "2026-01-02", "last_activity": "2026-01-02",
"status": "resolved", "total_tokens": 200, "total_cost_usd": 0.02},
]
list_cursor = _make_fake_cursor(rows)
conn = _FakeConnection([count_cursor, list_cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["total"] == 3
assert len(body["data"]["conversations"]) == 2
assert body["data"]["page"] == 1
assert body["data"]["per_page"] == 20
async def test_custom_page_and_per_page(self) -> None:
"""Custom page/per_page params are reflected in the response."""
count_cursor = _make_fake_cursor([], fetchone_value=(50,))
list_cursor = _make_fake_cursor([])
conn = _FakeConnection([count_cursor, list_cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations", params={"page": 3, "per_page": 10})
assert resp.status_code == 200
body = resp.json()
assert body["data"]["page"] == 3
assert body["data"]["per_page"] == 10
async def test_invalid_page_returns_422(self) -> None:
"""page=0 violates ge=1 constraint and returns 422 error envelope."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/conversations", params={"page": 0})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
class TestReplayEndpoint:
"""Tests for GET /api/v1/replay/{thread_id} endpoint."""
async def test_valid_thread_returns_timeline(self) -> None:
"""Replay with valid thread_id returns steps in envelope format."""
checkpoint_rows = [
{
"thread_id": "abc123",
"checkpoint_id": "cp1",
"checkpoint": {
"channel_values": {
"messages": [
{"type": "human", "content": "Hello", "created_at": "2026-01-01T00:00:00Z"},
{"type": "ai", "content": "Hi there!", "created_at": "2026-01-01T00:00:01Z"},
]
}
},
"metadata": {},
}
]
cursor = _make_fake_cursor(checkpoint_rows)
conn = _FakeConnection([cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/abc123")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["thread_id"] == "abc123"
assert body["data"]["total_steps"] == 2
assert len(body["data"]["steps"]) == 2
assert body["data"]["steps"][0]["type"] == "user_message"
assert body["data"]["steps"][1]["type"] == "agent_response"
async def test_invalid_thread_id_format_returns_400(self) -> None:
"""Thread IDs with path traversal characters are rejected with 400."""
test_app = _build_app()
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/../../etc/passwd")
# FastAPI may return 400 from our handler or 404 from routing
assert resp.status_code in (400, 404, 422)
async def test_nonexistent_thread_returns_404(self) -> None:
"""Replay with a thread_id that has no checkpoints returns 404."""
cursor = _make_fake_cursor([])
conn = _FakeConnection([cursor])
pool = _FakePool(conn)
test_app = _build_app(pool)
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
resp = await client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert "not found" in body["error"].lower()

View File

@@ -0,0 +1,371 @@
"""Integration tests for multi-agent routing flow.
Tests the full pipeline: intent classification -> supervisor routing ->
agent execution -> response streaming, exercising cross-module integration
with mocked LLM.
Required by Phase 2 test plan:
- Unit: intent classification accuracy
- Unit: multi-intent sequential execution
- Integration: complete multi-agent routing flow
"""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget, LLMIntentClassifier
from app.interrupt_manager import InterruptManager
from app.registry import AgentConfig
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class AsyncIterHelper:
def __init__(self, items: list) -> None:
self._items = list(items)
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
class FakeWS:
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_json(self, data: dict) -> None:
self.sent.append(data)
def _chunk(content: str, node: str) -> tuple:
c = MagicMock()
c.content = content
c.tool_calls = []
return (c, {"langgraph_node": node})
def _tool_chunk(name: str, args: dict, node: str) -> tuple:
c = MagicMock()
c.content = ""
c.tool_calls = [{"name": name, "args": args}]
return (c, {"langgraph_node": node})
def _state(*, interrupt: bool = False, data: dict | None = None):
s = MagicMock()
if interrupt:
obj = MagicMock()
obj.value = data or {}
t = MagicMock()
t.interrupts = (obj,)
s.tasks = (t,)
else:
s.tasks = ()
return s
AGENTS = (
AgentConfig(
name="order_lookup", description="Looks up orders",
permission="read", tools=["get_order_status", "get_tracking_info"],
),
AgentConfig(
name="order_actions", description="Modifies orders",
permission="write", tools=["cancel_order"],
),
AgentConfig(
name="discount", description="Applies discounts",
permission="write", tools=["apply_discount", "generate_coupon"],
),
AgentConfig(
name="fallback", description="Handles unclear requests",
permission="read", tools=["fallback_respond"],
),
)
def _make_classifier(result: ClassificationResult) -> AsyncMock:
"""Create a mock classifier returning the given result."""
classifier = AsyncMock()
classifier.classify = AsyncMock(return_value=result)
return classifier
def _make_graph_and_ctx(
classifier_result: ClassificationResult | None,
chunks: list,
state=None,
) -> tuple[MagicMock, GraphContext]:
"""Build a graph mock and GraphContext with optional intent classifier."""
graph = MagicMock()
graph.astream = MagicMock(return_value=AsyncIterHelper(list(chunks)))
graph.aget_state = AsyncMock(return_value=state or _state())
if classifier_result is not None:
classifier = _make_classifier(classifier_result)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=AGENTS)
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=classifier,
)
else:
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(
graph=graph, registry=mock_registry, intent_classifier=None,
)
return graph, graph_ctx
async def _dispatch(graph_ctx: GraphContext, content: str, thread_id: str = "t1") -> list[dict]:
sm = SessionManager()
sm.touch(thread_id)
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": thread_id, "content": content})
await dispatch_message(ws, ws_ctx, raw)
return ws.sent
# ---------------------------------------------------------------------------
# Single-intent routing to each agent
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestSingleIntentRouting:
"""Verify single-intent messages route to the correct agent."""
@pytest.mark.asyncio
async def test_routes_to_order_lookup(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(
agent_name="order_lookup", confidence=0.95, reasoning="status query",
),),
)
graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"),
_chunk("Order 1042 is shipped.", "order_lookup"),
])
msgs = await _dispatch(graph_ctx, "What is the status of order 1042?")
tools = [m for m in msgs if m["type"] == "tool_call"]
assert len(tools) == 1
assert tools[0]["tool"] == "get_order_status"
assert tools[0]["agent"] == "order_lookup"
tokens = [m for m in msgs if m["type"] == "token"]
assert any("shipped" in t["content"] for t in tokens)
@pytest.mark.asyncio
async def test_routes_to_order_actions(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),),
)
graph, graph_ctx = _make_graph_and_ctx(
result,
[_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions")],
state=_state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}),
)
msgs = await _dispatch(graph_ctx, "Cancel order 1042")
tools = [m for m in msgs if m["type"] == "tool_call"]
assert tools[0]["tool"] == "cancel_order"
assert tools[0]["agent"] == "order_actions"
interrupts = [m for m in msgs if m["type"] == "interrupt"]
assert len(interrupts) == 1
@pytest.mark.asyncio
async def test_routes_to_discount(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(agent_name="discount", confidence=0.88, reasoning="coupon"),),
)
graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("generate_coupon", {"discount_percent": 15, "expiry_days": 30}, "discount"),
_chunk("Here is your coupon: SAVE15-ABC12345", "discount"),
])
msgs = await _dispatch(graph_ctx, "Give me a 15% coupon")
tools = [m for m in msgs if m["type"] == "tool_call"]
assert tools[0]["tool"] == "generate_coupon"
assert tools[0]["agent"] == "discount"
@pytest.mark.asyncio
async def test_routes_to_fallback(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.7, reasoning="general"),),
)
graph, graph_ctx = _make_graph_and_ctx(result, [
_chunk("I can help with order inquiries.", "fallback"),
])
msgs = await _dispatch(graph_ctx, "What can you do?")
tokens = [m for m in msgs if m["type"] == "token"]
assert tokens[0]["agent"] == "fallback"
# ---------------------------------------------------------------------------
# Multi-intent routing
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestMultiIntentRouting:
"""Verify multi-intent triggers sequential execution hint."""
@pytest.mark.asyncio
async def test_two_intents_inject_routing_hint(self) -> None:
result = ClassificationResult(
intents=(
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
),
)
graph, graph_ctx = _make_graph_and_ctx(result, [
_tool_chunk("cancel_order", {"order_id": "1042"}, "order_actions"),
_tool_chunk("apply_discount", {"order_id": "1042", "discount_percent": 10}, "discount"),
])
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({
"type": "message",
"thread_id": "t1",
"content": "取消订单 1042 并给我一个 10% 折扣",
})
await dispatch_message(ws, ws_ctx, raw)
# Verify routing hint was injected
call_args = graph.astream.call_args[0][0]
msg_content = call_args["messages"][0].content
assert "[System:" in msg_content
assert "order_actions" in msg_content
assert "discount" in msg_content
# Both tool calls should appear
tools = [m for m in ws.sent if m["type"] == "tool_call"]
tool_names = {t["tool"] for t in tools}
assert "cancel_order" in tool_names
assert "apply_discount" in tool_names
@pytest.mark.asyncio
async def test_single_intent_no_routing_hint(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
)
graph, graph_ctx = _make_graph_and_ctx(result, [_chunk("Order shipped.", "order_lookup")])
sm = SessionManager()
sm.touch("t1")
im = InterruptManager()
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "查询订单 1042"})
await dispatch_message(ws, ws_ctx, raw)
msg_content = graph.astream.call_args[0][0]["messages"][0].content
assert "[System:" not in msg_content
# ---------------------------------------------------------------------------
# Ambiguity routing
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestAmbiguityRouting:
"""Verify ambiguous intents produce clarification, not agent calls."""
@pytest.mark.asyncio
async def test_ambiguous_skips_graph_returns_clarification(self) -> None:
result = ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="Could you please clarify what you need?",
)
graph, graph_ctx = _make_graph_and_ctx(result, [])
msgs = await _dispatch(graph_ctx, "嗯...")
clarifications = [m for m in msgs if m["type"] == "clarification"]
assert len(clarifications) == 1
assert "clarify" in clarifications[0]["message"]
# Graph should NOT have been called
graph.astream.assert_not_called()
@pytest.mark.asyncio
async def test_low_confidence_triggers_ambiguity(self) -> None:
"""LLMIntentClassifier applies threshold -- low confidence -> ambiguous."""
raw_result = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.2, reasoning="unclear"),),
is_ambiguous=False,
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=raw_result)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
result = await classifier.classify("hmm", AGENTS)
assert result.is_ambiguous
assert result.clarification_question is not None
# ---------------------------------------------------------------------------
# No classifier fallback
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestNoClassifierFallback:
"""Verify system works without intent classifier (falls back to supervisor prompt)."""
@pytest.mark.asyncio
async def test_no_classifier_routes_via_supervisor(self) -> None:
graph, graph_ctx = _make_graph_and_ctx(
classifier_result=None,
chunks=[_chunk("Order 1042 is shipped.", "order_lookup")],
)
msgs = await _dispatch(graph_ctx, "What is order 1042 status?")
tokens = [m for m in msgs if m["type"] == "token"]
assert len(tokens) == 1
completes = [m for m in msgs if m["type"] == "message_complete"]
assert len(completes) == 1

View File

@@ -0,0 +1,159 @@
"""Integration tests for SessionManager + InterruptManager lifecycle.
These tests exercise the in-memory managers together, verifying the full
lifecycle of sessions and interrupts: creation, TTL sliding, interrupt
registration/resolution, and expired-interrupt cleanup.
No database required -- both managers are in-memory.
"""
from __future__ import annotations
import time
from unittest.mock import patch
import pytest
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
pytestmark = pytest.mark.integration
class TestSessionInterruptLifecycle:
"""Tests for the combined session + interrupt lifecycle."""
def test_create_session_register_interrupt_check_status(self) -> None:
"""Full lifecycle: create session, register interrupt, verify both states."""
sm = SessionManager(session_ttl_seconds=3600)
im = InterruptManager(ttl_seconds=300)
# Create a session
state = sm.touch("thread-1")
assert state.thread_id == "thread-1"
assert not state.has_pending_interrupt
assert not sm.is_expired("thread-1")
# Register an interrupt
record = im.register("thread-1", "cancel_order", {"order_id": "1042"})
sm.extend_for_interrupt("thread-1")
assert im.has_pending("thread-1")
session_state = sm.get_state("thread-1")
assert session_state is not None
assert session_state.has_pending_interrupt
# Session should not expire while interrupt is pending
assert not sm.is_expired("thread-1")
def test_interrupt_expiry_after_ttl(self) -> None:
"""Interrupt expires when TTL elapses, even if session is alive."""
im = InterruptManager(ttl_seconds=5)
record = im.register("thread-2", "refund", {"amount": 50})
assert im.has_pending("thread-2")
# Simulate time passing beyond TTL
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 10
assert not im.has_pending("thread-2")
status = im.check_status("thread-2")
assert status is not None
assert status.is_expired
assert status.remaining_seconds == 0.0
def test_interrupt_resolve_flow(self) -> None:
"""Resolving an interrupt removes it from pending and resets session."""
sm = SessionManager(session_ttl_seconds=3600)
im = InterruptManager(ttl_seconds=300)
sm.touch("thread-3")
im.register("thread-3", "delete_account", {"user_id": "u1"})
sm.extend_for_interrupt("thread-3")
# Verify pending state
assert im.has_pending("thread-3")
assert sm.get_state("thread-3").has_pending_interrupt
# Resolve
im.resolve("thread-3")
sm.resolve_interrupt("thread-3")
assert not im.has_pending("thread-3")
session_state = sm.get_state("thread-3")
assert session_state is not None
assert not session_state.has_pending_interrupt
def test_cleanup_expired_removes_old_interrupts(self) -> None:
"""cleanup_expired removes only expired interrupts, keeping active ones."""
im = InterruptManager(ttl_seconds=10)
# Register two interrupts at different times
old_record = im.register("thread-old", "action_old", {})
new_record = im.register("thread-new", "action_new", {})
# Simulate time where only old one expired
with patch("app.interrupt_manager.time") as mock_time:
# Move old record's creation to the past
im._interrupts["thread-old"] = old_record.__class__(
interrupt_id=old_record.interrupt_id,
thread_id=old_record.thread_id,
action=old_record.action,
params=old_record.params,
created_at=time.time() - 20,
ttl_seconds=old_record.ttl_seconds,
)
mock_time.time.return_value = time.time()
expired = im.cleanup_expired()
assert len(expired) == 1
assert expired[0].thread_id == "thread-old"
# New one should still be pending
assert im.has_pending("thread-new")
assert not im.has_pending("thread-old")
def test_session_ttl_sliding_window(self) -> None:
"""Touching a session resets the sliding window TTL."""
sm = SessionManager(session_ttl_seconds=3600)
state1 = sm.touch("thread-5")
first_activity = state1.last_activity
time.sleep(0.01)
state2 = sm.touch("thread-5")
second_activity = state2.last_activity
assert second_activity > first_activity
assert not sm.is_expired("thread-5")
def test_session_expires_after_ttl_without_activity(self) -> None:
"""Session expires when TTL passes without a touch or interrupt."""
sm = SessionManager(session_ttl_seconds=0)
sm.touch("thread-6")
# TTL is 0 so session is immediately expired
assert sm.is_expired("thread-6")
def test_pending_interrupt_prevents_session_expiry(self) -> None:
"""A session with pending interrupt does not expire even with TTL=0."""
sm = SessionManager(session_ttl_seconds=0)
sm.touch("thread-7")
sm.extend_for_interrupt("thread-7")
# Even with TTL=0, session should not expire because of pending interrupt
assert not sm.is_expired("thread-7")
def test_retry_prompt_for_expired_interrupt(self) -> None:
"""InterruptManager generates a retry prompt for expired interrupts."""
im = InterruptManager(ttl_seconds=300)
record = im.register("thread-8", "cancel_order", {"order_id": "1042"})
prompt = im.generate_retry_prompt(record)
assert prompt["type"] == "interrupt_expired"
assert prompt["thread_id"] == "thread-8"
assert "cancel_order" in prompt["action"]
assert "cancel_order" in prompt["message"]
assert "expired" in prompt["message"].lower()

View File

@@ -0,0 +1,360 @@
"""Integration tests for WebSocket message flow.
These tests exercise dispatch_message end-to-end with a mocked LangGraph
graph, verifying streaming, interrupt approval/rejection, session TTL,
and interrupt TTL expiration through the full message handling pipeline.
"""
from __future__ import annotations
import json
import time
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.interrupt_manager import InterruptManager
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class AsyncIterHelper:
"""Make a list behave as an async iterator."""
def __init__(self, items: list) -> None:
self._items = list(items)
def __aiter__(self):
return self
async def __anext__(self):
if not self._items:
raise StopAsyncIteration
return self._items.pop(0)
class FakeWS:
"""Fake WebSocket that records sent messages."""
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_json(self, data: dict) -> None:
self.sent.append(data)
def _chunk(content: str, node: str = "order_lookup") -> tuple:
c = MagicMock()
c.content = content
c.tool_calls = []
return (c, {"langgraph_node": node})
def _tool_chunk(name: str, args: dict, node: str = "order_lookup") -> tuple:
c = MagicMock()
c.content = ""
c.tool_calls = [{"name": name, "args": args}]
return (c, {"langgraph_node": node})
def _state(*, interrupt: bool = False, data: dict | None = None) -> Any:
s = MagicMock()
if interrupt:
obj = MagicMock()
obj.value = data or {"action": "cancel_order", "order_id": "1042"}
t = MagicMock()
t.interrupts = (obj,)
s.tasks = (t,)
else:
s.tasks = ()
return s
def _graph(
chunks: list | None = None,
st: Any = None,
resume_chunks: list | None = None,
) -> MagicMock:
g = MagicMock()
if st is None:
st = _state()
streams = [chunks or [], resume_chunks or []]
idx = {"n": 0}
def make_stream(*a, **kw):
i = min(idx["n"], len(streams) - 1)
idx["n"] += 1
return AsyncIterHelper(list(streams[i]))
g.astream = MagicMock(side_effect=make_stream)
g.aget_state = AsyncMock(return_value=st)
return g
def _make_graph_ctx(graph: MagicMock | None = None) -> GraphContext:
g = graph or _graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
return GraphContext(graph=g, registry=registry, intent_classifier=None)
def _setup(
graph=None,
session_ttl: int = 1800,
interrupt_ttl: int = 1800,
thread_id: str = "t1",
touch: bool = True,
):
"""Create test dependencies. Pre-touches session by default."""
g = graph or _graph()
graph_ctx = _make_graph_ctx(g)
sm = SessionManager(session_ttl_seconds=session_ttl)
im = InterruptManager(ttl_seconds=interrupt_ttl)
cb = TokenUsageCallbackHandler()
ws = FakeWS()
ws_ctx = WebSocketContext(
graph_ctx=graph_ctx, session_manager=sm,
callback_handler=cb, interrupt_manager=im,
)
if touch:
sm.touch(thread_id)
return g, sm, im, cb, ws, ws_ctx
async def _send(ws, ws_ctx, *, thread_id="t1", content="hello", msg_type="message"):
raw = json.dumps({"type": msg_type, "thread_id": thread_id, "content": content})
await dispatch_message(ws, ws_ctx, raw)
async def _respond(ws, ws_ctx, *, thread_id="t1", approved=True):
raw = json.dumps({"type": "interrupt_response", "thread_id": thread_id, "approved": approved})
await dispatch_message(ws, ws_ctx, raw)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.integration
class TestWebSocketHappyPath:
@pytest.mark.asyncio
async def test_send_message_receives_tokens_and_complete(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup(
graph=_graph(chunks=[_chunk("Order 1042 is "), _chunk("shipped.")])
)
await _send(ws, ws_ctx, content="What is the status of order 1042?")
tokens = [m for m in ws.sent if m["type"] == "token"]
assert len(tokens) == 2
assert tokens[0]["content"] == "Order 1042 is "
assert tokens[0]["agent"] == "order_lookup"
assert tokens[1]["content"] == "shipped."
completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 1
assert completes[0]["thread_id"] == "t1"
@pytest.mark.asyncio
async def test_tool_call_streamed(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup(
graph=_graph(chunks=[
_tool_chunk("get_order_status", {"order_id": "1042"}),
_chunk("Order shipped."),
])
)
await _send(ws, ws_ctx, content="Check order 1042")
tools = [m for m in ws.sent if m["type"] == "tool_call"]
assert len(tools) == 1
assert tools[0]["tool"] == "get_order_status"
assert tools[0]["args"] == {"order_id": "1042"}
@pytest.mark.asyncio
async def test_multiple_messages_same_session(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup()
for i in range(3):
await _send(ws, ws_ctx, content=f"msg {i}")
completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 3
@pytest.mark.integration
class TestWebSocketInterruptApproval:
@pytest.mark.asyncio
async def test_interrupt_then_approve(self) -> None:
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
resume = [_chunk("Order 1042 cancelled.", "order_actions")]
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
# Send message -> triggers interrupt
await _send(ws, ws_ctx, content="Cancel order 1042")
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1
assert interrupts[0]["action"] == "cancel_order"
assert interrupts[0]["thread_id"] == "t1"
assert im.has_pending("t1")
# Approve
ws.sent.clear()
await _respond(ws, ws_ctx, approved=True)
tokens = [m for m in ws.sent if m["type"] == "token"]
assert len(tokens) == 1
assert "cancelled" in tokens[0]["content"]
completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 1
assert not im.has_pending("t1")
@pytest.mark.asyncio
async def test_interrupt_then_reject(self) -> None:
st_int = _state(interrupt=True)
resume = [_chunk("Order remains active.", "order_actions")]
g = _graph(chunks=[], st=st_int, resume_chunks=resume)
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g)
await _send(ws, ws_ctx, content="Cancel order 1042")
ws.sent.clear()
await _respond(ws, ws_ctx, approved=False)
tokens = [m for m in ws.sent if m["type"] == "token"]
assert "remains active" in tokens[0]["content"]
@pytest.mark.integration
class TestWebSocketSessionTTL:
@pytest.mark.asyncio
async def test_expired_session_returns_error(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=0)
# Session was touched in _setup, but TTL is 0 so it's already expired
await _send(ws, ws_ctx, content="hello")
assert ws.sent[0]["type"] == "error"
assert "expired" in ws.sent[0]["message"].lower()
@pytest.mark.asyncio
async def test_new_session_not_expired(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
await _send(ws, ws_ctx, content="hello")
completes = [m for m in ws.sent if m["type"] == "message_complete"]
assert len(completes) == 1
@pytest.mark.asyncio
async def test_sliding_window_resets_on_message(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup(session_ttl=3600)
await _send(ws, ws_ctx, content="hello")
first_activity = sm.get_state("t1").last_activity
time.sleep(0.01)
await _send(ws, ws_ctx, content="hello again")
second_activity = sm.get_state("t1").last_activity
assert second_activity > first_activity
@pytest.mark.asyncio
async def test_interrupt_extends_session_ttl(self) -> None:
st_int = _state(interrupt=True)
g = _graph(chunks=[], st=st_int)
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, session_ttl=3600)
await _send(ws, ws_ctx, content="cancel order")
state = sm.get_state("t1")
assert state is not None
assert state.has_pending_interrupt
assert not sm.is_expired("t1")
@pytest.mark.integration
class TestWebSocketValidation:
@pytest.mark.asyncio
async def test_invalid_json(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup()
await dispatch_message(ws, ws_ctx, "not json")
assert ws.sent[0]["type"] == "error"
assert "Invalid JSON" in ws.sent[0]["message"]
@pytest.mark.asyncio
async def test_missing_thread_id(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "content": "hi"})
await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error"
assert "thread_id" in ws.sent[0]["message"]
@pytest.mark.asyncio
async def test_invalid_thread_id_format(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "../../etc/passwd", "content": "hi"})
await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error"
@pytest.mark.asyncio
async def test_missing_content(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1"})
await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error"
@pytest.mark.asyncio
async def test_unknown_message_type(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "foobar", "thread_id": "t1"})
await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error"
assert "Unknown" in ws.sent[0]["message"]
@pytest.mark.asyncio
async def test_message_too_large(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup()
await dispatch_message(ws, ws_ctx, "x" * 40_000)
assert ws.sent[0]["type"] == "error"
assert "too large" in ws.sent[0]["message"].lower()
@pytest.mark.asyncio
async def test_content_too_long(self) -> None:
g, sm, im, cb, ws, ws_ctx = _setup()
raw = json.dumps({"type": "message", "thread_id": "t1", "content": "x" * 10001})
await dispatch_message(ws, ws_ctx, raw)
assert ws.sent[0]["type"] == "error"
assert "too long" in ws.sent[0]["message"].lower()
@pytest.mark.integration
class TestWebSocketInterruptTTL:
@pytest.mark.asyncio
async def test_expired_interrupt_sends_retry_prompt(self) -> None:
st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"})
g = _graph(chunks=[], st=st_int)
g_, sm, im, cb, ws, ws_ctx = _setup(graph=g, interrupt_ttl=5)
# Trigger interrupt
await _send(ws, ws_ctx, content="Cancel order 1042")
interrupts = [m for m in ws.sent if m["type"] == "interrupt"]
assert len(interrupts) == 1
# Expire the interrupt
record = im._interrupts["t1"]
ws.sent.clear()
with patch("app.interrupt_manager.time") as mock_time:
mock_time.time.return_value = record.created_at + 10
await _respond(ws, ws_ctx, approved=True)
assert ws.sent[0]["type"] == "interrupt_expired"
assert "cancel_order" in ws.sent[0]["message"]
assert ws.sent[0]["thread_id"] == "t1"

View File

@@ -0,0 +1 @@
"""Unit tests for app.analytics module."""

View File

@@ -0,0 +1,149 @@
"""Unit tests for app.analytics.api."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
pytestmark = pytest.mark.unit
def _build_app() -> FastAPI:
from app.analytics.api import router
app = FastAPI()
app.include_router(router)
return app
def _make_mock_pool() -> MagicMock:
mock_conn = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
mock_ctx.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.connection.return_value = mock_ctx
return mock_pool
def _make_analytics_result() -> object:
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
return AnalyticsResult(
range="7d",
total_conversations=50,
resolution_rate=0.8,
escalation_rate=0.1,
avg_turns_per_conversation=3.5,
avg_cost_per_conversation_usd=0.02,
agent_usage=(AgentUsage(agent="order_agent", count=30, percentage=60.0),),
interrupt_stats=InterruptStats(total=5, approved=4, rejected=1, expired=0),
)
def _get_analytics(app: FastAPI, path: str = "/api/v1/analytics", **patch_kwargs: object) -> object:
"""Helper: patch get_analytics, make request, return (response, mock)."""
analytics_result = _make_analytics_result()
with (
patch("app.analytics.api.get_analytics", return_value=analytics_result) as mock_ga,
TestClient(app) as client,
):
resp = client.get(path)
return resp, mock_ga
class TestAnalyticsEndpoint:
def test_returns_200_with_default_range(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
resp, _ = _get_analytics(app)
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["error"] is None
assert body["data"]["range"] == "7d"
def test_returns_correct_analytics_structure(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
resp, _ = _get_analytics(app)
data = resp.json()["data"]
assert "total_conversations" in data
assert "resolution_rate" in data
assert "escalation_rate" in data
assert "avg_turns_per_conversation" in data
assert "avg_cost_per_conversation_usd" in data
assert "agent_usage" in data
assert "interrupt_stats" in data
def test_custom_range_7d(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=7d")
assert resp.status_code == 200
mock_ga.assert_called_once()
call_kwargs = mock_ga.call_args
assert call_kwargs[1]["range_days"] == 7 or call_kwargs[0][1] == 7
def test_custom_range_30d(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
resp, mock_ga = _get_analytics(app, "/api/v1/analytics?range=30d")
assert resp.status_code == 200
call_kwargs = mock_ga.call_args
assert call_kwargs[1].get("range_days") == 30 or (
len(call_kwargs[0]) > 1 and call_kwargs[0][1] == 30
)
def test_invalid_range_format_returns_400(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
with TestClient(app) as client:
resp = client.get("/api/v1/analytics?range=invalid")
assert resp.status_code == 400
def test_range_without_d_suffix_returns_400(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
with TestClient(app) as client:
resp = client.get("/api/v1/analytics?range=7")
assert resp.status_code == 400
def test_agent_usage_in_response(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
resp, _ = _get_analytics(app)
data = resp.json()["data"]
assert len(data["agent_usage"]) == 1
assert data["agent_usage"][0]["agent"] == "order_agent"
def test_interrupt_stats_in_response(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
resp, _ = _get_analytics(app)
data = resp.json()["data"]
assert data["interrupt_stats"]["total"] == 5
assert data["interrupt_stats"]["approved"] == 4
def test_envelope_format(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool()
resp, _ = _get_analytics(app)
body = resp.json()
assert "success" in body
assert "data" in body
assert "error" in body

View File

@@ -0,0 +1,155 @@
"""Unit tests for app.analytics.event_recorder."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
pytestmark = pytest.mark.unit
class TestAnalyticsRecorderProtocol:
def test_postgres_recorder_implements_protocol(self) -> None:
from app.analytics.event_recorder import PostgresAnalyticsRecorder
mock_pool = MagicMock()
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
# Runtime check: has record method
assert hasattr(recorder, "record")
assert callable(recorder.record)
def test_noop_recorder_implements_protocol(self) -> None:
from app.analytics.event_recorder import NoOpAnalyticsRecorder
recorder = NoOpAnalyticsRecorder()
assert hasattr(recorder, "record")
assert callable(recorder.record)
class TestNoOpAnalyticsRecorder:
@pytest.mark.asyncio
async def test_record_does_nothing(self) -> None:
from app.analytics.event_recorder import NoOpAnalyticsRecorder
recorder = NoOpAnalyticsRecorder()
# Should not raise
await recorder.record(
thread_id="t1",
event_type="tool_call",
agent_name="order_agent",
tool_name="get_order",
tokens_used=50,
cost_usd=0.001,
)
@pytest.mark.asyncio
async def test_record_with_all_params(self) -> None:
from app.analytics.event_recorder import NoOpAnalyticsRecorder
recorder = NoOpAnalyticsRecorder()
await recorder.record(
thread_id="t1",
event_type="agent_response",
agent_name="fallback",
tool_name=None,
tokens_used=100,
cost_usd=0.002,
duration_ms=150,
success=True,
error_message=None,
metadata={"extra": "data"},
)
@pytest.mark.asyncio
async def test_record_minimal_params(self) -> None:
from app.analytics.event_recorder import NoOpAnalyticsRecorder
recorder = NoOpAnalyticsRecorder()
# Only required params
await recorder.record(thread_id="t1", event_type="conversation_start")
class TestPostgresAnalyticsRecorder:
@pytest.mark.asyncio
async def test_record_executes_insert(self) -> None:
from app.analytics.event_recorder import PostgresAnalyticsRecorder
mock_conn = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
mock_ctx.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.connection.return_value = mock_ctx
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
await recorder.record(
thread_id="t1",
event_type="tool_call",
agent_name="order_agent",
tokens_used=50,
cost_usd=0.001,
)
mock_conn.execute.assert_awaited_once()
call_args = mock_conn.execute.call_args
sql = call_args[0][0]
assert "INSERT INTO analytics_events" in sql
@pytest.mark.asyncio
async def test_record_passes_correct_params(self) -> None:
from app.analytics.event_recorder import PostgresAnalyticsRecorder
mock_conn = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
mock_ctx.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.connection.return_value = mock_ctx
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
await recorder.record(
thread_id="thread-xyz",
event_type="agent_response",
agent_name="discount_agent",
tool_name="apply_discount",
tokens_used=75,
cost_usd=0.002,
duration_ms=300,
success=True,
error_message=None,
metadata={"promo": "10PCT"},
)
call_args = mock_conn.execute.call_args
params = call_args[0][1]
assert params["thread_id"] == "thread-xyz"
assert params["event_type"] == "agent_response"
assert params["agent_name"] == "discount_agent"
assert params["tokens_used"] == 75
@pytest.mark.asyncio
async def test_record_stores_metadata_as_dict(self) -> None:
from app.analytics.event_recorder import PostgresAnalyticsRecorder
mock_conn = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
mock_ctx.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.connection.return_value = mock_ctx
recorder = PostgresAnalyticsRecorder(pool=mock_pool)
await recorder.record(
thread_id="t1",
event_type="tool_call",
metadata={"key": "val"},
)
call_args = mock_conn.execute.call_args
params = call_args[0][1]
# PostgresAnalyticsRecorder wraps metadata with psycopg Json() adapter.
# Unwrap to compare the inner dict.
from psycopg.types.json import Json
meta = params["metadata"]
if isinstance(meta, Json):
meta = meta.obj
assert meta == {"key": "val"}

View File

@@ -0,0 +1,106 @@
"""Unit tests for app.analytics.models."""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
class TestAgentUsage:
def test_agent_usage_construction(self) -> None:
from app.analytics.models import AgentUsage
au = AgentUsage(agent="order_agent", count=10, percentage=50.0)
assert au.agent == "order_agent"
assert au.count == 10
assert au.percentage == 50.0
def test_agent_usage_is_frozen(self) -> None:
from app.analytics.models import AgentUsage
au = AgentUsage(agent="a", count=1, percentage=100.0)
with pytest.raises((AttributeError, TypeError)):
au.count = 2 # type: ignore[misc]
class TestInterruptStats:
def test_interrupt_stats_defaults(self) -> None:
from app.analytics.models import InterruptStats
stats = InterruptStats()
assert stats.total == 0
assert stats.approved == 0
assert stats.rejected == 0
assert stats.expired == 0
def test_interrupt_stats_custom_values(self) -> None:
from app.analytics.models import InterruptStats
stats = InterruptStats(total=10, approved=7, rejected=2, expired=1)
assert stats.total == 10
assert stats.approved == 7
assert stats.rejected == 2
assert stats.expired == 1
def test_interrupt_stats_is_frozen(self) -> None:
from app.analytics.models import InterruptStats
stats = InterruptStats()
with pytest.raises((AttributeError, TypeError)):
stats.total = 5 # type: ignore[misc]
class TestAnalyticsResult:
def test_analytics_result_construction(self) -> None:
from app.analytics.models import AgentUsage, AnalyticsResult, InterruptStats
result = AnalyticsResult(
range="7d",
total_conversations=100,
resolution_rate=0.85,
escalation_rate=0.05,
avg_turns_per_conversation=4.2,
avg_cost_per_conversation_usd=0.03,
agent_usage=(AgentUsage(agent="order_agent", count=60, percentage=60.0),),
interrupt_stats=InterruptStats(total=5, approved=4, rejected=1, expired=0),
)
assert result.range == "7d"
assert result.total_conversations == 100
assert result.resolution_rate == 0.85
assert result.escalation_rate == 0.05
assert result.avg_turns_per_conversation == 4.2
assert result.avg_cost_per_conversation_usd == 0.03
assert len(result.agent_usage) == 1
assert result.interrupt_stats.total == 5
def test_analytics_result_is_frozen(self) -> None:
from app.analytics.models import AnalyticsResult, InterruptStats
result = AnalyticsResult(
range="7d",
total_conversations=0,
resolution_rate=0.0,
escalation_rate=0.0,
avg_turns_per_conversation=0.0,
avg_cost_per_conversation_usd=0.0,
agent_usage=(),
interrupt_stats=InterruptStats(),
)
with pytest.raises((AttributeError, TypeError)):
result.range = "30d" # type: ignore[misc]
def test_analytics_result_empty_agent_usage(self) -> None:
from app.analytics.models import AnalyticsResult, InterruptStats
result = AnalyticsResult(
range="7d",
total_conversations=0,
resolution_rate=0.0,
escalation_rate=0.0,
avg_turns_per_conversation=0.0,
avg_cost_per_conversation_usd=0.0,
agent_usage=(),
interrupt_stats=InterruptStats(),
)
assert result.agent_usage == ()

View File

@@ -0,0 +1,249 @@
"""Unit tests for app.analytics.queries."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
pytestmark = pytest.mark.unit
def _make_pool_with_fetchone(result: dict | None) -> MagicMock:
mock_cursor = AsyncMock()
mock_cursor.fetchone = AsyncMock(return_value=result)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
mock_ctx.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.connection.return_value = mock_ctx
return mock_pool
def _make_pool_with_fetchall(result: list[dict]) -> MagicMock:
mock_cursor = AsyncMock()
mock_cursor.fetchall = AsyncMock(return_value=result)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
mock_ctx.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.connection.return_value = mock_ctx
return mock_pool
class TestResolutionRate:
@pytest.mark.asyncio
async def test_returns_float(self) -> None:
from app.analytics.queries import resolution_rate
pool = _make_pool_with_fetchone({"rate": 0.85})
result = await resolution_rate(pool, range_days=7)
assert isinstance(result, float)
@pytest.mark.asyncio
async def test_zero_state_returns_zero(self) -> None:
from app.analytics.queries import resolution_rate
pool = _make_pool_with_fetchone(None)
result = await resolution_rate(pool, range_days=7)
assert result == 0.0
@pytest.mark.asyncio
async def test_returns_correct_value(self) -> None:
from app.analytics.queries import resolution_rate
pool = _make_pool_with_fetchone({"rate": 0.75})
result = await resolution_rate(pool, range_days=7)
assert result == 0.75
class TestAgentUsageQuery:
@pytest.mark.asyncio
async def test_returns_tuple(self) -> None:
from app.analytics.queries import agent_usage
pool = _make_pool_with_fetchall([])
result = await agent_usage(pool, range_days=7)
assert isinstance(result, tuple)
@pytest.mark.asyncio
async def test_empty_state_returns_empty_tuple(self) -> None:
from app.analytics.queries import agent_usage
pool = _make_pool_with_fetchall([])
result = await agent_usage(pool, range_days=7)
assert result == ()
@pytest.mark.asyncio
async def test_maps_rows_to_agent_usage_objects(self) -> None:
from app.analytics.models import AgentUsage
from app.analytics.queries import agent_usage
pool = _make_pool_with_fetchall([
{"agent": "order_agent", "count": 10, "percentage": 66.7},
{"agent": "discount_agent", "count": 5, "percentage": 33.3},
])
result = await agent_usage(pool, range_days=7)
assert len(result) == 2
assert isinstance(result[0], AgentUsage)
assert result[0].agent == "order_agent"
assert result[0].count == 10
class TestEscalationRate:
@pytest.mark.asyncio
async def test_returns_float(self) -> None:
from app.analytics.queries import escalation_rate
pool = _make_pool_with_fetchone({"rate": 0.05})
result = await escalation_rate(pool, range_days=7)
assert isinstance(result, float)
@pytest.mark.asyncio
async def test_zero_state_returns_zero(self) -> None:
from app.analytics.queries import escalation_rate
pool = _make_pool_with_fetchone(None)
result = await escalation_rate(pool, range_days=7)
assert result == 0.0
class TestCostPerConversation:
@pytest.mark.asyncio
async def test_returns_float(self) -> None:
from app.analytics.queries import cost_per_conversation
pool = _make_pool_with_fetchone({"avg_cost": 0.03})
result = await cost_per_conversation(pool, range_days=7)
assert isinstance(result, float)
@pytest.mark.asyncio
async def test_zero_state_returns_zero(self) -> None:
from app.analytics.queries import cost_per_conversation
pool = _make_pool_with_fetchone(None)
result = await cost_per_conversation(pool, range_days=7)
assert result == 0.0
class TestInterruptStatsQuery:
@pytest.mark.asyncio
async def test_returns_interrupt_stats(self) -> None:
from app.analytics.models import InterruptStats
from app.analytics.queries import interrupt_stats
pool = _make_pool_with_fetchone(
{"total": 10, "approved": 7, "rejected": 2, "expired": 1}
)
result = await interrupt_stats(pool, range_days=7)
assert isinstance(result, InterruptStats)
assert result.total == 10
assert result.approved == 7
@pytest.mark.asyncio
async def test_zero_state_returns_zeros(self) -> None:
from app.analytics.models import InterruptStats
from app.analytics.queries import interrupt_stats
pool = _make_pool_with_fetchone(None)
result = await interrupt_stats(pool, range_days=7)
assert isinstance(result, InterruptStats)
assert result.total == 0
assert result.approved == 0
assert result.rejected == 0
assert result.expired == 0
class TestTotalConversations:
@pytest.mark.asyncio
async def test_returns_count(self) -> None:
from app.analytics.queries import _total_conversations
pool = _make_pool_with_fetchone({"total": 42})
result = await _total_conversations(pool, range_days=7)
assert result == 42
@pytest.mark.asyncio
async def test_zero_state_returns_zero(self) -> None:
from app.analytics.queries import _total_conversations
pool = _make_pool_with_fetchone(None)
result = await _total_conversations(pool, range_days=7)
assert result == 0
class TestAvgTurns:
@pytest.mark.asyncio
async def test_returns_float(self) -> None:
from app.analytics.queries import _avg_turns
pool = _make_pool_with_fetchone({"avg_turns": 3.5})
result = await _avg_turns(pool, range_days=7)
assert result == 3.5
@pytest.mark.asyncio
async def test_zero_state_returns_zero(self) -> None:
from app.analytics.queries import _avg_turns
pool = _make_pool_with_fetchone(None)
result = await _avg_turns(pool, range_days=7)
assert result == 0.0
class TestGetAnalytics:
@pytest.mark.asyncio
async def test_returns_analytics_result(self) -> None:
from unittest.mock import patch
from app.analytics.models import AnalyticsResult, InterruptStats
from app.analytics.queries import get_analytics
mock_pool = MagicMock()
with (
patch("app.analytics.queries.resolution_rate", return_value=0.85),
patch("app.analytics.queries.escalation_rate", return_value=0.05),
patch("app.analytics.queries.cost_per_conversation", return_value=0.03),
patch("app.analytics.queries.agent_usage", return_value=()),
patch(
"app.analytics.queries.interrupt_stats",
return_value=InterruptStats(),
),
patch("app.analytics.queries._total_conversations", return_value=100),
patch("app.analytics.queries._avg_turns", return_value=4.2),
):
result = await get_analytics(mock_pool, range_days=7)
assert isinstance(result, AnalyticsResult)
assert result.range == "7d"
assert result.total_conversations == 100
assert result.resolution_rate == 0.85
@pytest.mark.asyncio
async def test_zero_state_returns_zeros(self) -> None:
from unittest.mock import patch
from app.analytics.models import AnalyticsResult, InterruptStats
from app.analytics.queries import get_analytics
mock_pool = MagicMock()
with (
patch("app.analytics.queries.resolution_rate", return_value=0.0),
patch("app.analytics.queries.escalation_rate", return_value=0.0),
patch("app.analytics.queries.cost_per_conversation", return_value=0.0),
patch("app.analytics.queries.agent_usage", return_value=()),
patch("app.analytics.queries.interrupt_stats", return_value=InterruptStats()),
patch("app.analytics.queries._total_conversations", return_value=0),
patch("app.analytics.queries._avg_turns", return_value=0.0),
):
result = await get_analytics(mock_pool, range_days=7)
assert isinstance(result, AnalyticsResult)
assert result.total_conversations == 0
assert result.resolution_rate == 0.0
assert result.agent_usage == ()

View File

View File

@@ -0,0 +1,249 @@
"""Tests for OpenAPI endpoint classifier module.
RED phase: written before implementation.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.openapi.models import EndpointInfo, ParameterInfo
pytestmark = pytest.mark.unit
def _make_endpoint(
path: str = "/items",
method: str = "GET",
operation_id: str = "list_items",
summary: str = "List items",
description: str = "",
parameters: tuple[ParameterInfo, ...] = (),
) -> EndpointInfo:
return EndpointInfo(
path=path,
method=method,
operation_id=operation_id,
summary=summary,
description=description,
parameters=parameters,
)
_ORDER_PARAM = ParameterInfo(
name="order_id", location="path", required=True, schema_type="string"
)
_CUSTOMER_PARAM = ParameterInfo(
name="customer_id", location="query", required=False, schema_type="string"
)
class TestHeuristicClassifier:
"""Tests for the rule-based HeuristicClassifier."""
async def test_get_classified_as_read(self) -> None:
"""GET endpoints are classified as read access."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="GET")
results = await clf.classify((ep,))
assert len(results) == 1
assert results[0].access_type == "read"
async def test_post_classified_as_write(self) -> None:
"""POST endpoints are classified as write access."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="POST")
results = await clf.classify((ep,))
assert results[0].access_type == "write"
async def test_post_needs_interrupt(self) -> None:
"""POST endpoints require interrupt/approval."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="POST")
results = await clf.classify((ep,))
assert results[0].needs_interrupt is True
async def test_put_classified_as_write(self) -> None:
"""PUT endpoints are classified as write access."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="PUT")
results = await clf.classify((ep,))
assert results[0].access_type == "write"
async def test_delete_classified_as_write_with_interrupt(self) -> None:
"""DELETE endpoints are classified as write and require interrupt."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="DELETE")
results = await clf.classify((ep,))
assert results[0].access_type == "write"
assert results[0].needs_interrupt is True
async def test_get_does_not_need_interrupt(self) -> None:
"""GET endpoints do not require interrupt."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="GET")
results = await clf.classify((ep,))
assert results[0].needs_interrupt is False
async def test_empty_endpoints_returns_empty_tuple(self) -> None:
"""Empty input yields empty output."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
results = await clf.classify(())
assert results == ()
async def test_customer_params_detected_order_id(self) -> None:
"""Parameters named order_id are recognized as customer params."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="GET", parameters=(_ORDER_PARAM,))
results = await clf.classify((ep,))
assert "order_id" in results[0].customer_params
async def test_customer_params_detected_customer_id(self) -> None:
"""Parameters named customer_id are recognized as customer params."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="GET", parameters=(_CUSTOMER_PARAM,))
results = await clf.classify((ep,))
assert "customer_id" in results[0].customer_params
async def test_result_is_tuple(self) -> None:
"""classify returns a tuple (immutable)."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint()
results = await clf.classify((ep,))
assert isinstance(results, tuple)
async def test_classification_has_confidence(self) -> None:
"""Heuristic results have a confidence value between 0 and 1."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint()
results = await clf.classify((ep,))
assert 0.0 <= results[0].confidence <= 1.0
async def test_patch_classified_as_write(self) -> None:
"""PATCH endpoints are classified as write access."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
ep = _make_endpoint(method="PATCH")
results = await clf.classify((ep,))
assert results[0].access_type == "write"
class TestLLMClassifier:
"""Tests for the LLM-backed classifier."""
def _make_mock_llm(self, classifications: list[dict]) -> MagicMock:
"""Create a mock LLM that returns structured classification data."""
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = str(classifications)
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
return mock_llm
async def test_llm_classifier_classifies_endpoints(self) -> None:
"""LLM classifier returns ClassificationResult for each endpoint."""
from app.openapi.classifier import LLMClassifier
ep = _make_endpoint(method="GET")
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = (
'[{"access_type": "read", "agent_group": "support",'
' "needs_interrupt": false, "customer_params": [], "confidence": 0.9}]'
)
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
clf = LLMClassifier(llm=mock_llm)
results = await clf.classify((ep,))
assert len(results) == 1
assert results[0].access_type == "read"
async def test_llm_failure_falls_back_to_heuristic(self) -> None:
"""When LLM raises an exception, falls back to heuristic classifier."""
from app.openapi.classifier import LLMClassifier
ep = _make_endpoint(method="GET")
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
clf = LLMClassifier(llm=mock_llm)
results = await clf.classify((ep,))
# Falls back to heuristic: GET = read
assert len(results) == 1
assert results[0].access_type == "read"
async def test_llm_invalid_json_falls_back_to_heuristic(self) -> None:
"""When LLM returns unparseable output, falls back to heuristic."""
from app.openapi.classifier import LLMClassifier
ep = _make_endpoint(method="DELETE")
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = "this is not valid json at all"
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
clf = LLMClassifier(llm=mock_llm)
results = await clf.classify((ep,))
# Fallback: DELETE = write with interrupt
assert results[0].access_type == "write"
assert results[0].needs_interrupt is True
async def test_llm_empty_endpoints_returns_empty(self) -> None:
"""Empty input yields empty output without calling LLM."""
from app.openapi.classifier import LLMClassifier
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock()
clf = LLMClassifier(llm=mock_llm)
results = await clf.classify(())
assert results == ()
mock_llm.ainvoke.assert_not_called()
class TestClassifierProtocol:
"""Verify both classifiers conform to ClassifierProtocol."""
def test_heuristic_has_classify_method(self) -> None:
"""HeuristicClassifier exposes classify method."""
from app.openapi.classifier import HeuristicClassifier
clf = HeuristicClassifier()
assert hasattr(clf, "classify")
assert callable(clf.classify)
def test_llm_has_classify_method(self) -> None:
"""LLMClassifier exposes classify method."""
from app.openapi.classifier import LLMClassifier
mock_llm = MagicMock()
clf = LLMClassifier(llm=mock_llm)
assert hasattr(clf, "classify")
assert callable(clf.classify)

View File

@@ -0,0 +1,120 @@
"""Tests for OpenAPI spec fetcher module.
RED phase: written before implementation.
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
from app.openapi.ssrf import SSRFError
pytestmark = pytest.mark.unit
_SAMPLE_JSON = '{"openapi": "3.0.0", "info": {"title": "Test", "version": "1.0"}, "paths": {}}'
_SAMPLE_YAML = "openapi: '3.0.0'\ninfo:\n title: Test\n version: '1.0'\npaths: {}\n"
_PUBLIC_IP = "93.184.216.34"
@pytest.fixture
def _mock_public_dns():
with patch("app.openapi.ssrf.resolve_hostname", return_value=[_PUBLIC_IP]):
yield
class TestFetchSpec:
"""Tests for fetch_spec function."""
@pytest.mark.usefixtures("_mock_public_dns")
async def test_fetch_json_spec_succeeds(self, httpx_mock) -> None:
"""Fetch a JSON spec and return parsed dict."""
from app.openapi.fetcher import fetch_spec
httpx_mock.add_response(
url="https://example.com/spec.json",
text=_SAMPLE_JSON,
headers={"content-type": "application/json"},
)
result = await fetch_spec("https://example.com/spec.json")
assert isinstance(result, dict)
assert result["openapi"] == "3.0.0"
@pytest.mark.usefixtures("_mock_public_dns")
async def test_fetch_yaml_spec_succeeds(self, httpx_mock) -> None:
"""Fetch a YAML spec and return parsed dict."""
from app.openapi.fetcher import fetch_spec
httpx_mock.add_response(
url="https://example.com/spec.yaml",
text=_SAMPLE_YAML,
headers={"content-type": "application/x-yaml"},
)
result = await fetch_spec("https://example.com/spec.yaml")
assert isinstance(result, dict)
assert result["openapi"] == "3.0.0"
@pytest.mark.usefixtures("_mock_public_dns")
async def test_fetch_yaml_by_url_extension(self, httpx_mock) -> None:
"""Auto-detect YAML format from .yaml URL extension."""
from app.openapi.fetcher import fetch_spec
httpx_mock.add_response(
url="https://example.com/api.yaml",
text=_SAMPLE_YAML,
headers={"content-type": "text/plain"},
)
result = await fetch_spec("https://example.com/api.yaml")
assert isinstance(result, dict)
assert result["openapi"] == "3.0.0"
async def test_ssrf_blocked_url_raises(self) -> None:
"""SSRF-blocked URL raises SSRFError."""
from app.openapi.fetcher import fetch_spec
with (
patch("app.openapi.ssrf.resolve_hostname", return_value=["10.0.0.1"]),
pytest.raises(SSRFError),
):
await fetch_spec("http://internal.corp/spec.json")
@pytest.mark.usefixtures("_mock_public_dns")
async def test_oversized_response_raises(self, httpx_mock) -> None:
"""Response exceeding 10MB raises ValueError."""
from app.openapi.fetcher import fetch_spec
big_content = "x" * (10 * 1024 * 1024 + 1)
httpx_mock.add_response(
url="https://example.com/huge.json",
text=big_content,
headers={"content-type": "application/json"},
)
with pytest.raises(ValueError, match="too large"):
await fetch_spec("https://example.com/huge.json")
@pytest.mark.usefixtures("_mock_public_dns")
async def test_invalid_json_raises(self, httpx_mock) -> None:
"""Non-parseable JSON raises ValueError."""
from app.openapi.fetcher import fetch_spec
httpx_mock.add_response(
url="https://example.com/bad.json",
text="not valid json {{{",
headers={"content-type": "application/json"},
)
with pytest.raises(ValueError, match="[Pp]arse|[Ii]nvalid|[Dd]ecode"):
await fetch_spec("https://example.com/bad.json")
@pytest.mark.usefixtures("_mock_public_dns")
async def test_invalid_yaml_raises(self, httpx_mock) -> None:
"""Non-parseable YAML raises ValueError."""
from app.openapi.fetcher import fetch_spec
httpx_mock.add_response(
url="https://example.com/bad.yaml",
text=": invalid: yaml: {\n",
headers={"content-type": "application/x-yaml"},
)
with pytest.raises(ValueError, match="[Pp]arse|[Ii]nvalid|[Yy]AML"):
await fetch_spec("https://example.com/bad.yaml")

View File

@@ -0,0 +1,258 @@
"""Tests for OpenAPI tool generator module.
RED phase: written before implementation.
"""
from __future__ import annotations
import pytest
from app.openapi.models import ClassificationResult, EndpointInfo, ParameterInfo
pytestmark = pytest.mark.unit
_BASE_URL = "https://api.example.com"
def _make_endpoint(
path: str = "/items",
method: str = "GET",
operation_id: str = "list_items",
summary: str = "List items",
description: str = "Returns all items",
parameters: tuple[ParameterInfo, ...] = (),
request_body_schema: dict | None = None,
) -> EndpointInfo:
return EndpointInfo(
path=path,
method=method,
operation_id=operation_id,
summary=summary,
description=description,
parameters=parameters,
request_body_schema=request_body_schema,
)
def _make_classification(
endpoint: EndpointInfo,
access_type: str = "read",
needs_interrupt: bool = False,
agent_group: str = "read_agent",
) -> ClassificationResult:
return ClassificationResult(
endpoint=endpoint,
access_type=access_type,
customer_params=(),
agent_group=agent_group,
confidence=0.9,
needs_interrupt=needs_interrupt,
)
_PATH_PARAM = ParameterInfo(
name="item_id", location="path", required=True, schema_type="string"
)
_QUERY_PARAM = ParameterInfo(
name="filter", location="query", required=False, schema_type="string"
)
class TestGenerateToolCode:
"""Tests for generate_tool_code function."""
def test_generate_tool_for_get_endpoint(self) -> None:
"""Generated tool for GET endpoint is a GeneratedTool with non-empty code."""
from app.openapi.generator import generate_tool_code
ep = _make_endpoint(method="GET")
clf = _make_classification(ep)
tool = generate_tool_code(clf, _BASE_URL)
assert tool.function_name == "list_items"
assert tool.code != ""
assert "@tool" in tool.code
def test_generate_tool_contains_function_name(self) -> None:
"""Generated code contains the function name."""
from app.openapi.generator import generate_tool_code
ep = _make_endpoint(operation_id="get_order", method="GET")
clf = _make_classification(ep)
tool = generate_tool_code(clf, _BASE_URL)
assert "get_order" in tool.code
def test_generate_tool_contains_base_url(self) -> None:
"""Generated code contains the base URL."""
from app.openapi.generator import generate_tool_code
ep = _make_endpoint()
clf = _make_classification(ep)
tool = generate_tool_code(clf, _BASE_URL)
assert _BASE_URL in tool.code
def test_generate_tool_contains_http_method(self) -> None:
"""Generated code uses the correct HTTP method."""
from app.openapi.generator import generate_tool_code
ep = _make_endpoint(method="POST")
clf = _make_classification(ep, access_type="write")
tool = generate_tool_code(clf, _BASE_URL)
assert "post" in tool.code.lower()
def test_generate_tool_for_post_with_body(self) -> None:
"""Generated tool for POST includes body parameter."""
from app.openapi.generator import generate_tool_code
ep = _make_endpoint(
method="POST",
request_body_schema={"type": "object", "properties": {"name": {"type": "string"}}},
)
clf = _make_classification(ep, access_type="write")
tool = generate_tool_code(clf, _BASE_URL)
assert tool.code != ""
assert "POST" in tool.code or "post" in tool.code
def test_generate_tool_with_path_params(self) -> None:
"""Generated tool includes path parameter in function signature."""
from app.openapi.generator import generate_tool_code
ep = _make_endpoint(
path="/items/{item_id}",
operation_id="get_item",
parameters=(_PATH_PARAM,),
)
clf = _make_classification(ep)
tool = generate_tool_code(clf, _BASE_URL)
assert "item_id" in tool.code
def test_write_tool_includes_interrupt_marker(self) -> None:
"""Write tools that need interrupt include a marker comment."""
from app.openapi.generator import generate_tool_code
ep = _make_endpoint(method="DELETE", operation_id="delete_item")
clf = _make_classification(ep, access_type="write", needs_interrupt=True)
tool = generate_tool_code(clf, _BASE_URL)
assert "interrupt" in tool.code.lower() or "approval" in tool.code.lower()
def test_generated_code_is_executable(self) -> None:
"""Generated code can be exec'd without syntax errors."""
from app.openapi.generator import generate_tool_code
ep = _make_endpoint(
path="/items/{item_id}",
operation_id="fetch_item",
parameters=(_PATH_PARAM,),
)
clf = _make_classification(ep)
tool = generate_tool_code(clf, _BASE_URL)
# Must be valid Python syntax
compile(tool.code, "<generated>", "exec")
def test_generated_tool_code_exec_imports(self) -> None:
"""Generated code exec'd with required imports does not raise."""
from app.openapi.generator import generate_tool_code
ep = _make_endpoint()
clf = _make_classification(ep)
tool = generate_tool_code(clf, _BASE_URL)
namespace: dict = {}
try:
import httpx
from langchain_core.tools import tool as lc_tool
namespace = {"httpx": httpx, "tool": lc_tool}
exec(tool.code, namespace) # noqa: S102
except ImportError:
pytest.skip("langchain_core not available for exec test")
def test_returns_generated_tool_instance(self) -> None:
"""generate_tool_code returns a GeneratedTool instance."""
from app.openapi.generator import generate_tool_code
from app.openapi.models import GeneratedTool
ep = _make_endpoint()
clf = _make_classification(ep)
tool = generate_tool_code(clf, _BASE_URL)
assert isinstance(tool, GeneratedTool)
def test_generated_tool_is_frozen(self) -> None:
"""GeneratedTool instance is immutable."""
from app.openapi.generator import generate_tool_code
ep = _make_endpoint()
clf = _make_classification(ep)
tool = generate_tool_code(clf, _BASE_URL)
with pytest.raises((AttributeError, TypeError)):
tool.code = "new code" # type: ignore[misc]
class TestGenerateAgentYaml:
"""Tests for generate_agent_yaml function."""
def test_generate_yaml_is_valid_string(self) -> None:
"""generate_agent_yaml returns a non-empty string."""
from app.openapi.generator import generate_agent_yaml
ep = _make_endpoint()
clf = _make_classification(ep)
result = generate_agent_yaml((clf,), _BASE_URL)
assert isinstance(result, str)
assert len(result) > 0
def test_generated_yaml_is_parseable(self) -> None:
"""Output can be parsed as YAML."""
import yaml
from app.openapi.generator import generate_agent_yaml
ep = _make_endpoint()
clf = _make_classification(ep)
result = generate_agent_yaml((clf,), _BASE_URL)
parsed = yaml.safe_load(result)
assert isinstance(parsed, dict)
def test_generated_yaml_contains_agents_key(self) -> None:
"""Generated YAML has an 'agents' key matching AgentConfig format."""
import yaml
from app.openapi.generator import generate_agent_yaml
ep = _make_endpoint()
clf = _make_classification(ep)
result = generate_agent_yaml((clf,), _BASE_URL)
parsed = yaml.safe_load(result)
assert "agents" in parsed
def test_generated_yaml_contains_tool_name(self) -> None:
"""Generated YAML references the tool function name."""
from app.openapi.generator import generate_agent_yaml
ep = _make_endpoint(operation_id="list_orders")
clf = _make_classification(ep)
result = generate_agent_yaml((clf,), _BASE_URL)
assert "list_orders" in result
def test_empty_classifications_returns_empty_agents(self) -> None:
"""No classifications yields YAML with empty agents list."""
import yaml
from app.openapi.generator import generate_agent_yaml
result = generate_agent_yaml((), _BASE_URL)
parsed = yaml.safe_load(result)
assert parsed.get("agents") == [] or parsed.get("agents") is None

View File

@@ -0,0 +1,290 @@
"""Tests for OpenAPI endpoint parser module.
RED phase: written before implementation.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
_MINIMAL_SPEC = {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {},
}
_GET_SPEC = {
"openapi": "3.0.0",
"info": {"title": "Orders API", "version": "1.0.0"},
"paths": {
"/orders/{order_id}": {
"get": {
"operationId": "get_order",
"summary": "Get an order",
"description": "Retrieves a single order by ID",
"parameters": [
{
"name": "order_id",
"in": "path",
"required": True,
"schema": {"type": "string"},
"description": "The order identifier",
}
],
"responses": {
"200": {
"description": "Order found",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {"id": {"type": "string"}},
}
}
},
}
},
}
}
},
}
_POST_SPEC = {
"openapi": "3.0.0",
"info": {"title": "Orders API", "version": "1.0.0"},
"paths": {
"/orders": {
"post": {
"operationId": "create_order",
"summary": "Create an order",
"description": "Creates a new order",
"requestBody": {
"required": True,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"item": {"type": "string"},
"quantity": {"type": "integer"},
},
}
}
},
},
"responses": {"201": {"description": "Created"}},
}
}
},
}
_MULTI_PARAM_SPEC = {
"openapi": "3.0.0",
"info": {"title": "Items API", "version": "1.0.0"},
"paths": {
"/items/{item_id}": {
"get": {
"operationId": "get_item",
"summary": "Get item",
"description": "",
"parameters": [
{
"name": "item_id",
"in": "path",
"required": True,
"schema": {"type": "integer"},
},
{
"name": "include_details",
"in": "query",
"required": False,
"schema": {"type": "boolean"},
},
],
"responses": {"200": {"description": "OK"}},
}
}
},
}
_REF_SPEC = {
"openapi": "3.0.0",
"info": {"title": "Ref API", "version": "1.0.0"},
"components": {
"schemas": {
"Item": {
"type": "object",
"properties": {"id": {"type": "string"}},
}
}
},
"paths": {
"/items": {
"get": {
"operationId": "list_items",
"summary": "List items",
"description": "",
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/Item"}
}
},
}
},
}
}
},
}
_MULTI_ENDPOINT_SPEC = {
"openapi": "3.0.0",
"info": {"title": "Multi API", "version": "1.0.0"},
"paths": {
"/users": {
"get": {
"operationId": "list_users",
"summary": "List users",
"description": "",
"responses": {"200": {"description": "OK"}},
},
"post": {
"operationId": "create_user",
"summary": "Create user",
"description": "",
"responses": {"201": {"description": "Created"}},
},
},
"/users/{id}": {
"delete": {
"operationId": "delete_user",
"summary": "Delete user",
"description": "",
"parameters": [
{"name": "id", "in": "path", "required": True, "schema": {"type": "string"}}
],
"responses": {"204": {"description": "Deleted"}},
}
},
},
}
class TestParseEndpoints:
"""Tests for parse_endpoints function."""
def test_empty_paths_returns_empty_tuple(self) -> None:
"""Spec with no paths yields no endpoints."""
from app.openapi.parser import parse_endpoints
result = parse_endpoints(_MINIMAL_SPEC)
assert result == ()
def test_parse_get_endpoint(self) -> None:
"""Parse a GET endpoint with path parameter."""
from app.openapi.parser import parse_endpoints
result = parse_endpoints(_GET_SPEC)
assert len(result) == 1
ep = result[0]
assert ep.path == "/orders/{order_id}"
assert ep.method == "GET"
assert ep.operation_id == "get_order"
assert ep.summary == "Get an order"
def test_parse_get_endpoint_parameters(self) -> None:
"""Path parameters are extracted correctly."""
from app.openapi.parser import parse_endpoints
result = parse_endpoints(_GET_SPEC)
ep = result[0]
assert len(ep.parameters) == 1
param = ep.parameters[0]
assert param.name == "order_id"
assert param.location == "path"
assert param.required is True
assert param.schema_type == "string"
def test_parse_post_with_request_body(self) -> None:
"""POST endpoint with request body is extracted."""
from app.openapi.parser import parse_endpoints
result = parse_endpoints(_POST_SPEC)
assert len(result) == 1
ep = result[0]
assert ep.method == "POST"
assert ep.request_body_schema is not None
assert ep.request_body_schema["type"] == "object"
def test_parse_path_and_query_params(self) -> None:
"""Both path and query parameters are extracted."""
from app.openapi.parser import parse_endpoints
result = parse_endpoints(_MULTI_PARAM_SPEC)
ep = result[0]
locations = {p.location for p in ep.parameters}
assert "path" in locations
assert "query" in locations
def test_autogenerate_operation_id_when_missing(self) -> None:
"""Auto-generate operation_id when not provided in spec."""
from app.openapi.parser import parse_endpoints
spec = {
"openapi": "3.0.0",
"info": {"title": "Test", "version": "1.0"},
"paths": {
"/things/{id}": {
"get": {
"summary": "Get thing",
"description": "",
"responses": {"200": {"description": "OK"}},
}
}
},
}
result = parse_endpoints(spec)
ep = result[0]
assert ep.operation_id != ""
assert len(ep.operation_id) > 0
def test_multiple_endpoints_extracted(self) -> None:
"""Multiple path+method combinations are all extracted."""
from app.openapi.parser import parse_endpoints
result = parse_endpoints(_MULTI_ENDPOINT_SPEC)
assert len(result) == 3
methods = {ep.method for ep in result}
assert "GET" in methods
assert "POST" in methods
assert "DELETE" in methods
def test_ref_in_response_schema_resolved(self) -> None:
"""$ref in response schema is resolved to the target schema."""
from app.openapi.parser import parse_endpoints
result = parse_endpoints(_REF_SPEC)
ep = result[0]
assert ep.response_schema is not None
# Resolved ref should contain the properties
assert "properties" in ep.response_schema or "$ref" not in ep.response_schema
def test_result_is_tuple(self) -> None:
"""parse_endpoints returns a tuple (immutable)."""
from app.openapi.parser import parse_endpoints
result = parse_endpoints(_GET_SPEC)
assert isinstance(result, tuple)
def test_endpoint_info_is_frozen(self) -> None:
"""EndpointInfo instances are frozen/immutable."""
from app.openapi.parser import parse_endpoints
result = parse_endpoints(_GET_SPEC)
ep = result[0]
with pytest.raises((AttributeError, TypeError)):
ep.method = "POST" # type: ignore[misc]

View File

@@ -0,0 +1,219 @@
"""Tests for OpenAPI review API endpoints.
RED phase: written before implementation.
"""
from __future__ import annotations
import pytest
from fastapi.testclient import TestClient
pytestmark = pytest.mark.unit
_SAMPLE_URL = "https://example.com/api/spec.json"
@pytest.fixture
def client():
"""Create TestClient for the review API app."""
from fastapi import FastAPI
from app.openapi.review_api import router
app = FastAPI()
app.include_router(router)
return TestClient(app)
@pytest.fixture
def job_id(client):
"""Create a job and return its ID."""
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
assert response.status_code == 202
return response.json()["job_id"]
@pytest.fixture
def job_with_classifications(client, job_id):
"""Return job_id for a job that has mock classifications injected."""
from app.openapi.models import ClassificationResult, EndpointInfo
from app.openapi.review_api import _job_store
ep = EndpointInfo(
path="/orders",
method="GET",
operation_id="list_orders",
summary="List orders",
description="",
)
clf = ClassificationResult(
endpoint=ep,
access_type="read",
customer_params=(),
agent_group="read_agent",
confidence=0.9,
needs_interrupt=False,
)
# Inject classifications directly into the store
job = _job_store[job_id]
_job_store[job_id] = {**job, "classifications": [clf]}
return job_id
class TestImportEndpoint:
"""Tests for POST /api/v1/openapi/import."""
def test_post_import_returns_job_id(self, client) -> None:
"""POST /import returns 202 with a job_id."""
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
assert response.status_code == 202
data = response.json()
assert "job_id" in data
assert len(data["job_id"]) > 0
def test_post_import_empty_url_returns_422(self, client) -> None:
"""POST /import with empty URL returns 422 validation error."""
response = client.post("/api/v1/openapi/import", json={"url": ""})
assert response.status_code == 422
def test_post_import_missing_url_returns_422(self, client) -> None:
"""POST /import with missing URL field returns 422."""
response = client.post("/api/v1/openapi/import", json={})
assert response.status_code == 422
def test_post_import_invalid_scheme_returns_422(self, client) -> None:
"""POST /import with non-http URL returns 422."""
response = client.post("/api/v1/openapi/import", json={"url": "ftp://evil.com/spec"})
assert response.status_code == 422
def test_post_import_returns_pending_status(self, client) -> None:
"""Newly created job has pending status."""
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
data = response.json()
assert data["status"] == "pending"
def test_post_import_returns_spec_url(self, client) -> None:
"""Response includes the original spec URL."""
response = client.post("/api/v1/openapi/import", json={"url": _SAMPLE_URL})
data = response.json()
assert data["spec_url"] == _SAMPLE_URL
class TestGetJobEndpoint:
"""Tests for GET /api/v1/openapi/jobs/{job_id}."""
def test_get_job_returns_status(self, client, job_id) -> None:
"""GET /jobs/{id} returns job status."""
response = client.get(f"/api/v1/openapi/jobs/{job_id}")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "job_id" in data
def test_get_unknown_job_returns_404(self, client) -> None:
"""GET /jobs/nonexistent returns 404."""
response = client.get("/api/v1/openapi/jobs/nonexistent-id")
assert response.status_code == 404
def test_get_job_includes_spec_url(self, client, job_id) -> None:
"""Job response includes the spec URL."""
response = client.get(f"/api/v1/openapi/jobs/{job_id}")
data = response.json()
assert data["spec_url"] == _SAMPLE_URL
class TestGetClassificationsEndpoint:
"""Tests for GET /api/v1/openapi/jobs/{job_id}/classifications."""
def test_get_classifications_returns_list(self, client, job_with_classifications) -> None:
"""GET /classifications returns a list."""
response = client.get(
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 1
def test_get_classifications_unknown_job_returns_404(self, client) -> None:
"""GET /classifications for unknown job returns 404."""
response = client.get("/api/v1/openapi/jobs/unknown/classifications")
assert response.status_code == 404
def test_classification_has_expected_fields(self, client, job_with_classifications) -> None:
"""Each classification item has access_type and endpoint fields."""
response = client.get(
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications"
)
item = response.json()[0]
assert "access_type" in item
assert "endpoint" in item
assert "needs_interrupt" in item
class TestUpdateClassificationEndpoint:
"""Tests for PUT /api/v1/openapi/jobs/{job_id}/classifications/{idx}."""
def test_update_classification_succeeds(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 updates the classification."""
response = client.put(
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
)
assert response.status_code == 200
def test_update_unknown_job_returns_404(self, client) -> None:
"""PUT /classifications/0 for unknown job returns 404."""
response = client.put(
"/api/v1/openapi/jobs/unknown/classifications/0",
json={"access_type": "write", "needs_interrupt": True, "agent_group": "write_agent"},
)
assert response.status_code == 404
def test_update_invalid_access_type_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid access_type returns 422."""
response = client.put(
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "admin", "needs_interrupt": True, "agent_group": "x"},
)
assert response.status_code == 422
def test_update_invalid_agent_group_returns_422(self, client, job_with_classifications) -> None:
"""PUT /classifications/0 with invalid agent_group returns 422."""
response = client.put(
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/0",
json={"access_type": "read", "needs_interrupt": False, "agent_group": "evil group!"},
)
assert response.status_code == 422
def test_update_out_of_range_index_returns_404(self, client, job_with_classifications) -> None:
"""PUT /classifications/999 returns 404 for out-of-range index."""
response = client.put(
f"/api/v1/openapi/jobs/{job_with_classifications}/classifications/999",
json={"access_type": "read", "needs_interrupt": False, "agent_group": "read_agent"},
)
assert response.status_code == 404
class TestApproveEndpoint:
"""Tests for POST /api/v1/openapi/jobs/{job_id}/approve."""
def test_approve_job_succeeds(self, client, job_with_classifications) -> None:
"""POST /approve transitions job to approved status."""
response = client.post(
f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
)
assert response.status_code == 200
def test_approve_unknown_job_returns_404(self, client) -> None:
"""POST /approve for unknown job returns 404."""
response = client.post("/api/v1/openapi/jobs/unknown/approve")
assert response.status_code == 404
def test_approve_returns_job_status(self, client, job_with_classifications) -> None:
"""POST /approve returns updated job status."""
response = client.post(
f"/api/v1/openapi/jobs/{job_with_classifications}/approve"
)
data = response.json()
assert "status" in data

View File

@@ -0,0 +1,93 @@
"""Tests for OpenAPI spec validator module.
RED phase: written before implementation.
"""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
_VALID_SPEC = {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
"paths": {
"/items": {
"get": {
"summary": "List items",
"responses": {"200": {"description": "OK"}},
}
}
},
}
class TestValidateSpec:
"""Tests for validate_spec function."""
def test_valid_minimal_spec_passes(self) -> None:
"""A valid minimal spec returns empty error list."""
from app.openapi.validator import validate_spec
errors = validate_spec(_VALID_SPEC)
assert errors == []
def test_missing_openapi_key_returns_error(self) -> None:
"""Missing 'openapi' field returns an error."""
from app.openapi.validator import validate_spec
spec = {k: v for k, v in _VALID_SPEC.items() if k != "openapi"}
errors = validate_spec(spec)
assert len(errors) > 0
assert any("openapi" in e.lower() for e in errors)
def test_missing_info_returns_error(self) -> None:
"""Missing 'info' field returns an error."""
from app.openapi.validator import validate_spec
spec = {k: v for k, v in _VALID_SPEC.items() if k != "info"}
errors = validate_spec(spec)
assert len(errors) > 0
assert any("info" in e.lower() for e in errors)
def test_missing_paths_returns_error(self) -> None:
"""Missing 'paths' field returns an error."""
from app.openapi.validator import validate_spec
spec = {k: v for k, v in _VALID_SPEC.items() if k != "paths"}
errors = validate_spec(spec)
assert len(errors) > 0
assert any("paths" in e.lower() for e in errors)
def test_non_dict_input_returns_error(self) -> None:
"""Non-dict input returns an error without raising."""
from app.openapi.validator import validate_spec
errors = validate_spec("not a dict") # type: ignore[arg-type]
assert len(errors) > 0
def test_empty_dict_returns_multiple_errors(self) -> None:
"""Empty dict returns errors for all required fields."""
from app.openapi.validator import validate_spec
errors = validate_spec({})
# Should have at least one error for each required field
assert len(errors) >= 3
def test_invalid_openapi_version_returns_error(self) -> None:
"""Unsupported openapi version string returns an error."""
from app.openapi.validator import validate_spec
spec = {**_VALID_SPEC, "openapi": "1.0.0"}
errors = validate_spec(spec)
assert len(errors) > 0
def test_errors_are_descriptive_strings(self) -> None:
"""All returned errors are non-empty strings."""
from app.openapi.validator import validate_spec
errors = validate_spec({})
for e in errors:
assert isinstance(e, str)
assert len(e) > 0

View File

@@ -0,0 +1 @@
"""Unit tests for app.replay module."""

View File

@@ -0,0 +1,217 @@
"""Unit tests for app.replay.api."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from app.api_utils import envelope
pytestmark = pytest.mark.unit
def _build_app() -> FastAPI:
from app.replay.api import router
app = FastAPI()
app.include_router(router)
@app.exception_handler(HTTPException)
async def _http_exc(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
return app
def _make_mock_pool(
fetchall_result: list[dict],
*,
count: int | None = None,
) -> MagicMock:
"""Build a mock pool that returns the given rows from fetchall.
When *count* is provided, the first execute() call returns a cursor
whose fetchone() yields ``(count,)`` (for the COUNT query) and the
second call returns the rows via fetchall(). When *count* is None
(the default), a single cursor backed by *fetchall_result* is used
for all calls.
"""
if count is not None:
count_cursor = AsyncMock()
count_cursor.fetchone = AsyncMock(return_value=(count,))
rows_cursor = AsyncMock()
rows_cursor.fetchall = AsyncMock(return_value=fetchall_result)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock(side_effect=[count_cursor, rows_cursor])
else:
mock_cursor = AsyncMock()
mock_cursor.fetchall = AsyncMock(return_value=fetchall_result)
mock_cursor.fetchone = AsyncMock(return_value=None)
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock(return_value=mock_cursor)
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
mock_ctx.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.connection.return_value = mock_ctx
return mock_pool
class TestListConversations:
def test_returns_200_with_empty_list(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
data = body["data"]
assert isinstance(data["conversations"], list)
assert data["total"] == 0
assert data["page"] == 1
assert body["error"] is None
def test_returns_conversations_list(self) -> None:
app = _build_app()
mock_rows = [
{
"thread_id": "t1",
"created_at": "2026-01-01T00:00:00",
"last_activity": "2026-01-01T00:01:00",
"status": "active",
"total_tokens": 100,
"total_cost_usd": 0.01,
}
]
app.state.pool = _make_mock_pool(mock_rows, count=1)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations")
body = resp.json()
assert resp.status_code == 200
data = body["data"]
assert len(data["conversations"]) == 1
assert data["conversations"][0]["thread_id"] == "t1"
assert data["total"] == 1
def test_pagination_defaults(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations")
assert resp.status_code == 200
def test_pagination_custom_params(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations?page=2&per_page=10")
assert resp.status_code == 200
def test_per_page_max_capped_at_100(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([], count=0)
with TestClient(app) as client:
resp = client.get("/api/v1/conversations?per_page=200")
# FastAPI Query(le=100) rejects values > 100
assert resp.status_code == 422
class TestGetReplay:
def test_thread_not_found_returns_404(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/v1/replay/nonexistent-thread")
assert resp.status_code == 404
def test_returns_replay_page_for_existing_thread(self) -> None:
app = _build_app()
mock_rows = [
{
"thread_id": "thread-123",
"checkpoint_id": "cp-001",
"checkpoint": {
"channel_values": {
"messages": [{"type": "human", "content": "Hello"}]
}
},
"metadata": {},
}
]
app.state.pool = _make_mock_pool(mock_rows)
with TestClient(app) as client:
resp = client.get("/api/v1/replay/thread-123")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["thread_id"] == "thread-123"
assert "steps" in body["data"]
assert "total_steps" in body["data"]
assert "page" in body["data"]
assert "per_page" in body["data"]
def test_replay_pagination_params(self) -> None:
app = _build_app()
mock_rows = [
{
"thread_id": "t1",
"checkpoint_id": "cp-001",
"checkpoint": {
"channel_values": {"messages": [{"type": "human", "content": "Hi"}]}
},
"metadata": {},
}
]
app.state.pool = _make_mock_pool(mock_rows)
with TestClient(app) as client:
resp = client.get("/api/v1/replay/t1?page=1&per_page=5")
assert resp.status_code == 200
def test_error_response_has_envelope(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/v1/replay/missing")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] is not None
def test_invalid_thread_id_returns_400(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/v1/replay/id%20with%20spaces")
assert resp.status_code == 400
def test_thread_id_special_chars_returns_400(self) -> None:
app = _build_app()
app.state.pool = _make_mock_pool([])
with TestClient(app) as client:
resp = client.get("/api/v1/replay/id;DROP TABLE")
assert resp.status_code == 400

View File

@@ -0,0 +1,134 @@
"""Unit tests for app.replay.models."""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
class TestStepType:
def test_all_step_types_exist(self) -> None:
from app.replay.models import StepType
assert StepType.user_message
assert StepType.supervisor_routing
assert StepType.tool_call
assert StepType.tool_result
assert StepType.agent_response
assert StepType.interrupt
def test_step_type_values(self) -> None:
from app.replay.models import StepType
assert StepType.user_message.value == "user_message"
assert StepType.tool_call.value == "tool_call"
assert StepType.agent_response.value == "agent_response"
class TestReplayStep:
def test_minimal_replay_step(self) -> None:
from app.replay.models import ReplayStep, StepType
step = ReplayStep(step=1, type=StepType.user_message, timestamp="2026-01-01T00:00:00Z")
assert step.step == 1
assert step.type == StepType.user_message
assert step.timestamp == "2026-01-01T00:00:00Z"
assert step.content == ""
assert step.agent is None
assert step.tool is None
assert step.params is None
assert step.result is None
assert step.reasoning is None
assert step.tokens is None
assert step.duration_ms is None
def test_full_replay_step(self) -> None:
from app.replay.models import ReplayStep, StepType
step = ReplayStep(
step=2,
type=StepType.tool_call,
timestamp="2026-01-01T00:00:01Z",
content="calling get_order",
agent="order_agent",
tool="get_order_status",
params={"order_id": "ORD-123"},
result={"status": "shipped"},
reasoning="user asked about order",
tokens=50,
duration_ms=200,
)
assert step.step == 2
assert step.agent == "order_agent"
assert step.tool == "get_order_status"
assert step.params == {"order_id": "ORD-123"}
assert step.tokens == 50
def test_replay_step_is_frozen(self) -> None:
from app.replay.models import ReplayStep, StepType
step = ReplayStep(step=1, type=StepType.user_message, timestamp="2026-01-01T00:00:00Z")
with pytest.raises((AttributeError, TypeError)):
step.step = 99 # type: ignore[misc]
def test_replay_step_params_is_immutable_copy(self) -> None:
from app.replay.models import ReplayStep, StepType
params = {"key": "value"}
step = ReplayStep(
step=1,
type=StepType.tool_call,
timestamp="2026-01-01T00:00:00Z",
params=params,
)
# Modifying original dict should not affect step
params["new_key"] = "new_value"
assert "new_key" not in (step.params or {})
class TestReplayPage:
def test_replay_page_construction(self) -> None:
from app.replay.models import ReplayPage, ReplayStep, StepType
steps = (
ReplayStep(step=1, type=StepType.user_message, timestamp="2026-01-01T00:00:00Z"),
ReplayStep(step=2, type=StepType.agent_response, timestamp="2026-01-01T00:00:01Z"),
)
page = ReplayPage(
thread_id="thread-123",
total_steps=2,
page=1,
per_page=20,
steps=steps,
)
assert page.thread_id == "thread-123"
assert page.total_steps == 2
assert page.page == 1
assert page.per_page == 20
assert len(page.steps) == 2
def test_replay_page_is_frozen(self) -> None:
from app.replay.models import ReplayPage
page = ReplayPage(
thread_id="t1",
total_steps=0,
page=1,
per_page=20,
steps=(),
)
with pytest.raises((AttributeError, TypeError)):
page.page = 2 # type: ignore[misc]
def test_replay_page_empty_steps(self) -> None:
from app.replay.models import ReplayPage
page = ReplayPage(
thread_id="t1",
total_steps=0,
page=1,
per_page=20,
steps=(),
)
assert page.steps == ()

View File

@@ -0,0 +1,257 @@
"""Unit tests for app.replay.transformer."""
from __future__ import annotations
import pytest
pytestmark = pytest.mark.unit
def _make_row(messages: list[dict], metadata: dict | None = None) -> dict:
"""Helper to build a checkpoint row with the given messages."""
return {
"thread_id": "thread-abc",
"checkpoint_id": "cp-001",
"checkpoint": {"channel_values": {"messages": messages}},
"metadata": metadata or {},
}
class TestTransformCheckpoints:
def test_empty_rows_returns_empty_list(self) -> None:
from app.replay.transformer import transform_checkpoints
result = transform_checkpoints([])
assert result == []
def test_human_message_produces_user_message_step(self) -> None:
from app.replay.models import StepType
from app.replay.transformer import transform_checkpoints
rows = [_make_row([{"type": "human", "content": "Hello, I need help"}])]
steps = transform_checkpoints(rows)
assert len(steps) == 1
assert steps[0].type == StepType.user_message
assert steps[0].content == "Hello, I need help"
assert steps[0].step == 1
def test_ai_message_with_content_produces_agent_response(self) -> None:
from app.replay.models import StepType
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[{"type": "ai", "content": "I can help you with that.", "tool_calls": []}],
metadata={"writes": {"some_agent": "response"}},
)
]
steps = transform_checkpoints(rows)
assert len(steps) == 1
assert steps[0].type == StepType.agent_response
assert steps[0].content == "I can help you with that."
def test_ai_message_with_tool_calls_produces_tool_call_step(self) -> None:
from app.replay.models import StepType
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{
"type": "ai",
"content": "",
"tool_calls": [
{
"name": "get_order_status",
"args": {"order_id": "ORD-123"},
"id": "call_abc",
}
],
}
]
)
]
steps = transform_checkpoints(rows)
assert len(steps) == 1
assert steps[0].type == StepType.tool_call
assert steps[0].tool == "get_order_status"
assert steps[0].params == {"order_id": "ORD-123"}
def test_tool_message_produces_tool_result_step(self) -> None:
from app.replay.models import StepType
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{
"type": "tool",
"content": '{"status": "shipped"}',
"name": "get_order_status",
}
]
)
]
steps = transform_checkpoints(rows)
assert len(steps) == 1
assert steps[0].type == StepType.tool_result
assert steps[0].tool == "get_order_status"
def test_multiple_messages_sequential_steps(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{"type": "human", "content": "Help"},
{"type": "ai", "content": "Sure!", "tool_calls": []},
]
)
]
steps = transform_checkpoints(rows)
assert len(steps) == 2
assert steps[0].step == 1
assert steps[1].step == 2
def test_unknown_message_type_skipped(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [_make_row([{"type": "unknown_type", "content": "test"}])]
steps = transform_checkpoints(rows)
# Should not crash; unknown types may be skipped
assert isinstance(steps, list)
def test_row_missing_checkpoint_skipped(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [{"thread_id": "t1", "checkpoint_id": "cp1", "checkpoint": None, "metadata": {}}]
steps = transform_checkpoints(rows)
assert isinstance(steps, list)
def test_row_missing_messages_key_skipped(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [{"thread_id": "t1", "checkpoint_id": "cp1", "checkpoint": {}, "metadata": {}}]
steps = transform_checkpoints(rows)
assert isinstance(steps, list)
def test_multiple_rows_steps_are_continuous(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
_make_row([{"type": "human", "content": "Q1"}]),
_make_row([{"type": "ai", "content": "A1", "tool_calls": []}]),
]
steps = transform_checkpoints(rows)
assert len(steps) == 2
assert steps[0].step == 1
assert steps[1].step == 2
def test_timestamps_are_strings(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [_make_row([{"type": "human", "content": "Hi"}])]
steps = transform_checkpoints(rows)
assert isinstance(steps[0].timestamp, str)
def test_list_content_joined_to_string(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{
"type": "human",
"content": [
{"text": "Hello"},
{"text": " world"},
],
}
]
)
]
steps = transform_checkpoints(rows)
assert len(steps) == 1
assert steps[0].content == "Hello world"
def test_checkpoint_as_string_skipped(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
{
"thread_id": "t1",
"checkpoint_id": "cp1",
"checkpoint": "not-a-dict",
"metadata": {},
}
]
steps = transform_checkpoints(rows)
assert steps == []
def test_channel_values_not_dict_skipped(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
{
"thread_id": "t1",
"checkpoint_id": "cp1",
"checkpoint": {"channel_values": "bad"},
"metadata": {},
}
]
steps = transform_checkpoints(rows)
assert steps == []
def test_tool_result_valid_json_parsed(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{
"type": "tool",
"content": '{"order_id": "123", "status": "shipped"}',
"name": "get_order_status",
}
]
)
]
steps = transform_checkpoints(rows)
assert len(steps) == 1
assert steps[0].result == {"order_id": "123", "status": "shipped"}
def test_tool_result_invalid_json_wrapped(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{
"type": "tool",
"content": "not valid json",
"name": "some_tool",
}
]
)
]
steps = transform_checkpoints(rows)
assert len(steps) == 1
assert steps[0].result == {"raw": "not valid json"}
def test_malformed_message_skipped_gracefully(self) -> None:
from app.replay.transformer import transform_checkpoints
rows = [
_make_row(
[
{"type": "human", "content": "Good message"},
42, # not a dict -- will raise in _step_from_message
{"type": "ai", "content": "Response", "tool_calls": []},
]
)
]
steps = transform_checkpoints(rows)
# The malformed message is skipped; the other two produce steps.
assert len(steps) == 2
assert steps[0].step == 1
assert steps[1].step == 2

View File

@@ -7,10 +7,41 @@ import pytest
from app.config import Settings
def _isolated_settings(**kwargs: object) -> Settings:
"""Create a Settings instance that ignores .env files and process env vars.
pydantic-settings reads from env_file and environment by default, which
causes test results to depend on the machine they run on. We override
model_config at the class level temporarily so that every test gets
deterministic results.
"""
# Build a throwaway subclass that disables env-file and env-var loading.
class _IsolatedSettings(Settings):
model_config = Settings.model_config.copy()
model_config["env_file"] = None # type: ignore[assignment]
model_config["env_ignore_empty"] = True
# _env_parse_none_str makes pydantic-settings treat missing env vars as
# absent rather than empty-string, so required fields will raise.
import os
env_backup = os.environ.copy()
# Strip all env vars that Settings knows about so they can't leak in.
settings_fields = set(Settings.model_fields)
for key in list(os.environ):
if key.lower() in settings_fields:
del os.environ[key]
try:
return _IsolatedSettings(**kwargs) # type: ignore[return-value]
finally:
os.environ.clear()
os.environ.update(env_backup)
@pytest.mark.unit
class TestSettings:
def test_default_values(self) -> None:
settings = Settings(
settings = _isolated_settings(
database_url="postgresql://x:x@localhost/db",
anthropic_api_key="key",
)
@@ -20,7 +51,7 @@ class TestSettings:
assert settings.interrupt_ttl_minutes == 30
def test_custom_values(self) -> None:
settings = Settings(
settings = _isolated_settings(
database_url="postgresql://x:x@localhost/db",
llm_provider="openai",
llm_model="gpt-4o",
@@ -33,18 +64,18 @@ class TestSettings:
def test_invalid_provider_rejected(self) -> None:
with pytest.raises(Exception):
Settings(
_isolated_settings(
database_url="postgresql://x:x@localhost/db",
llm_provider="invalid",
)
def test_missing_database_url_rejected(self) -> None:
with pytest.raises(Exception):
Settings(anthropic_api_key="key")
_isolated_settings(anthropic_api_key="key")
def test_empty_api_key_for_provider_rejected(self) -> None:
with pytest.raises(ValueError, match="API key"):
Settings(
_isolated_settings(
database_url="postgresql://x:x@localhost/db",
llm_provider="anthropic",
anthropic_api_key="",
@@ -52,9 +83,27 @@ class TestSettings:
def test_wrong_provider_key_rejected(self) -> None:
with pytest.raises(ValueError, match="API key"):
Settings(
_isolated_settings(
database_url="postgresql://x:x@localhost/db",
llm_provider="openai",
anthropic_api_key="key",
openai_api_key="",
)
def test_azure_openai_missing_endpoint_rejected(self) -> None:
with pytest.raises(ValueError, match="AZURE_OPENAI_ENDPOINT"):
_isolated_settings(
database_url="postgresql://x:x@localhost/db",
llm_provider="azure_openai",
azure_openai_api_key="key",
azure_openai_deployment="my-deploy",
)
def test_azure_openai_missing_deployment_rejected(self) -> None:
with pytest.raises(ValueError, match="AZURE_OPENAI_DEPLOYMENT"):
_isolated_settings(
database_url="postgresql://x:x@localhost/db",
llm_provider="azure_openai",
azure_openai_api_key="key",
azure_openai_endpoint="https://example.openai.azure.com",
)

View File

@@ -0,0 +1,156 @@
"""Tests for app.conversation_tracker module."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.conversation_tracker import (
ConversationTrackerProtocol,
NoOpConversationTracker,
PostgresConversationTracker,
)
pytestmark = pytest.mark.unit
def _make_pool() -> AsyncMock:
"""Create a mock async connection pool."""
pool = AsyncMock()
conn = AsyncMock()
conn.execute = AsyncMock()
pool.connection = MagicMock(return_value=_AsyncContextManager(conn))
return pool, conn
class _AsyncContextManager:
"""Async context manager helper."""
def __init__(self, value: object) -> None:
self._value = value
async def __aenter__(self) -> object:
return self._value
async def __aexit__(self, *args: object) -> None:
pass
class TestConversationTrackerProtocol:
def test_noop_satisfies_protocol(self) -> None:
tracker = NoOpConversationTracker()
assert isinstance(tracker, ConversationTrackerProtocol)
def test_postgres_satisfies_protocol(self) -> None:
tracker = PostgresConversationTracker()
assert isinstance(tracker, ConversationTrackerProtocol)
class TestNoOpConversationTracker:
@pytest.mark.asyncio
async def test_ensure_conversation_does_nothing(self) -> None:
tracker = NoOpConversationTracker()
pool = AsyncMock()
# Should not raise
await tracker.ensure_conversation(pool, "thread-1")
@pytest.mark.asyncio
async def test_record_turn_does_nothing(self) -> None:
tracker = NoOpConversationTracker()
pool = AsyncMock()
await tracker.record_turn(pool, "thread-1", "agent_a", 100, 0.05)
@pytest.mark.asyncio
async def test_resolve_does_nothing(self) -> None:
tracker = NoOpConversationTracker()
pool = AsyncMock()
await tracker.resolve(pool, "thread-1", "resolved")
@pytest.mark.asyncio
async def test_accepts_none_agent_name(self) -> None:
tracker = NoOpConversationTracker()
pool = AsyncMock()
await tracker.record_turn(pool, "thread-1", None, 0, 0.0)
class TestPostgresConversationTracker:
@pytest.mark.asyncio
async def test_ensure_conversation_executes_insert(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.ensure_conversation(pool, "thread-abc")
conn.execute.assert_awaited_once()
sql, params = conn.execute.call_args[0]
assert "INSERT" in sql
assert "ON CONFLICT" in sql
assert params["thread_id"] == "thread-abc"
@pytest.mark.asyncio
async def test_record_turn_executes_update(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.record_turn(pool, "thread-abc", "order_agent", 250, 0.12)
conn.execute.assert_awaited_once()
sql, params = conn.execute.call_args[0]
assert "UPDATE" in sql
assert params["thread_id"] == "thread-abc"
assert params["agent_name"] == "order_agent"
assert params["tokens"] == 250
assert params["cost"] == 0.12
@pytest.mark.asyncio
async def test_record_turn_accepts_none_agent_name(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.record_turn(pool, "thread-abc", None, 0, 0.0)
conn.execute.assert_awaited_once()
sql, params = conn.execute.call_args[0]
assert params["agent_name"] is None
@pytest.mark.asyncio
async def test_resolve_executes_update(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.resolve(pool, "thread-abc", "resolved")
conn.execute.assert_awaited_once()
sql, params = conn.execute.call_args[0]
assert "UPDATE" in sql
assert params["thread_id"] == "thread-abc"
assert params["resolution_type"] == "resolved"
@pytest.mark.asyncio
async def test_resolve_sets_ended_at(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.resolve(pool, "thread-abc", "escalated")
sql, params = conn.execute.call_args[0]
assert "ended_at" in sql.lower()
@pytest.mark.asyncio
async def test_ensure_conversation_with_special_thread_id(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.ensure_conversation(pool, "thread-123-abc-XYZ")
conn.execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_record_turn_with_zero_cost(self) -> None:
tracker = PostgresConversationTracker()
pool, conn = _make_pool()
await tracker.record_turn(pool, "t1", "agent", 0, 0.0)
conn.execute.assert_awaited_once()

View File

@@ -55,7 +55,7 @@ class TestDbModule:
from app.db import setup_app_tables
await setup_app_tables(mock_pool)
assert mock_conn.execute.await_count == 2
assert mock_conn.execute.await_count == 5
def test_ddl_statements_valid(self) -> None:
assert "CREATE TABLE IF NOT EXISTS conversations" in _CONVERSATIONS_DDL

View File

@@ -0,0 +1,55 @@
"""Phase 4 DB migration tests -- analytics_events table and conversation columns."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
pytestmark = pytest.mark.unit
class TestAnalyticsEventsDDL:
def test_analytics_events_ddl_exists(self) -> None:
from app.db import _ANALYTICS_EVENTS_DDL
assert "CREATE TABLE IF NOT EXISTS analytics_events" in _ANALYTICS_EVENTS_DDL
def test_analytics_events_ddl_has_required_columns(self) -> None:
from app.db import _ANALYTICS_EVENTS_DDL
assert "thread_id" in _ANALYTICS_EVENTS_DDL
assert "event_type" in _ANALYTICS_EVENTS_DDL
assert "agent_name" in _ANALYTICS_EVENTS_DDL
assert "tool_name" in _ANALYTICS_EVENTS_DDL
assert "tokens_used" in _ANALYTICS_EVENTS_DDL
assert "cost_usd" in _ANALYTICS_EVENTS_DDL
assert "duration_ms" in _ANALYTICS_EVENTS_DDL
assert "success" in _ANALYTICS_EVENTS_DDL
assert "error_message" in _ANALYTICS_EVENTS_DDL
assert "metadata" in _ANALYTICS_EVENTS_DDL
def test_conversations_migration_ddl_exists(self) -> None:
from app.db import _CONVERSATIONS_MIGRATION_DDL
assert "ALTER TABLE" in _CONVERSATIONS_MIGRATION_DDL
assert "resolution_type" in _CONVERSATIONS_MIGRATION_DDL
assert "agents_used" in _CONVERSATIONS_MIGRATION_DDL
assert "turn_count" in _CONVERSATIONS_MIGRATION_DDL
assert "ended_at" in _CONVERSATIONS_MIGRATION_DDL
assert "IF NOT EXISTS" in _CONVERSATIONS_MIGRATION_DDL
@pytest.mark.asyncio
async def test_setup_app_tables_executes_analytics_ddl(self) -> None:
mock_conn = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
mock_ctx.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.connection.return_value = mock_ctx
from app.db import setup_app_tables
await setup_app_tables(mock_pool)
# Now expects 5 statements: conversations, interrupts, sessions, analytics_events, migrations
assert mock_conn.execute.await_count == 5

View File

@@ -0,0 +1,79 @@
"""Tests for app.agents.discount module."""
from __future__ import annotations
from unittest.mock import patch
import pytest
from app.agents.discount import apply_discount, generate_coupon
@pytest.mark.unit
class TestApplyDiscount:
def test_invalid_discount_zero(self) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 0})
assert result["status"] == "error"
assert "Invalid" in result["message"]
def test_invalid_discount_over_100(self) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 101})
assert result["status"] == "error"
def test_invalid_discount_negative(self) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": -5})
assert result["status"] == "error"
@patch("app.agents.discount.interrupt", return_value=True)
def test_approved_discount(self, mock_interrupt) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 10})
assert result["status"] == "applied"
assert result["discount_percent"] == 10
assert "1042" in result["message"]
@patch("app.agents.discount.interrupt", return_value=False)
def test_rejected_discount(self, mock_interrupt) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 10})
assert result["status"] == "declined"
@patch("app.agents.discount.interrupt", return_value={"approved": True})
def test_approved_via_dict(self, mock_interrupt) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 50})
assert result["status"] == "applied"
@patch("app.agents.discount.interrupt", return_value={"approved": False})
def test_rejected_via_dict(self, mock_interrupt) -> None:
result = apply_discount.invoke({"order_id": "1042", "discount_percent": 50})
assert result["status"] == "declined"
@pytest.mark.unit
class TestGenerateCoupon:
def test_valid_coupon(self) -> None:
result = generate_coupon.invoke({"discount_percent": 15, "expiry_days": 7})
assert result["status"] == "generated"
assert result["discount_percent"] == 15
assert result["expiry_days"] == 7
assert result["coupon_code"].startswith("SAVE15-")
def test_default_expiry(self) -> None:
result = generate_coupon.invoke({"discount_percent": 20})
assert result["status"] == "generated"
assert result["expiry_days"] == 30
def test_invalid_discount_zero(self) -> None:
result = generate_coupon.invoke({"discount_percent": 0})
assert result["status"] == "error"
def test_invalid_discount_over_100(self) -> None:
result = generate_coupon.invoke({"discount_percent": 101})
assert result["status"] == "error"
def test_invalid_expiry(self) -> None:
result = generate_coupon.invoke({"discount_percent": 10, "expiry_days": 0})
assert result["status"] == "error"
def test_coupon_codes_unique(self) -> None:
r1 = generate_coupon.invoke({"discount_percent": 10})
r2 = generate_coupon.invoke({"discount_percent": 10})
assert r1["coupon_code"] != r2["coupon_code"]

View File

@@ -0,0 +1,210 @@
"""Edge case tests for ws_handler input validation and rate limiting."""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.callbacks import TokenUsageCallbackHandler
from app.graph_context import GraphContext
from app.session_manager import SessionManager
from app.ws_context import WebSocketContext
from app.ws_handler import dispatch_message
pytestmark = pytest.mark.unit
def _make_ws() -> AsyncMock:
ws = AsyncMock()
ws.send_json = AsyncMock()
return ws
def _make_graph() -> MagicMock:
graph = AsyncMock()
class AsyncIterHelper:
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
graph.astream = MagicMock(return_value=AsyncIterHelper())
state = MagicMock()
state.tasks = ()
graph.aget_state = AsyncMock(return_value=state)
return graph
def _make_ws_ctx(sm: SessionManager | None = None) -> WebSocketContext:
graph = _make_graph()
registry = MagicMock()
registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=graph, registry=registry, intent_classifier=None)
return WebSocketContext(
graph_ctx=graph_ctx,
session_manager=sm or SessionManager(),
callback_handler=TokenUsageCallbackHandler(),
)
@pytest.mark.unit
class TestEmptyMessageHandling:
@pytest.mark.asyncio
async def test_empty_message_content_returns_error(self) -> None:
ws = _make_ws()
sm = SessionManager()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": ""})
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
msg_lower = call_data["message"].lower()
assert "content" in msg_lower or "missing" in msg_lower
@pytest.mark.asyncio
async def test_whitespace_only_message_treated_as_empty(self) -> None:
ws = _make_ws()
sm = SessionManager()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
msg = json.dumps({"type": "message", "thread_id": "t1", "content": " "})
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@pytest.mark.unit
class TestOversizedMessageHandling:
@pytest.mark.asyncio
async def test_content_over_10000_chars_returns_error(self) -> None:
ws = _make_ws()
sm = SessionManager()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
content = "x" * 10001
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, ws_ctx, msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "too long" in call_data["message"].lower()
@pytest.mark.asyncio
async def test_content_exactly_10000_chars_is_accepted(self) -> None:
ws = _make_ws()
sm = SessionManager()
ws_ctx = _make_ws_ctx(sm=sm)
sm.touch("t1")
content = "x" * 10000
msg = json.dumps({"type": "message", "thread_id": "t1", "content": content})
await dispatch_message(ws, ws_ctx, msg)
last_call = ws.send_json.call_args[0][0]
# Should be processed, not an error about length
msg_text = last_call.get("message", "").lower()
assert last_call["type"] != "error" or "too long" not in msg_text
@pytest.mark.asyncio
async def test_raw_message_over_32kb_returns_error(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
large_msg = "x" * 40_000
await dispatch_message(ws, ws_ctx, large_msg)
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "too large" in call_data["message"].lower()
@pytest.mark.unit
class TestInvalidJsonHandling:
@pytest.mark.asyncio
async def test_invalid_json_returns_error(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, ws_ctx, "not valid json {{")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
assert "invalid json" in call_data["message"].lower()
@pytest.mark.asyncio
async def test_empty_string_returns_json_error(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, ws_ctx, "")
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@pytest.mark.asyncio
async def test_json_array_not_object_returns_error(self) -> None:
ws = _make_ws()
ws_ctx = _make_ws_ctx()
await dispatch_message(ws, ws_ctx, '["not", "an", "object"]')
call_data = ws.send_json.call_args[0][0]
assert call_data["type"] == "error"
@pytest.mark.unit
class TestRateLimiting:
@pytest.mark.asyncio
async def test_rapid_fire_messages_rate_limited(self) -> None:
ws = _make_ws()
sm = SessionManager()
sm.touch("t1")
# Simulate 11 rapid messages (exceeds 10 per 10 seconds limit)
rate_limit_triggered = False
for i in range(11):
ws_ctx = _make_ws_ctx(sm=sm)
await dispatch_message(ws, ws_ctx, json.dumps({
"type": "message",
"thread_id": "t1",
"content": f"message {i}",
}))
last_call = ws.send_json.call_args[0][0]
if last_call["type"] == "error" and "rate" in last_call.get("message", "").lower():
rate_limit_triggered = True
break
assert rate_limit_triggered, "Rate limiting should trigger after 10 rapid messages"
@pytest.mark.asyncio
async def test_different_threads_have_separate_rate_limits(self) -> None:
ws = _make_ws()
sm = SessionManager()
sm.touch("t1")
sm.touch("t2")
# Send 5 messages on t1 and 5 on t2 -- neither should be rate limited
for i in range(5):
ws_ctx1 = _make_ws_ctx(sm=sm)
ws_ctx2 = _make_ws_ctx(sm=sm)
await dispatch_message(ws, ws_ctx1, json.dumps({
"type": "message", "thread_id": "t1", "content": f"msg {i}",
}))
await dispatch_message(ws, ws_ctx2, json.dumps({
"type": "message", "thread_id": "t2", "content": f"msg {i}",
}))
last_call = ws.send_json.call_args[0][0]
assert "rate" not in last_call.get("message", "").lower()

View File

@@ -0,0 +1,175 @@
"""Tests for app.tools.error_handler module."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from app.tools.error_handler import (
ErrorCategory,
classify_error,
with_retry,
)
pytestmark = pytest.mark.unit
class TestErrorClassification:
def test_timeout_exception_is_timeout(self) -> None:
exc = httpx.TimeoutException("timed out")
assert classify_error(exc) == ErrorCategory.TIMEOUT
def test_connect_error_is_network(self) -> None:
exc = httpx.ConnectError("connection refused")
assert classify_error(exc) == ErrorCategory.NETWORK
def test_401_is_auth_failure(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(401, request=request)
exc = httpx.HTTPStatusError("401", request=request, response=response)
assert classify_error(exc) == ErrorCategory.AUTH_FAILURE
def test_403_is_auth_failure(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(403, request=request)
exc = httpx.HTTPStatusError("403", request=request, response=response)
assert classify_error(exc) == ErrorCategory.AUTH_FAILURE
def test_429_is_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(429, request=request)
exc = httpx.HTTPStatusError("429", request=request, response=response)
assert classify_error(exc) == ErrorCategory.RETRYABLE
def test_500_is_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(500, request=request)
exc = httpx.HTTPStatusError("500", request=request, response=response)
assert classify_error(exc) == ErrorCategory.RETRYABLE
def test_502_is_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(502, request=request)
exc = httpx.HTTPStatusError("502", request=request, response=response)
assert classify_error(exc) == ErrorCategory.RETRYABLE
def test_503_is_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(503, request=request)
exc = httpx.HTTPStatusError("503", request=request, response=response)
assert classify_error(exc) == ErrorCategory.RETRYABLE
def test_404_is_non_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(404, request=request)
exc = httpx.HTTPStatusError("404", request=request, response=response)
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
def test_400_is_non_retryable(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(400, request=request)
exc = httpx.HTTPStatusError("400", request=request, response=response)
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
def test_generic_exception_is_non_retryable(self) -> None:
exc = ValueError("bad value")
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
def test_runtime_error_is_non_retryable(self) -> None:
exc = RuntimeError("boom")
assert classify_error(exc) == ErrorCategory.NON_RETRYABLE
class TestWithRetry:
@pytest.mark.asyncio
async def test_succeeds_on_first_try(self) -> None:
fn = AsyncMock(return_value="ok")
result = await with_retry(fn, max_retries=3, base_delay=0.0)
assert result == "ok"
assert fn.call_count == 1
@pytest.mark.asyncio
async def test_retries_on_retryable_error(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(503, request=request)
retryable_exc = httpx.HTTPStatusError("503", request=request, response=response)
fn = AsyncMock(side_effect=[retryable_exc, retryable_exc, "success"])
with patch("app.tools.error_handler.asyncio.sleep", new_callable=AsyncMock):
result = await with_retry(fn, max_retries=3, base_delay=0.0)
assert result == "success"
assert fn.call_count == 3
@pytest.mark.asyncio
async def test_does_not_retry_non_retryable_error(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(404, request=request)
non_retryable_exc = httpx.HTTPStatusError("404", request=request, response=response)
fn = AsyncMock(side_effect=non_retryable_exc)
with pytest.raises(httpx.HTTPStatusError):
await with_retry(fn, max_retries=3, base_delay=0.0)
assert fn.call_count == 1
@pytest.mark.asyncio
async def test_does_not_retry_auth_failure(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(401, request=request)
auth_exc = httpx.HTTPStatusError("401", request=request, response=response)
fn = AsyncMock(side_effect=auth_exc)
with pytest.raises(httpx.HTTPStatusError):
await with_retry(fn, max_retries=3, base_delay=0.0)
assert fn.call_count == 1
@pytest.mark.asyncio
async def test_raises_after_max_retries_exhausted(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(500, request=request)
retryable_exc = httpx.HTTPStatusError("500", request=request, response=response)
fn = AsyncMock(side_effect=retryable_exc)
with (
patch("app.tools.error_handler.asyncio.sleep", new_callable=AsyncMock),
pytest.raises(httpx.HTTPStatusError),
):
await with_retry(fn, max_retries=3, base_delay=0.0)
assert fn.call_count == 3
@pytest.mark.asyncio
async def test_does_not_retry_timeout(self) -> None:
"""TimeoutException is TIMEOUT category -- not retried by default."""
fn = AsyncMock(side_effect=httpx.TimeoutException("timed out"))
with pytest.raises(httpx.TimeoutException):
await with_retry(fn, max_retries=3, base_delay=0.0)
assert fn.call_count == 1
@pytest.mark.asyncio
async def test_exponential_backoff_increases_delay(self) -> None:
request = httpx.Request("GET", "http://example.com")
response = httpx.Response(503, request=request)
retryable_exc = httpx.HTTPStatusError("503", request=request, response=response)
fn = AsyncMock(side_effect=[retryable_exc, retryable_exc, "done"])
sleep_delays: list[float] = []
async def capture_sleep(delay: float) -> None:
sleep_delays.append(delay)
with patch("app.tools.error_handler.asyncio.sleep", side_effect=capture_sleep):
await with_retry(fn, max_retries=3, base_delay=1.0)
assert len(sleep_delays) == 2
assert sleep_delays[1] > sleep_delays[0]

View File

@@ -0,0 +1,142 @@
"""Tests for standardized error response envelope format."""
from __future__ import annotations
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field
from app.api_utils import envelope
pytestmark = pytest.mark.unit
def _build_test_app() -> FastAPI:
"""Build a minimal FastAPI app with the standard exception handlers."""
app = FastAPI()
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=exc.status_code,
content=envelope(None, success=False, error=exc.detail),
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=422,
content=envelope(None, success=False, error=str(exc)),
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc): # type: ignore[no-untyped-def]
return JSONResponse(
status_code=500,
content=envelope(None, success=False, error="Internal server error"),
)
class ItemRequest(BaseModel):
name: str = Field(..., min_length=1)
count: int = Field(..., gt=0)
@app.get("/items/{item_id}")
def get_item(item_id: int) -> dict:
if item_id == 0:
raise HTTPException(status_code=400, detail="Invalid item ID")
if item_id == 999:
raise HTTPException(status_code=404, detail="Item not found")
if item_id == 401:
raise HTTPException(status_code=401, detail="Not authenticated")
return envelope({"id": item_id, "name": "test"})
@app.post("/items")
def create_item(item: ItemRequest) -> dict:
return envelope({"id": 1, "name": item.name})
@app.get("/crash")
def crash() -> dict:
msg = "unexpected failure"
raise RuntimeError(msg)
return app
class TestHttpExceptionEnvelope:
"""HTTPException responses use the standard envelope format."""
def test_400_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/0")
assert resp.status_code == 400
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Invalid item ID"
def test_404_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/999")
assert resp.status_code == 404
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Item not found"
def test_401_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/items/401")
assert resp.status_code == 401
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Not authenticated"
class TestValidationErrorEnvelope:
"""Validation errors return 422 with envelope format."""
def test_validation_error_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.post("/items", json={"name": "", "count": -1})
assert resp.status_code == 422
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert isinstance(body["error"], str)
assert len(body["error"]) > 0
class TestGeneralExceptionEnvelope:
"""Unhandled exceptions return 500 with safe envelope."""
def test_unhandled_exception_returns_500_envelope(self) -> None:
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.get("/crash")
assert resp.status_code == 500
body = resp.json()
assert body["success"] is False
assert body["data"] is None
assert body["error"] == "Internal server error"
class TestSuccessResponseUnchanged:
"""Success responses still work normally."""
def test_success_returns_envelope(self) -> None:
app = _build_test_app()
with TestClient(app) as client:
resp = client.get("/items/42")
assert resp.status_code == 200
body = resp.json()
assert body["success"] is True
assert body["data"]["id"] == 42
assert body["error"] is None

View File

@@ -0,0 +1,169 @@
"""Tests for app.escalation module."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from app.escalation import (
EscalationPayload,
EscalationResult,
NoOpEscalator,
WebhookEscalator,
)
def _make_payload(**kwargs) -> EscalationPayload:
defaults = {
"thread_id": "t1",
"reason": "Agent cannot resolve",
"conversation_summary": "User asked about refund policy",
}
defaults.update(kwargs)
return EscalationPayload(**defaults)
@pytest.mark.unit
class TestEscalationPayload:
def test_frozen(self) -> None:
payload = _make_payload()
with pytest.raises(Exception):
payload.thread_id = "t2" # type: ignore[misc]
def test_default_metadata(self) -> None:
payload = _make_payload()
assert payload.metadata == {}
def test_model_dump(self) -> None:
payload = _make_payload(metadata={"key": "val"})
data = payload.model_dump()
assert data["thread_id"] == "t1"
assert data["metadata"] == {"key": "val"}
@pytest.mark.unit
class TestEscalationResult:
def test_frozen(self) -> None:
result = EscalationResult(success=True, status_code=200, attempts=1, error=None)
assert result.success
assert result.status_code == 200
@pytest.mark.unit
class TestWebhookEscalator:
@pytest.mark.asyncio
async def test_empty_url_returns_failure(self) -> None:
escalator = WebhookEscalator(url="", max_retries=3)
result = await escalator.escalate(_make_payload())
assert not result.success
assert result.attempts == 0
assert "not configured" in result.error
@pytest.mark.asyncio
async def test_successful_post(self) -> None:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("app.escalation.httpx.AsyncClient", return_value=mock_client):
escalator = WebhookEscalator(url="https://example.com/hook")
result = await escalator.escalate(_make_payload())
assert result.success
assert result.status_code == 200
assert result.attempts == 1
@pytest.mark.asyncio
async def test_retry_on_server_error(self) -> None:
fail_response = AsyncMock()
fail_response.status_code = 500
success_response = AsyncMock()
success_response.status_code = 200
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=[fail_response, fail_response, success_response])
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
):
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=3)
result = await escalator.escalate(_make_payload())
assert result.success
assert result.attempts == 3
@pytest.mark.asyncio
async def test_all_retries_exhausted(self) -> None:
fail_response = AsyncMock()
fail_response.status_code = 500
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=fail_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
):
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=3)
result = await escalator.escalate(_make_payload())
assert not result.success
assert result.attempts == 3
assert "500" in result.error
@pytest.mark.asyncio
async def test_timeout_error(self) -> None:
mock_client = AsyncMock()
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("timeout"))
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
):
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=2)
result = await escalator.escalate(_make_payload())
assert not result.success
assert "timed out" in result.error
@pytest.mark.asyncio
async def test_request_error(self) -> None:
mock_client = AsyncMock()
mock_client.post = AsyncMock(
side_effect=httpx.RequestError("connection refused")
)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with (
patch("app.escalation.httpx.AsyncClient", return_value=mock_client),
patch("app.escalation.asyncio.sleep", new_callable=AsyncMock),
):
escalator = WebhookEscalator(url="https://example.com/hook", max_retries=1)
result = await escalator.escalate(_make_payload())
assert not result.success
@pytest.mark.unit
class TestNoOpEscalator:
@pytest.mark.asyncio
async def test_returns_disabled(self) -> None:
escalator = NoOpEscalator()
result = await escalator.escalate(_make_payload())
assert not result.success
assert result.attempts == 0
assert "disabled" in result.error.lower()

View File

@@ -6,8 +6,11 @@ from typing import TYPE_CHECKING
from unittest.mock import AsyncMock, MagicMock
import pytest
from langgraph.checkpoint.memory import InMemorySaver
from app.graph import SUPERVISOR_PROMPT, build_agent_nodes, build_graph
from app.graph import build_agent_nodes, build_graph
from app.graph_context import GraphContext
from app.intent import ClassificationResult, IntentTarget
if TYPE_CHECKING:
from app.registry import AgentRegistry
@@ -33,12 +36,59 @@ class TestBuildGraph:
mock_llm = MagicMock()
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
mock_checkpointer = AsyncMock()
checkpointer = InMemorySaver()
graph = build_graph(sample_registry, mock_llm, mock_checkpointer)
assert graph is not None
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
assert graph_ctx is not None
assert graph_ctx.graph is not None
def test_supervisor_prompt_contains_routing_info(self) -> None:
assert "order_lookup" in SUPERVISOR_PROMPT
assert "order_actions" in SUPERVISOR_PROMPT
assert "fallback" in SUPERVISOR_PROMPT
def test_graph_has_classifier_attached(self, sample_registry: AgentRegistry) -> None:
mock_llm = MagicMock()
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
checkpointer = InMemorySaver()
mock_classifier = MagicMock()
graph_ctx = build_graph(
sample_registry, mock_llm, checkpointer, intent_classifier=mock_classifier
)
assert graph_ctx.intent_classifier is mock_classifier
assert graph_ctx.registry is sample_registry
def test_graph_without_classifier(self, sample_registry: AgentRegistry) -> None:
mock_llm = MagicMock()
mock_llm.bind_tools = MagicMock(return_value=mock_llm)
mock_llm.with_structured_output = MagicMock(return_value=mock_llm)
checkpointer = InMemorySaver()
graph_ctx = build_graph(sample_registry, mock_llm, checkpointer)
assert graph_ctx.intent_classifier is None
@pytest.mark.unit
class TestClassifyIntent:
@pytest.mark.asyncio
async def test_returns_none_without_classifier(self) -> None:
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(graph=MagicMock(), registry=mock_registry, intent_classifier=None)
result = await graph_ctx.classify_intent("hello")
assert result is None
@pytest.mark.asyncio
async def test_calls_classifier(self) -> None:
expected = ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.9, reasoning="test"),),
)
mock_classifier = AsyncMock()
mock_classifier.classify = AsyncMock(return_value=expected)
mock_registry = MagicMock()
mock_registry.list_agents = MagicMock(return_value=())
graph_ctx = GraphContext(
graph=MagicMock(), registry=mock_registry, intent_classifier=mock_classifier,
)
result = await graph_ctx.classify_intent("check order")
assert result is not None
assert result.intents[0].agent_name == "order_lookup"

View File

@@ -0,0 +1,177 @@
"""Tests for app.intent module."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.intent import (
AMBIGUITY_THRESHOLD,
ClassificationResult,
IntentTarget,
LLMIntentClassifier,
_build_agent_list,
)
from app.registry import AgentConfig
def _make_agent(name: str, desc: str = "test", perm: str = "read") -> AgentConfig:
return AgentConfig(
name=name, description=desc, permission=perm, tools=["fallback_respond"],
)
@pytest.mark.unit
class TestIntentModels:
def test_intent_target_frozen(self) -> None:
target = IntentTarget(agent_name="order_lookup", confidence=0.9, reasoning="order query")
with pytest.raises(Exception):
target.agent_name = "other" # type: ignore[misc]
def test_classification_result_frozen(self) -> None:
result = ClassificationResult(
intents=(IntentTarget(agent_name="a", confidence=0.9, reasoning="r"),),
)
assert not result.is_ambiguous
assert result.clarification_question is None
def test_classification_result_ambiguous(self) -> None:
result = ClassificationResult(
intents=(),
is_ambiguous=True,
clarification_question="What do you mean?",
)
assert result.is_ambiguous
def test_multi_intent(self) -> None:
result = ClassificationResult(
intents=(
IntentTarget(agent_name="order_actions", confidence=0.85, reasoning="cancel"),
IntentTarget(agent_name="discount", confidence=0.8, reasoning="discount"),
),
)
assert len(result.intents) == 2
@pytest.mark.unit
class TestBuildAgentList:
def test_formats_agents(self) -> None:
agents = (
_make_agent("order_lookup", "Looks up orders", "read"),
_make_agent("order_actions", "Modifies orders", "write"),
)
text = _build_agent_list(agents)
assert "order_lookup" in text
assert "order_actions" in text
assert "read" in text
assert "write" in text
@pytest.mark.unit
class TestLLMIntentClassifier:
@pytest.mark.asyncio
async def test_single_intent_classification(self) -> None:
expected = ClassificationResult(
intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="query"),),
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("order_lookup"), _make_agent("fallback"))
result = await classifier.classify("What is order 1042 status?", agents)
assert len(result.intents) == 1
assert result.intents[0].agent_name == "order_lookup"
assert not result.is_ambiguous
@pytest.mark.asyncio
async def test_multi_intent_classification(self) -> None:
expected = ClassificationResult(
intents=(
IntentTarget(agent_name="order_actions", confidence=0.9, reasoning="cancel"),
IntentTarget(agent_name="discount", confidence=0.85, reasoning="discount"),
),
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("order_actions"), _make_agent("discount"), _make_agent("fallback"))
result = await classifier.classify("Cancel order 1042 and give me 10% off", agents)
assert len(result.intents) == 2
assert not result.is_ambiguous
@pytest.mark.asyncio
async def test_ambiguous_classification(self) -> None:
expected = ClassificationResult(
intents=(IntentTarget(agent_name="fallback", confidence=0.3, reasoning="unclear"),),
is_ambiguous=False,
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("order_lookup"), _make_agent("fallback"))
result = await classifier.classify("hmm", agents)
# Low confidence triggers ambiguity
assert result.is_ambiguous
assert result.clarification_question is not None
@pytest.mark.asyncio
async def test_llm_error_returns_ambiguous(self) -> None:
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(side_effect=RuntimeError("LLM error"))
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("fallback"),)
result = await classifier.classify("test", agents)
assert result.is_ambiguous
assert result.clarification_question is not None
@pytest.mark.asyncio
async def test_non_result_type_returns_ambiguous(self) -> None:
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value="not a ClassificationResult")
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("fallback"),)
result = await classifier.classify("test", agents)
assert result.is_ambiguous
@pytest.mark.asyncio
async def test_high_confidence_not_ambiguous(self) -> None:
expected = ClassificationResult(
intents=(
IntentTarget(
agent_name="order_lookup",
confidence=AMBIGUITY_THRESHOLD + 0.1,
reasoning="clear",
),
),
)
mock_structured = MagicMock()
mock_structured.ainvoke = AsyncMock(return_value=expected)
mock_llm = MagicMock()
mock_llm.with_structured_output = MagicMock(return_value=mock_structured)
classifier = LLMIntentClassifier(mock_llm)
agents = (_make_agent("order_lookup"),)
result = await classifier.classify("order status 1042", agents)
assert not result.is_ambiguous

Some files were not shown because too many files have changed in this diff Show More