Compare commits
11 Commits
ca8d7099b3
...
ec005c91a9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec005c91a9 | ||
|
|
0f7341b158 | ||
|
|
37c46e76ae | ||
|
|
4915f1bae4 | ||
|
|
9ee3ec9b4e | ||
|
|
5c7a0ee4c0 | ||
|
|
42ba359c48 | ||
|
|
27b131492f | ||
|
|
ea72497587 | ||
|
|
3c725c45fa | ||
|
|
4eb06dd8e5 |
19
README.md
19
README.md
@@ -106,7 +106,7 @@ curl -X POST http://localhost:8000/api/v1/portfolio/analyze \
|
||||
-d '{"holdings":[{"symbol":"AAPL","shares":100,"buy_in_price":150},{"symbol":"VOLV-B.ST","shares":50,"buy_in_price":250}]}'
|
||||
```
|
||||
|
||||
## API Endpoints (99 total)
|
||||
## API Endpoints (102 total)
|
||||
|
||||
### Health
|
||||
|
||||
@@ -130,15 +130,27 @@ curl -X POST http://localhost:8000/api/v1/portfolio/analyze \
|
||||
| GET | `/api/v1/stock/{symbol}/filings?form_type=10-K` | SEC filings (10-K, 10-Q, 8-K) |
|
||||
| GET | `/api/v1/search?query=` | Company search by name (SEC/NASDAQ) |
|
||||
|
||||
### Sentiment & Analyst Data (Finnhub + Alpha Vantage + yfinance)
|
||||
### Sentiment & Analyst Data (Finnhub + Alpha Vantage + yfinance + Reddit)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/stock/{symbol}/sentiment` | Aggregated: news sentiment + recommendations + upgrades |
|
||||
| GET | `/api/v1/stock/{symbol}/sentiment` | Composite sentiment score from all sources (-1 to +1) |
|
||||
| GET | `/api/v1/stock/{symbol}/news-sentiment?limit=30` | News articles with per-ticker sentiment scores (Alpha Vantage) |
|
||||
| GET | `/api/v1/stock/{symbol}/social-sentiment` | Social media sentiment from Reddit + Twitter (Finnhub) |
|
||||
| GET | `/api/v1/stock/{symbol}/reddit-sentiment` | Reddit mentions, upvotes, rank (ApeWisdom, free) |
|
||||
| GET | `/api/v1/stock/{symbol}/insider-trades` | Insider transactions via Finnhub |
|
||||
| GET | `/api/v1/stock/{symbol}/recommendations` | Monthly analyst buy/hold/sell counts (Finnhub) |
|
||||
| GET | `/api/v1/stock/{symbol}/upgrades` | Analyst upgrades/downgrades with price targets (yfinance) |
|
||||
| GET | `/api/v1/discover/reddit-trending` | Top 25 trending stocks on Reddit (free) |
|
||||
|
||||
The `/sentiment` endpoint aggregates 4 sources into a weighted composite score:
|
||||
|
||||
| Source | Weight | Data |
|
||||
|--------|--------|------|
|
||||
| News (Alpha Vantage) | 25% | Article-level bullish/bearish scores |
|
||||
| Analysts (Finnhub) | 30% | Buy/sell recommendation ratio |
|
||||
| Upgrades (yfinance) | 20% | Recent upgrade/downgrade actions |
|
||||
| Reddit (ApeWisdom) | 25% | 24h mention change trend |
|
||||
|
||||
### Technical Analysis (14 indicators, local computation, no key needed)
|
||||
|
||||
@@ -555,6 +567,7 @@ docker run -p 8000:8000 invest-api
|
||||
| **Alpha Vantage** | Free | Yes (free registration) | News sentiment scores (bullish/bearish per ticker per article), 25 req/day |
|
||||
| **FRED** | Free | Yes (free registration) | Fed rate, treasury yields, CPI, PCE, money supply, surveys, 800K+ economic series |
|
||||
| **Federal Reserve** | Free | No | EFFR, SOFR, money measures, central bank holdings, primary dealer positions, FOMC documents |
|
||||
| **ApeWisdom** | Free | No | Reddit stock mentions, upvotes, trending (WSB, r/stocks, r/investing) |
|
||||
| **openbb-technical** | Free | No (local) | ATR, ADX, Stochastic, OBV, Ichimoku, Donchian, Aroon, CCI, Keltner, Fibonacci, A/D, VWAP, Volatility Cones, Relative Rotation |
|
||||
| **openbb-quantitative** | Free | No (local) | Sharpe, Sortino, Omega ratios, CAPM, normality tests, unit root tests, rolling statistics |
|
||||
|
||||
|
||||
185
akshare_service.py
Normal file
185
akshare_service.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""A-share and HK stock data service using AKShare."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import akshare as ak
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Symbol validation patterns ---
|
||||
|
||||
_A_SHARE_PATTERN = re.compile(r"^[036]\d{5}$")
|
||||
_HK_PATTERN = re.compile(r"^\d{5}$")
|
||||
|
||||
# --- Chinese column name mappings ---
|
||||
|
||||
_HIST_COLUMNS: dict[str, str] = {
|
||||
"日期": "date",
|
||||
"开盘": "open",
|
||||
"收盘": "close",
|
||||
"最高": "high",
|
||||
"最低": "low",
|
||||
"成交量": "volume",
|
||||
"成交额": "turnover",
|
||||
"振幅": "amplitude",
|
||||
"涨跌幅": "change_percent",
|
||||
"涨跌额": "change",
|
||||
"换手率": "turnover_rate",
|
||||
}
|
||||
|
||||
_QUOTE_COLUMNS: dict[str, str] = {
|
||||
"代码": "symbol",
|
||||
"名称": "name",
|
||||
"最新价": "price",
|
||||
"涨跌幅": "change_percent",
|
||||
"涨跌额": "change",
|
||||
"成交量": "volume",
|
||||
"成交额": "turnover",
|
||||
"今开": "open",
|
||||
"最高": "high",
|
||||
"最低": "low",
|
||||
"昨收": "prev_close",
|
||||
}
|
||||
|
||||
|
||||
# --- Validation helpers ---
|
||||
|
||||
|
||||
def validate_a_share_symbol(symbol: str) -> bool:
|
||||
"""Return True if symbol matches A-share format (6 digits, starts with 0, 3, or 6)."""
|
||||
return bool(_A_SHARE_PATTERN.match(symbol))
|
||||
|
||||
|
||||
def validate_hk_symbol(symbol: str) -> bool:
|
||||
"""Return True if symbol matches HK stock format (exactly 5 digits)."""
|
||||
return bool(_HK_PATTERN.match(symbol))
|
||||
|
||||
|
||||
# --- DataFrame parsers ---
|
||||
|
||||
|
||||
def _parse_hist_df(df: pd.DataFrame) -> list[dict[str, Any]]:
|
||||
"""Convert a Chinese-column historical DataFrame to a list of English-key dicts."""
|
||||
if df.empty:
|
||||
return []
|
||||
df = df.rename(columns=_HIST_COLUMNS)
|
||||
# Serialize date column to ISO string
|
||||
if "date" in df.columns:
|
||||
df["date"] = df["date"].astype(str)
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
|
||||
def _parse_spot_row(df: pd.DataFrame, symbol: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Filter a spot quote DataFrame by symbol code column and return
|
||||
an English-key dict for the matching row, or None if not found.
|
||||
"""
|
||||
if df.empty:
|
||||
return None
|
||||
code_col = "代码"
|
||||
if code_col not in df.columns:
|
||||
return None
|
||||
matched = df[df[code_col] == symbol]
|
||||
if matched.empty:
|
||||
return None
|
||||
row = matched.iloc[0]
|
||||
result: dict[str, Any] = {}
|
||||
for cn_key, en_key in _QUOTE_COLUMNS.items():
|
||||
result[en_key] = row.get(cn_key)
|
||||
return result
|
||||
|
||||
|
||||
# --- Date helpers ---
|
||||
|
||||
|
||||
def _date_range(days: int) -> tuple[str, str]:
|
||||
"""Return (start_date, end_date) strings in YYYYMMDD format for the given window."""
|
||||
end = datetime.now()
|
||||
start = end - timedelta(days=days)
|
||||
return start.strftime("%Y%m%d"), end.strftime("%Y%m%d")
|
||||
|
||||
|
||||
# --- A-share service functions ---
|
||||
|
||||
|
||||
async def get_a_share_quote(symbol: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch real-time A-share quote for a single symbol.
|
||||
|
||||
Returns a dict with English keys, or None if the symbol is not found
|
||||
in the spot market data. Propagates AKShare exceptions to the caller.
|
||||
"""
|
||||
df: pd.DataFrame = await asyncio.to_thread(ak.stock_zh_a_spot_em)
|
||||
return _parse_spot_row(df, symbol)
|
||||
|
||||
|
||||
async def get_a_share_historical(
|
||||
symbol: str, *, days: int = 365
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch daily OHLCV history for an A-share symbol with qfq (前复权) adjustment.
|
||||
|
||||
Propagates AKShare exceptions to the caller.
|
||||
"""
|
||||
start_date, end_date = _date_range(days)
|
||||
df: pd.DataFrame = await asyncio.to_thread(
|
||||
ak.stock_zh_a_hist,
|
||||
symbol=symbol,
|
||||
period="daily",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
adjust="qfq",
|
||||
)
|
||||
return _parse_hist_df(df)
|
||||
|
||||
|
||||
async def search_a_shares(query: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search A-share stocks by name substring.
|
||||
|
||||
Returns a list of {code, name} dicts. An empty query returns all stocks.
|
||||
Propagates AKShare exceptions to the caller.
|
||||
"""
|
||||
df: pd.DataFrame = await asyncio.to_thread(ak.stock_info_a_code_name)
|
||||
if query:
|
||||
df = df[df["name"].str.contains(query, na=False)]
|
||||
return df[["code", "name"]].to_dict(orient="records")
|
||||
|
||||
|
||||
# --- HK stock service functions ---
|
||||
|
||||
|
||||
async def get_hk_quote(symbol: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Fetch real-time HK stock quote for a single symbol.
|
||||
|
||||
Returns a dict with English keys, or None if the symbol is not found.
|
||||
Propagates AKShare exceptions to the caller.
|
||||
"""
|
||||
df: pd.DataFrame = await asyncio.to_thread(ak.stock_hk_spot_em)
|
||||
return _parse_spot_row(df, symbol)
|
||||
|
||||
|
||||
async def get_hk_historical(
|
||||
symbol: str, *, days: int = 365
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch daily OHLCV history for a HK stock symbol with qfq adjustment.
|
||||
|
||||
Propagates AKShare exceptions to the caller.
|
||||
"""
|
||||
start_date, end_date = _date_range(days)
|
||||
df: pd.DataFrame = await asyncio.to_thread(
|
||||
ak.stock_hk_hist,
|
||||
symbol=symbol,
|
||||
period="daily",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
adjust="qfq",
|
||||
)
|
||||
return _parse_hist_df(df)
|
||||
372
backtest_service.py
Normal file
372
backtest_service.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""Backtesting engine using pure pandas/numpy - no external backtesting libraries."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from obb_utils import fetch_historical
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_EQUITY_CURVE_MAX_POINTS = 20
|
||||
_MIN_BARS_FOR_SINGLE_POINT = 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal signal computation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_closes(result: Any) -> pd.Series:
|
||||
"""Pull close prices from an OBBject result into a float Series."""
|
||||
bars = result.results
|
||||
closes = [getattr(bar, "close", None) for bar in bars]
|
||||
return pd.Series(closes, dtype=float).dropna().reset_index(drop=True)
|
||||
|
||||
|
||||
def _compute_sma_signals(
|
||||
prices: pd.Series, short_window: int, long_window: int
|
||||
) -> pd.Series:
|
||||
"""Return position series (1=long, 0=flat) from SMA crossover strategy.
|
||||
|
||||
Buy when short SMA crosses above long SMA; sell when it crosses below.
|
||||
"""
|
||||
short_ma = prices.rolling(short_window).mean()
|
||||
long_ma = prices.rolling(long_window).mean()
|
||||
|
||||
# 1 where short > long, else 0; NaN before long_window filled with 0
|
||||
signal = (short_ma > long_ma).astype(int)
|
||||
signal.iloc[: long_window - 1] = 0
|
||||
return signal
|
||||
|
||||
|
||||
def _compute_rsi(prices: pd.Series, period: int) -> pd.Series:
|
||||
"""Compute Wilder RSI for a price series."""
|
||||
delta = prices.diff()
|
||||
gain = delta.clip(lower=0)
|
||||
loss = (-delta).clip(lower=0)
|
||||
|
||||
avg_gain = gain.ewm(alpha=1 / period, adjust=False).mean()
|
||||
avg_loss = loss.ewm(alpha=1 / period, adjust=False).mean()
|
||||
|
||||
# When avg_loss == 0 and avg_gain > 0, RSI = 100; avoid division by zero.
|
||||
rsi = pd.Series(np.where(
|
||||
avg_loss == 0,
|
||||
np.where(avg_gain == 0, 50.0, 100.0),
|
||||
100 - (100 / (1 + avg_gain / avg_loss)),
|
||||
), index=prices.index, dtype=float)
|
||||
# Preserve NaN for the initial diff period
|
||||
rsi[avg_gain.isna()] = np.nan
|
||||
return rsi
|
||||
|
||||
|
||||
def _compute_rsi_signals(
|
||||
prices: pd.Series, period: int, oversold: float, overbought: float
|
||||
) -> pd.Series:
|
||||
"""Return position series (1=long, 0=flat) from RSI strategy.
|
||||
|
||||
Buy when RSI < oversold; sell when RSI > overbought.
|
||||
"""
|
||||
rsi = _compute_rsi(prices, period)
|
||||
position = pd.Series(0, index=prices.index, dtype=int)
|
||||
in_trade = False
|
||||
|
||||
for i in range(len(prices)):
|
||||
rsi_val = rsi.iloc[i]
|
||||
if pd.isna(rsi_val):
|
||||
continue
|
||||
if not in_trade and rsi_val < oversold:
|
||||
in_trade = True
|
||||
elif in_trade and rsi_val > overbought:
|
||||
in_trade = False
|
||||
if in_trade:
|
||||
position.iloc[i] = 1
|
||||
|
||||
return position
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared metrics computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_metrics(equity: pd.Series, trades: int) -> dict[str, Any]:
|
||||
"""Compute standard backtest performance metrics from an equity curve.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
equity:
|
||||
Daily portfolio value series starting from initial_capital.
|
||||
trades:
|
||||
Number of completed round-trip trades.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict with keys: total_return, annualized_return, sharpe_ratio,
|
||||
max_drawdown, win_rate, total_trades, equity_curve.
|
||||
"""
|
||||
n = len(equity)
|
||||
initial = float(equity.iloc[0])
|
||||
final = float(equity.iloc[-1])
|
||||
|
||||
total_return = (final - initial) / initial if initial != 0 else 0.0
|
||||
|
||||
trading_days = max(n - 1, 1)
|
||||
annualized_return = (1 + total_return) ** (252 / trading_days) - 1
|
||||
|
||||
# Sharpe ratio (annualized, risk-free rate = 0)
|
||||
sharpe_ratio: float | None = None
|
||||
if n > 1:
|
||||
daily_returns = equity.pct_change().dropna()
|
||||
std = float(daily_returns.std())
|
||||
if std > 0:
|
||||
sharpe_ratio = float(daily_returns.mean() / std * np.sqrt(252))
|
||||
|
||||
# Maximum drawdown
|
||||
rolling_max = equity.cummax()
|
||||
drawdown = (equity - rolling_max) / rolling_max
|
||||
max_drawdown = float(drawdown.min())
|
||||
|
||||
# Win rate: undefined when no trades
|
||||
win_rate: float | None = None
|
||||
if trades > 0:
|
||||
# Approximate: compare each trade entry/exit pair captured in equity
|
||||
win_rate = None # will be overridden by callers that track trades
|
||||
|
||||
# Equity curve - last N points as plain Python floats
|
||||
last_n = equity.iloc[-_EQUITY_CURVE_MAX_POINTS:]
|
||||
equity_curve = [round(float(v), 4) for v in last_n]
|
||||
|
||||
return {
|
||||
"total_return": round(total_return, 6),
|
||||
"annualized_return": round(annualized_return, 6),
|
||||
"sharpe_ratio": round(sharpe_ratio, 6) if sharpe_ratio is not None else None,
|
||||
"max_drawdown": round(max_drawdown, 6),
|
||||
"win_rate": win_rate,
|
||||
"total_trades": trades,
|
||||
"equity_curve": equity_curve,
|
||||
}
|
||||
|
||||
|
||||
def _simulate_positions(
|
||||
prices: pd.Series,
|
||||
positions: pd.Series,
|
||||
initial_capital: float,
|
||||
) -> tuple[pd.Series, int, int]:
|
||||
"""Simulate portfolio equity given a position series and prices.
|
||||
|
||||
Returns (equity_curve, total_trades, winning_trades).
|
||||
A trade is a complete buy->sell round-trip.
|
||||
"""
|
||||
# Daily returns when in position
|
||||
price_returns = prices.pct_change().fillna(0.0)
|
||||
strategy_returns = positions.shift(1).fillna(0).astype(float) * price_returns
|
||||
|
||||
equity = initial_capital * (1 + strategy_returns).cumprod()
|
||||
equity.iloc[0] = initial_capital
|
||||
|
||||
# Count round trips
|
||||
trade_changes = positions.diff().abs()
|
||||
entries = int((trade_changes == 1).sum())
|
||||
exits = int((trade_changes == -1).sum())
|
||||
total_trades = min(entries, exits) # only completed round trips
|
||||
|
||||
# Count wins: each completed trade where exit value > entry value
|
||||
winning_trades = 0
|
||||
in_trade = False
|
||||
entry_price = 0.0
|
||||
for i in range(len(positions)):
|
||||
pos = int(positions.iloc[i])
|
||||
price = float(prices.iloc[i])
|
||||
if not in_trade and pos == 1:
|
||||
in_trade = True
|
||||
entry_price = price
|
||||
elif in_trade and pos == 0:
|
||||
in_trade = False
|
||||
if price > entry_price:
|
||||
winning_trades += 1
|
||||
|
||||
return equity, total_trades, winning_trades
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public strategy functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def backtest_sma_crossover(
|
||||
symbol: str,
|
||||
short_window: int,
|
||||
long_window: int,
|
||||
days: int,
|
||||
initial_capital: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Run SMA crossover backtest for a single symbol."""
|
||||
hist = await fetch_historical(symbol, days)
|
||||
if hist is None:
|
||||
raise ValueError(f"No historical data available for {symbol}")
|
||||
|
||||
prices = _extract_closes(hist)
|
||||
if len(prices) <= long_window:
|
||||
raise ValueError(
|
||||
f"Insufficient data: need >{long_window} bars, got {len(prices)}"
|
||||
)
|
||||
|
||||
positions = _compute_sma_signals(prices, short_window, long_window)
|
||||
equity, total_trades, winning_trades = _simulate_positions(
|
||||
prices, positions, initial_capital
|
||||
)
|
||||
|
||||
result = _compute_metrics(equity, total_trades)
|
||||
if total_trades > 0:
|
||||
result["win_rate"] = round(winning_trades / total_trades, 6)
|
||||
return result
|
||||
|
||||
|
||||
async def backtest_rsi(
|
||||
symbol: str,
|
||||
period: int,
|
||||
oversold: float,
|
||||
overbought: float,
|
||||
days: int,
|
||||
initial_capital: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Run RSI-based backtest for a single symbol."""
|
||||
hist = await fetch_historical(symbol, days)
|
||||
if hist is None:
|
||||
raise ValueError(f"No historical data available for {symbol}")
|
||||
|
||||
prices = _extract_closes(hist)
|
||||
if len(prices) <= period:
|
||||
raise ValueError(
|
||||
f"Insufficient data: need >{period} bars, got {len(prices)}"
|
||||
)
|
||||
|
||||
positions = _compute_rsi_signals(prices, period, oversold, overbought)
|
||||
equity, total_trades, winning_trades = _simulate_positions(
|
||||
prices, positions, initial_capital
|
||||
)
|
||||
|
||||
result = _compute_metrics(equity, total_trades)
|
||||
if total_trades > 0:
|
||||
result["win_rate"] = round(winning_trades / total_trades, 6)
|
||||
return result
|
||||
|
||||
|
||||
async def backtest_buy_and_hold(
|
||||
symbol: str,
|
||||
days: int,
|
||||
initial_capital: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Run a simple buy-and-hold backtest as a benchmark."""
|
||||
hist = await fetch_historical(symbol, days)
|
||||
if hist is None:
|
||||
raise ValueError(f"No historical data available for {symbol}")
|
||||
|
||||
prices = _extract_closes(hist)
|
||||
if len(prices) < _MIN_BARS_FOR_SINGLE_POINT:
|
||||
raise ValueError(
|
||||
f"Insufficient data: need at least 2 bars, got {len(prices)}"
|
||||
)
|
||||
|
||||
# Always fully invested - position is 1 from day 0
|
||||
positions = pd.Series(1, index=prices.index, dtype=int)
|
||||
equity, _, _ = _simulate_positions(prices, positions, initial_capital)
|
||||
|
||||
result = _compute_metrics(equity, trades=1)
|
||||
# Buy-and-hold: 1 trade, win_rate is whether final > initial
|
||||
result["win_rate"] = 1.0 if result["total_return"] > 0 else 0.0
|
||||
result["total_trades"] = 1
|
||||
return result
|
||||
|
||||
|
||||
async def backtest_momentum(
|
||||
symbols: list[str],
|
||||
lookback: int,
|
||||
top_n: int,
|
||||
rebalance_days: int,
|
||||
days: int,
|
||||
initial_capital: float,
|
||||
) -> dict[str, Any]:
|
||||
"""Run momentum strategy: every rebalance_days pick top_n symbols by lookback return."""
|
||||
# Fetch all price series
|
||||
price_map: dict[str, pd.Series] = {}
|
||||
for sym in symbols:
|
||||
hist = await fetch_historical(sym, days)
|
||||
if hist is not None:
|
||||
closes = _extract_closes(hist)
|
||||
if len(closes) > lookback:
|
||||
price_map[sym] = closes
|
||||
|
||||
if not price_map:
|
||||
raise ValueError("No price data available for any of the requested symbols")
|
||||
|
||||
# Align all price series to the same length (min across symbols)
|
||||
min_len = min(len(v) for v in price_map.values())
|
||||
aligned = {sym: s.iloc[:min_len].reset_index(drop=True) for sym, s in price_map.items()}
|
||||
|
||||
n_bars = min_len
|
||||
portfolio_value = initial_capital
|
||||
equity_values: list[float] = [initial_capital]
|
||||
allocation_history: list[dict[str, Any]] = []
|
||||
total_trades = 0
|
||||
|
||||
current_symbols: list[str] = []
|
||||
current_weights: list[float] = []
|
||||
entry_prices: dict[str, float] = {}
|
||||
winning_trades = 0
|
||||
|
||||
for bar in range(1, n_bars):
|
||||
# Rebalance check
|
||||
if bar % rebalance_days == 0 and bar >= lookback:
|
||||
# Rank symbols by lookback-period return
|
||||
returns: dict[str, float] = {}
|
||||
for sym, prices in aligned.items():
|
||||
if bar >= lookback:
|
||||
ret = (prices.iloc[bar] / prices.iloc[bar - lookback]) - 1
|
||||
returns[sym] = ret
|
||||
|
||||
sorted_syms = sorted(returns, key=returns.get, reverse=True) # type: ignore[arg-type]
|
||||
selected = sorted_syms[:top_n]
|
||||
weight = 1.0 / len(selected) if selected else 0.0
|
||||
|
||||
# Count closed positions as trades
|
||||
for sym in current_symbols:
|
||||
if sym in aligned:
|
||||
exit_price = float(aligned[sym].iloc[bar])
|
||||
entry_price = entry_prices.get(sym, exit_price)
|
||||
total_trades += 1
|
||||
if exit_price > entry_price:
|
||||
winning_trades += 1
|
||||
|
||||
current_symbols = selected
|
||||
current_weights = [weight] * len(selected)
|
||||
entry_prices = {sym: float(aligned[sym].iloc[bar]) for sym in selected}
|
||||
|
||||
allocation_history.append({
|
||||
"bar": bar,
|
||||
"symbols": selected,
|
||||
"weights": current_weights,
|
||||
})
|
||||
|
||||
# Compute portfolio daily return
|
||||
if current_symbols:
|
||||
daily_ret = 0.0
|
||||
for sym, w in zip(current_symbols, current_weights):
|
||||
prev_bar = bar - 1
|
||||
prev_price = float(aligned[sym].iloc[prev_bar])
|
||||
curr_price = float(aligned[sym].iloc[bar])
|
||||
if prev_price != 0:
|
||||
daily_ret += w * (curr_price / prev_price - 1)
|
||||
portfolio_value = portfolio_value * (1 + daily_ret)
|
||||
|
||||
equity_values.append(portfolio_value)
|
||||
|
||||
equity = pd.Series(equity_values, dtype=float)
|
||||
result = _compute_metrics(equity, total_trades)
|
||||
if total_trades > 0:
|
||||
result["win_rate"] = round(winning_trades / total_trades, 6)
|
||||
result["allocation_history"] = allocation_history
|
||||
return result
|
||||
68
congress_service.py
Normal file
68
congress_service.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Congress trading data: member trades and bill search."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openbb import obb
|
||||
|
||||
from obb_utils import to_list
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _try_obb_call(fn, *args, **kwargs) -> list[dict[str, Any]] | None:
|
||||
"""Attempt a single OBB call and return to_list result, or None on failure."""
|
||||
try:
|
||||
result = await asyncio.to_thread(fn, *args, **kwargs)
|
||||
return to_list(result)
|
||||
except Exception as exc:
|
||||
logger.debug("OBB call failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _get_congress_fn():
|
||||
"""Resolve the congress trading OBB function safely."""
|
||||
try:
|
||||
return obb.regulators.government_us.congress_trading
|
||||
except AttributeError:
|
||||
logger.debug("obb.regulators.government_us.congress_trading not available")
|
||||
return None
|
||||
|
||||
|
||||
async def get_congress_trades() -> list[dict[str, Any]]:
|
||||
"""Get recent US congress member stock trades.
|
||||
|
||||
Returns an empty list if the data provider is unavailable.
|
||||
"""
|
||||
fn = _get_congress_fn()
|
||||
if fn is None:
|
||||
return []
|
||||
|
||||
providers = ["quiverquant", "fmp"]
|
||||
for provider in providers:
|
||||
data = await _try_obb_call(fn, provider=provider)
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
logger.warning("All congress trades providers failed")
|
||||
return []
|
||||
|
||||
|
||||
async def search_congress_bills(query: str) -> list[dict[str, Any]]:
|
||||
"""Search US congress bills by keyword.
|
||||
|
||||
Returns an empty list if the data provider is unavailable.
|
||||
"""
|
||||
fn = _get_congress_fn()
|
||||
if fn is None:
|
||||
return []
|
||||
|
||||
providers = ["quiverquant", "fmp"]
|
||||
for provider in providers:
|
||||
data = await _try_obb_call(fn, query, provider=provider)
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
logger.warning("All congress bills providers failed for query: %s", query)
|
||||
return []
|
||||
195
defi_service.py
Normal file
195
defi_service.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""DeFi data service via DefiLlama API (no API key required)."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LLAMA_BASE = "https://api.llama.fi"
|
||||
STABLES_BASE = "https://stablecoins.llama.fi"
|
||||
YIELDS_BASE = "https://yields.llama.fi"
|
||||
TIMEOUT = 15.0
|
||||
|
||||
|
||||
async def get_top_protocols(limit: int = 20) -> list[dict[str, Any]]:
|
||||
"""Fetch top DeFi protocols ranked by TVL from DefiLlama."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
resp = await client.get(f"{LLAMA_BASE}/protocols")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return [
|
||||
{
|
||||
"name": p.get("name"),
|
||||
"symbol": p.get("symbol"),
|
||||
"tvl": p.get("tvl"),
|
||||
"chain": p.get("chain"),
|
||||
"chains": p.get("chains", []),
|
||||
"category": p.get("category"),
|
||||
"change_1d": p.get("change_1d"),
|
||||
"change_7d": p.get("change_7d"),
|
||||
}
|
||||
for p in data[:limit]
|
||||
]
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch top protocols from DefiLlama")
|
||||
return []
|
||||
|
||||
|
||||
async def get_chain_tvls() -> list[dict[str, Any]]:
|
||||
"""Fetch TVL rankings for all chains from DefiLlama."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
resp = await client.get(f"{LLAMA_BASE}/v2/chains")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return [
|
||||
{
|
||||
"name": c.get("name"),
|
||||
"tvl": c.get("tvl"),
|
||||
"tokenSymbol": c.get("tokenSymbol"),
|
||||
}
|
||||
for c in data
|
||||
]
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch chain TVLs from DefiLlama")
|
||||
return []
|
||||
|
||||
|
||||
async def get_protocol_tvl(protocol: str) -> float | None:
|
||||
"""Fetch current TVL for a specific protocol slug."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
resp = await client.get(f"{LLAMA_BASE}/tvl/{protocol}")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch TVL for protocol %s", protocol)
|
||||
return None
|
||||
|
||||
|
||||
async def get_yield_pools(
|
||||
chain: str | None = None,
|
||||
project: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch yield pools from DefiLlama, optionally filtered by chain and/or project.
|
||||
|
||||
Returns top 20 by TVL descending.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
resp = await client.get(f"{YIELDS_BASE}/pools")
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
pools: list[dict[str, Any]] = payload.get("data", [])
|
||||
|
||||
if chain is not None:
|
||||
pools = [p for p in pools if p.get("chain") == chain]
|
||||
if project is not None:
|
||||
pools = [p for p in pools if p.get("project") == project]
|
||||
|
||||
pools = sorted(pools, key=lambda p: p.get("tvlUsd") or 0, reverse=True)[:20]
|
||||
|
||||
return [
|
||||
{
|
||||
"pool": p.get("pool"),
|
||||
"chain": p.get("chain"),
|
||||
"project": p.get("project"),
|
||||
"symbol": p.get("symbol"),
|
||||
"tvlUsd": p.get("tvlUsd"),
|
||||
"apy": p.get("apy"),
|
||||
"apyBase": p.get("apyBase"),
|
||||
"apyReward": p.get("apyReward"),
|
||||
}
|
||||
for p in pools
|
||||
]
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch yield pools from DefiLlama")
|
||||
return []
|
||||
|
||||
|
||||
def _extract_circulating(asset: dict[str, Any]) -> float | None:
|
||||
"""Extract the primary circulating supply value from a stablecoin asset dict."""
|
||||
raw = asset.get("circulating")
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, (int, float)):
|
||||
return float(raw)
|
||||
if isinstance(raw, dict):
|
||||
# DefiLlama returns {"peggedUSD": <amount>, ...}
|
||||
values = [v for v in raw.values() if isinstance(v, (int, float))]
|
||||
return values[0] if values else None
|
||||
return None
|
||||
|
||||
|
||||
async def get_stablecoins(limit: int = 20) -> list[dict[str, Any]]:
|
||||
"""Fetch top stablecoins by circulating supply from DefiLlama."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
resp = await client.get(f"{STABLES_BASE}/stablecoins")
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
assets: list[dict[str, Any]] = payload.get("peggedAssets", [])
|
||||
|
||||
return [
|
||||
{
|
||||
"name": a.get("name"),
|
||||
"symbol": a.get("symbol"),
|
||||
"pegType": a.get("pegType"),
|
||||
"circulating": _extract_circulating(a),
|
||||
"price": a.get("price"),
|
||||
}
|
||||
for a in assets[:limit]
|
||||
]
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch stablecoins from DefiLlama")
|
||||
return []
|
||||
|
||||
|
||||
async def get_dex_volumes() -> dict[str, Any] | None:
|
||||
"""Fetch DEX volume overview from DefiLlama."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
resp = await client.get(f"{LLAMA_BASE}/overview/dexs")
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
|
||||
protocols = [
|
||||
{
|
||||
"name": p.get("name"),
|
||||
"volume24h": p.get("total24h"),
|
||||
}
|
||||
for p in payload.get("protocols", [])
|
||||
]
|
||||
|
||||
return {
|
||||
"totalVolume24h": payload.get("total24h"),
|
||||
"totalVolume7d": payload.get("total7d"),
|
||||
"protocols": protocols,
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch DEX volumes from DefiLlama")
|
||||
return None
|
||||
|
||||
|
||||
async def get_protocol_fees() -> list[dict[str, Any]]:
|
||||
"""Fetch protocol fees and revenue overview from DefiLlama."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
resp = await client.get(f"{LLAMA_BASE}/overview/fees")
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
|
||||
return [
|
||||
{
|
||||
"name": p.get("name"),
|
||||
"fees24h": p.get("total24h"),
|
||||
"revenue24h": p.get("revenue24h"),
|
||||
}
|
||||
for p in payload.get("protocols", [])
|
||||
]
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch protocol fees from DefiLlama")
|
||||
return []
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Finnhub API client for sentiment, insider trades, and analyst data."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
@@ -104,12 +105,67 @@ async def get_upgrade_downgrade(
|
||||
return data if isinstance(data, list) else []
|
||||
|
||||
|
||||
async def get_social_sentiment(symbol: str) -> dict[str, Any]:
|
||||
"""Get social media sentiment from Reddit and Twitter.
|
||||
|
||||
Returns mention counts, positive/negative scores, and trends.
|
||||
"""
|
||||
if not _is_configured():
|
||||
return {"configured": False, "message": "Set INVEST_API_FINNHUB_API_KEY"}
|
||||
start = (datetime.now() - timedelta(days=3)).strftime("%Y-%m-%d")
|
||||
async with _client() as client:
|
||||
resp = await client.get(
|
||||
"/stock/social-sentiment",
|
||||
params={"symbol": symbol, "from": start},
|
||||
)
|
||||
if resp.status_code in (403, 401):
|
||||
logger.debug("social-sentiment requires premium, skipping")
|
||||
return {"configured": True, "symbol": symbol, "premium_required": True, "reddit": [], "twitter": []}
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if not isinstance(data, dict):
|
||||
return {"configured": True, "symbol": symbol, "reddit": [], "twitter": []}
|
||||
reddit = data.get("reddit", [])
|
||||
twitter = data.get("twitter", [])
|
||||
|
||||
# Compute summary stats
|
||||
reddit_summary = _summarize_social(reddit) if reddit else None
|
||||
twitter_summary = _summarize_social(twitter) if twitter else None
|
||||
|
||||
return {
|
||||
"configured": True,
|
||||
"symbol": symbol,
|
||||
"reddit_summary": reddit_summary,
|
||||
"twitter_summary": twitter_summary,
|
||||
"reddit": reddit[-20:],
|
||||
"twitter": twitter[-20:],
|
||||
}
|
||||
|
||||
|
||||
def _summarize_social(entries: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Summarize social sentiment entries into aggregate stats."""
|
||||
if not entries:
|
||||
return {}
|
||||
total_mentions = sum(e.get("mention", 0) for e in entries)
|
||||
total_positive = sum(e.get("positiveScore", 0) for e in entries)
|
||||
total_negative = sum(e.get("negativeScore", 0) for e in entries)
|
||||
avg_score = sum(e.get("score", 0) for e in entries) / len(entries)
|
||||
return {
|
||||
"total_mentions": total_mentions,
|
||||
"total_positive": total_positive,
|
||||
"total_negative": total_negative,
|
||||
"avg_score": round(avg_score, 4),
|
||||
"data_points": len(entries),
|
||||
}
|
||||
|
||||
|
||||
# Reddit sentiment moved to reddit_service.py
|
||||
|
||||
|
||||
async def get_sentiment_summary(symbol: str) -> dict[str, Any]:
|
||||
"""Aggregate all sentiment data for a symbol into one response."""
|
||||
if not _is_configured():
|
||||
return {"configured": False, "message": "Set INVEST_API_FINNHUB_API_KEY to enable sentiment data"}
|
||||
|
||||
import asyncio
|
||||
news_sentiment, company_news, recommendations, upgrades = await asyncio.gather(
|
||||
get_news_sentiment(symbol),
|
||||
get_company_news(symbol, days=7),
|
||||
|
||||
8
main.py
8
main.py
@@ -37,6 +37,10 @@ from routes_sentiment import router as sentiment_router # noqa: E402
|
||||
from routes_shorts import router as shorts_router # noqa: E402
|
||||
from routes_surveys import router as surveys_router # noqa: E402
|
||||
from routes_technical import router as technical_router # noqa: E402
|
||||
from routes_portfolio import router as portfolio_router # noqa: E402
|
||||
from routes_backtest import router as backtest_router # noqa: E402
|
||||
from routes_cn import router as cn_router # noqa: E402
|
||||
from routes_defi import router as defi_router # noqa: E402
|
||||
|
||||
logging.basicConfig(
|
||||
level=settings.log_level.upper(),
|
||||
@@ -81,6 +85,10 @@ app.include_router(fixed_income_router)
|
||||
app.include_router(economy_router)
|
||||
app.include_router(surveys_router)
|
||||
app.include_router(regulators_router)
|
||||
app.include_router(portfolio_router)
|
||||
app.include_router(backtest_router)
|
||||
app.include_router(cn_router)
|
||||
app.include_router(defi_router)
|
||||
|
||||
|
||||
@app.get("/health", response_model=dict[str, str])
|
||||
|
||||
@@ -2,17 +2,14 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Any
|
||||
|
||||
from openbb import obb
|
||||
|
||||
from obb_utils import to_list
|
||||
from obb_utils import to_list, days_ago, PROVIDER
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROVIDER = "yfinance"
|
||||
|
||||
|
||||
# --- ETF ---
|
||||
|
||||
@@ -30,7 +27,7 @@ async def get_etf_info(symbol: str) -> dict[str, Any]:
|
||||
|
||||
async def get_etf_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]:
|
||||
"""Get ETF price history."""
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
start = days_ago(days)
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
obb.etf.historical, symbol, start_date=start, provider=PROVIDER
|
||||
@@ -66,7 +63,7 @@ async def get_available_indices() -> list[dict[str, Any]]:
|
||||
|
||||
async def get_index_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]:
|
||||
"""Get index price history."""
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
start = days_ago(days)
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
obb.index.price.historical, symbol, start_date=start, provider=PROVIDER
|
||||
@@ -82,7 +79,7 @@ async def get_index_historical(symbol: str, days: int = 365) -> list[dict[str, A
|
||||
|
||||
async def get_crypto_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]:
|
||||
"""Get cryptocurrency price history."""
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
start = days_ago(days)
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
obb.crypto.price.historical, symbol, start_date=start, provider=PROVIDER
|
||||
@@ -110,7 +107,7 @@ async def get_currency_historical(
|
||||
symbol: str, days: int = 365
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get forex price history (e.g., EURUSD)."""
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
start = days_ago(days)
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
obb.currency.price.historical, symbol, start_date=start, provider=PROVIDER
|
||||
@@ -140,7 +137,7 @@ async def get_futures_historical(
|
||||
symbol: str, days: int = 365
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get futures price history."""
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
start = days_ago(days)
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
obb.derivatives.futures.historical, symbol, start_date=start, provider=PROVIDER
|
||||
|
||||
@@ -66,11 +66,16 @@ def first_or_empty(result: Any) -> dict[str, Any]:
|
||||
return items[0] if items else {}
|
||||
|
||||
|
||||
def days_ago(days: int) -> str:
|
||||
"""Return a YYYY-MM-DD date string for N days ago (UTC)."""
|
||||
return (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
async def fetch_historical(
|
||||
symbol: str, days: int = 365, provider: str = PROVIDER,
|
||||
) -> Any | None:
|
||||
"""Fetch historical price data, returning the OBBject result or None."""
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
start = days_ago(days)
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.price.historical, symbol, start_date=start, provider=provider,
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import yfinance as yf
|
||||
from openbb import obb
|
||||
|
||||
from obb_utils import to_list, first_or_empty
|
||||
from obb_utils import to_list, first_or_empty, days_ago, PROVIDER
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROVIDER = "yfinance"
|
||||
|
||||
|
||||
async def get_quote(symbol: str) -> dict[str, Any]:
|
||||
result = await asyncio.to_thread(
|
||||
@@ -21,7 +18,7 @@ async def get_quote(symbol: str) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def get_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]:
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
start = days_ago(days)
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.price.historical,
|
||||
symbol,
|
||||
|
||||
372
portfolio_service.py
Normal file
372
portfolio_service.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""Portfolio optimization: HRP, correlation matrix, risk parity, t-SNE clustering."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from math import isqrt
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from obb_utils import fetch_historical
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def fetch_historical_prices(symbols: list[str], days: int = 365) -> pd.DataFrame:
|
||||
"""Fetch closing prices for multiple symbols and return as a DataFrame.
|
||||
|
||||
Columns are symbol names; rows are dates. Symbols with no data are skipped.
|
||||
"""
|
||||
tasks = [fetch_historical(sym, days=days) for sym in symbols]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
price_series: dict[str, pd.Series] = {}
|
||||
for sym, result in zip(symbols, results):
|
||||
if result is None or result.results is None:
|
||||
logger.warning("No historical data for %s, skipping", sym)
|
||||
continue
|
||||
rows = result.results
|
||||
if not rows:
|
||||
continue
|
||||
dates = []
|
||||
closes = []
|
||||
for row in rows:
|
||||
d = getattr(row, "date", None)
|
||||
c = getattr(row, "close", None)
|
||||
if d is not None and c is not None:
|
||||
dates.append(str(d))
|
||||
closes.append(float(c))
|
||||
if dates:
|
||||
price_series[sym] = pd.Series(closes, index=dates)
|
||||
|
||||
if not price_series:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(price_series)
|
||||
df = df.dropna(how="all")
|
||||
return df
|
||||
|
||||
|
||||
def _compute_returns(prices: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Compute daily log returns from a price DataFrame."""
|
||||
return prices.pct_change().dropna()
|
||||
|
||||
|
||||
def _inverse_volatility_weights(returns: pd.DataFrame) -> dict[str, float]:
|
||||
"""Compute inverse-volatility weights."""
|
||||
vols = returns.std()
|
||||
inv_vols = 1.0 / vols
|
||||
weights = inv_vols / inv_vols.sum()
|
||||
return {sym: float(w) for sym, w in weights.items()}
|
||||
|
||||
|
||||
def _hrp_weights(returns: pd.DataFrame) -> dict[str, float]:
|
||||
"""Compute Hierarchical Risk Parity weights via scipy clustering.
|
||||
|
||||
Falls back to inverse-volatility if scipy is unavailable.
|
||||
"""
|
||||
symbols = list(returns.columns)
|
||||
n = len(symbols)
|
||||
|
||||
if n == 1:
|
||||
return {symbols[0]: 1.0}
|
||||
|
||||
try:
|
||||
from scipy.cluster.hierarchy import linkage, leaves_list
|
||||
from scipy.spatial.distance import squareform
|
||||
|
||||
corr = returns.corr().fillna(0).values
|
||||
# Convert correlation to distance: d = sqrt(0.5 * (1 - corr))
|
||||
dist = np.sqrt(np.clip(0.5 * (1 - corr), 0, 1))
|
||||
np.fill_diagonal(dist, 0.0)
|
||||
condensed = squareform(dist)
|
||||
link = linkage(condensed, method="single")
|
||||
order = leaves_list(link)
|
||||
sorted_symbols = [symbols[i] for i in order]
|
||||
except ImportError:
|
||||
logger.warning("scipy not available; using inverse-volatility for HRP")
|
||||
return _inverse_volatility_weights(returns)
|
||||
|
||||
cov = returns.cov().values
|
||||
|
||||
def _bisect_weights(items: list[str]) -> dict[str, float]:
|
||||
if len(items) == 1:
|
||||
return {items[0]: 1.0}
|
||||
mid = len(items) // 2
|
||||
left_items = items[:mid]
|
||||
right_items = items[mid:]
|
||||
left_idx = [sorted_symbols.index(s) for s in left_items]
|
||||
right_idx = [sorted_symbols.index(s) for s in right_items]
|
||||
|
||||
def _cluster_var(idx: list[int]) -> float:
|
||||
sub_cov = cov[np.ix_(idx, idx)]
|
||||
w = np.ones(len(idx)) / len(idx)
|
||||
return float(w @ sub_cov @ w)
|
||||
|
||||
v_left = _cluster_var(left_idx)
|
||||
v_right = _cluster_var(right_idx)
|
||||
total = v_left + v_right
|
||||
alpha = 1.0 - v_left / total if total > 0 else 0.5
|
||||
|
||||
w_left = _bisect_weights(left_items)
|
||||
w_right = _bisect_weights(right_items)
|
||||
|
||||
result = {}
|
||||
for sym, w in w_left.items():
|
||||
result[sym] = w * (1.0 - alpha)
|
||||
for sym, w in w_right.items():
|
||||
result[sym] = w * alpha
|
||||
return result
|
||||
|
||||
raw = _bisect_weights(sorted_symbols)
|
||||
total = sum(raw.values())
|
||||
return {sym: float(w / total) for sym, w in raw.items()}
|
||||
|
||||
|
||||
async def optimize_hrp(symbols: list[str], days: int = 365) -> dict[str, Any]:
|
||||
"""Compute Hierarchical Risk Parity portfolio weights.
|
||||
|
||||
Args:
|
||||
symbols: List of ticker symbols (1-50).
|
||||
days: Number of historical days to use.
|
||||
|
||||
Returns:
|
||||
Dict with keys ``weights`` (symbol -> float) and ``method``.
|
||||
|
||||
Raises:
|
||||
ValueError: If symbols is empty or no price data is available.
|
||||
"""
|
||||
if not symbols:
|
||||
raise ValueError("symbols must not be empty")
|
||||
|
||||
prices = await fetch_historical_prices(symbols, days=days)
|
||||
if prices.empty:
|
||||
raise ValueError("No price data available for the given symbols")
|
||||
|
||||
returns = _compute_returns(prices)
|
||||
weights = _hrp_weights(returns)
|
||||
|
||||
return {"weights": weights, "method": "hrp"}
|
||||
|
||||
|
||||
async def compute_correlation(
|
||||
symbols: list[str], days: int = 365
|
||||
) -> dict[str, Any]:
|
||||
"""Compute correlation matrix for a list of symbols.
|
||||
|
||||
Args:
|
||||
symbols: List of ticker symbols (1-50).
|
||||
days: Number of historical days to use.
|
||||
|
||||
Returns:
|
||||
Dict with keys ``symbols`` (list) and ``matrix`` (list of lists).
|
||||
|
||||
Raises:
|
||||
ValueError: If symbols is empty or no price data is available.
|
||||
"""
|
||||
if not symbols:
|
||||
raise ValueError("symbols must not be empty")
|
||||
|
||||
prices = await fetch_historical_prices(symbols, days=days)
|
||||
if prices.empty:
|
||||
raise ValueError("No price data available for the given symbols")
|
||||
|
||||
returns = _compute_returns(prices)
|
||||
available = list(returns.columns)
|
||||
corr = returns.corr().fillna(0)
|
||||
|
||||
matrix = corr.values.tolist()
|
||||
|
||||
return {"symbols": available, "matrix": matrix}
|
||||
|
||||
|
||||
async def compute_risk_parity(
|
||||
symbols: list[str], days: int = 365
|
||||
) -> dict[str, Any]:
|
||||
"""Compute equal risk contribution (inverse-volatility) weights.
|
||||
|
||||
Args:
|
||||
symbols: List of ticker symbols (1-50).
|
||||
days: Number of historical days to use.
|
||||
|
||||
Returns:
|
||||
Dict with keys ``weights``, ``risk_contributions``, and ``method``.
|
||||
|
||||
Raises:
|
||||
ValueError: If symbols is empty or no price data is available.
|
||||
"""
|
||||
if not symbols:
|
||||
raise ValueError("symbols must not be empty")
|
||||
|
||||
prices = await fetch_historical_prices(symbols, days=days)
|
||||
if prices.empty:
|
||||
raise ValueError("No price data available for the given symbols")
|
||||
|
||||
returns = _compute_returns(prices)
|
||||
weights = _inverse_volatility_weights(returns)
|
||||
|
||||
# Risk contributions: w_i * sigma_i / sum(w_j * sigma_j)
|
||||
vols = returns.std()
|
||||
weighted_risk = {sym: weights[sym] * float(vols[sym]) for sym in weights}
|
||||
total_risk = sum(weighted_risk.values())
|
||||
if total_risk > 0:
|
||||
risk_contributions = {sym: v / total_risk for sym, v in weighted_risk.items()}
|
||||
else:
|
||||
n = len(weights)
|
||||
risk_contributions = {sym: 1.0 / n for sym in weights}
|
||||
|
||||
return {
|
||||
"weights": weights,
|
||||
"risk_contributions": risk_contributions,
|
||||
"method": "risk_parity",
|
||||
}
|
||||
|
||||
|
||||
def _auto_n_clusters(n: int) -> int:
|
||||
"""Return a sensible default cluster count: max(2, floor(sqrt(n)))."""
|
||||
return max(2, isqrt(n))
|
||||
|
||||
|
||||
def _run_tsne_kmeans(
|
||||
returns_matrix: np.ndarray, n_clusters: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Run t-SNE then KMeans on a (n_symbols x n_days) returns matrix.
|
||||
|
||||
Returns (coords, labels) where coords has shape (n_symbols, 2).
|
||||
CPU-heavy: caller must wrap in asyncio.to_thread.
|
||||
"""
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
n_samples = returns_matrix.shape[0]
|
||||
perplexity = min(5, n_samples - 1)
|
||||
|
||||
# Add tiny noise to prevent numerical singularity when returns are identical
|
||||
rng = np.random.default_rng(42)
|
||||
jittered = returns_matrix + rng.normal(0, 1e-10, returns_matrix.shape)
|
||||
|
||||
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42, method="exact")
|
||||
coords = tsne.fit_transform(jittered)
|
||||
|
||||
km = KMeans(n_clusters=n_clusters, random_state=42, n_init="auto")
|
||||
labels = km.fit_predict(coords)
|
||||
|
||||
return coords, labels
|
||||
|
||||
|
||||
async def cluster_stocks(
|
||||
symbols: list[str],
|
||||
days: int = 180,
|
||||
n_clusters: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Cluster stocks by return similarity using t-SNE + KMeans.
|
||||
|
||||
Args:
|
||||
symbols: List of ticker symbols. Minimum 3, maximum 50.
|
||||
days: Number of historical trading days to use.
|
||||
n_clusters: Number of clusters. Defaults to floor(sqrt(n_symbols)).
|
||||
|
||||
Returns:
|
||||
Dict with keys ``symbols``, ``coordinates``, ``clusters``,
|
||||
``method``, ``n_clusters``, and ``days``.
|
||||
|
||||
Raises:
|
||||
ValueError: Fewer than 3 symbols, or no price data available.
|
||||
"""
|
||||
if len(symbols) < 3:
|
||||
raise ValueError("cluster_stocks requires at least 3 symbols")
|
||||
|
||||
prices = await fetch_historical_prices(symbols, days=days)
|
||||
if prices.empty:
|
||||
raise ValueError("No price data available for the given symbols")
|
||||
|
||||
returns = _compute_returns(prices)
|
||||
available = list(returns.columns)
|
||||
n = len(available)
|
||||
|
||||
k = n_clusters if n_clusters is not None else _auto_n_clusters(n)
|
||||
|
||||
# Build (n_symbols x n_days) matrix; fill NaN with column mean
|
||||
matrix = returns[available].T.fillna(0).values.astype(float)
|
||||
|
||||
coords, labels = await asyncio.to_thread(_run_tsne_kmeans, matrix, k)
|
||||
|
||||
coordinates = [
|
||||
{
|
||||
"symbol": sym,
|
||||
"x": float(coords[i, 0]),
|
||||
"y": float(coords[i, 1]),
|
||||
"cluster": int(labels[i]),
|
||||
}
|
||||
for i, sym in enumerate(available)
|
||||
]
|
||||
|
||||
clusters: dict[str, list[str]] = {}
|
||||
for sym, label in zip(available, labels):
|
||||
key = str(int(label))
|
||||
clusters.setdefault(key, []).append(sym)
|
||||
|
||||
return {
|
||||
"symbols": available,
|
||||
"coordinates": coordinates,
|
||||
"clusters": clusters,
|
||||
"method": "t-SNE + KMeans",
|
||||
"n_clusters": k,
|
||||
"days": days,
|
||||
}
|
||||
|
||||
|
||||
async def find_similar_stocks(
|
||||
symbol: str,
|
||||
universe: list[str],
|
||||
days: int = 180,
|
||||
top_n: int = 5,
|
||||
) -> dict[str, Any]:
|
||||
"""Find stocks most/least similar to a target by return correlation.
|
||||
|
||||
Args:
|
||||
symbol: Target ticker symbol.
|
||||
universe: List of candidate symbols to compare against.
|
||||
days: Number of historical trading days to use.
|
||||
top_n: Number of most- and least-similar stocks to return.
|
||||
|
||||
Returns:
|
||||
Dict with keys ``symbol``, ``most_similar``, ``least_similar``.
|
||||
|
||||
Raises:
|
||||
ValueError: No price data available, or target symbol missing from data.
|
||||
"""
|
||||
all_symbols = [symbol] + [s for s in universe if s != symbol]
|
||||
prices = await fetch_historical_prices(all_symbols, days=days)
|
||||
|
||||
if prices.empty:
|
||||
raise ValueError("No price data available for the given symbols")
|
||||
|
||||
if symbol not in prices.columns:
|
||||
raise ValueError(
|
||||
f"{symbol} not found in price data; it may have no available history"
|
||||
)
|
||||
|
||||
returns = _compute_returns(prices)
|
||||
|
||||
target_returns = returns[symbol]
|
||||
peers = [s for s in universe if s in returns.columns and s != symbol]
|
||||
|
||||
correlations: list[dict[str, Any]] = []
|
||||
for peer in peers:
|
||||
corr_val = float(target_returns.corr(returns[peer]))
|
||||
if not np.isnan(corr_val):
|
||||
correlations.append({"symbol": peer, "correlation": corr_val})
|
||||
|
||||
correlations.sort(key=lambda e: e["correlation"], reverse=True)
|
||||
|
||||
n = min(top_n, len(correlations))
|
||||
most_similar = correlations[:n]
|
||||
least_similar = sorted(correlations, key=lambda e: e["correlation"])[:n]
|
||||
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"most_similar": most_similar,
|
||||
"least_similar": least_similar,
|
||||
}
|
||||
@@ -10,6 +10,7 @@ dependencies = [
|
||||
"pydantic-settings",
|
||||
"httpx",
|
||||
"curl_cffi==0.7.4",
|
||||
"akshare",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -2,17 +2,14 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Any
|
||||
|
||||
from openbb import obb
|
||||
|
||||
from obb_utils import extract_single, safe_last, fetch_historical, to_list
|
||||
from obb_utils import extract_single, safe_last, fetch_historical, to_list, days_ago, PROVIDER
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROVIDER = "yfinance"
|
||||
|
||||
# Need 252+ trading days for default window; 730 calendar days is safe
|
||||
PERF_DAYS = 730
|
||||
TARGET = "close"
|
||||
@@ -22,7 +19,7 @@ async def get_performance_metrics(symbol: str, days: int = 365) -> dict[str, Any
|
||||
"""Calculate Sharpe ratio, summary stats, and volatility for a symbol."""
|
||||
# Need at least 252 trading days for Sharpe window
|
||||
fetch_days = max(days, PERF_DAYS)
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=fetch_days)).strftime("%Y-%m-%d")
|
||||
start = days_ago(fetch_days)
|
||||
|
||||
try:
|
||||
hist = await asyncio.to_thread(
|
||||
@@ -64,7 +61,7 @@ async def get_performance_metrics(symbol: str, days: int = 365) -> dict[str, Any
|
||||
|
||||
async def get_capm(symbol: str) -> dict[str, Any]:
|
||||
"""Calculate CAPM metrics: beta, alpha, systematic/idiosyncratic risk."""
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=PERF_DAYS)).strftime("%Y-%m-%d")
|
||||
start = days_ago(PERF_DAYS)
|
||||
|
||||
try:
|
||||
hist = await asyncio.to_thread(
|
||||
@@ -85,7 +82,7 @@ async def get_capm(symbol: str) -> dict[str, Any]:
|
||||
async def get_normality_test(symbol: str, days: int = 365) -> dict[str, Any]:
|
||||
"""Run normality tests (Jarque-Bera, Shapiro-Wilk, etc.) on returns."""
|
||||
fetch_days = max(days, PERF_DAYS)
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=fetch_days)).strftime("%Y-%m-%d")
|
||||
start = days_ago(fetch_days)
|
||||
|
||||
try:
|
||||
hist = await asyncio.to_thread(
|
||||
@@ -106,7 +103,7 @@ async def get_normality_test(symbol: str, days: int = 365) -> dict[str, Any]:
|
||||
async def get_unitroot_test(symbol: str, days: int = 365) -> dict[str, Any]:
|
||||
"""Run unit root tests (ADF, KPSS) for stationarity."""
|
||||
fetch_days = max(days, PERF_DAYS)
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=fetch_days)).strftime("%Y-%m-%d")
|
||||
start = days_ago(fetch_days)
|
||||
|
||||
try:
|
||||
hist = await asyncio.to_thread(
|
||||
|
||||
80
reddit_service.py
Normal file
80
reddit_service.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Reddit stock sentiment via ApeWisdom API (free, no key needed)."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
APEWISDOM_URL = "https://apewisdom.io/api/v1.0/filter/all-stocks/page/1"
|
||||
TIMEOUT = 10.0
|
||||
|
||||
|
||||
async def get_reddit_sentiment(symbol: str) -> dict[str, Any]:
|
||||
"""Get Reddit sentiment for a symbol.
|
||||
|
||||
Tracks mentions and upvotes across r/wallstreetbets, r/stocks, r/investing.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
resp = await client.get(APEWISDOM_URL)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
results = data.get("results", [])
|
||||
|
||||
match = next(
|
||||
(r for r in results if r.get("ticker", "").upper() == symbol.upper()),
|
||||
None,
|
||||
)
|
||||
if match is None:
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"found": False,
|
||||
"message": f"{symbol} not in Reddit top trending (not enough mentions)",
|
||||
}
|
||||
|
||||
mentions_prev = match.get("mentions_24h_ago", 0)
|
||||
mentions_now = match.get("mentions", 0)
|
||||
change_pct = (
|
||||
round((mentions_now - mentions_prev) / mentions_prev * 100, 1)
|
||||
if mentions_prev > 0
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"found": True,
|
||||
"rank": match.get("rank"),
|
||||
"mentions_24h": mentions_now,
|
||||
"mentions_24h_ago": mentions_prev,
|
||||
"mentions_change_pct": change_pct,
|
||||
"upvotes": match.get("upvotes"),
|
||||
"rank_24h_ago": match.get("rank_24h_ago"),
|
||||
}
|
||||
except Exception:
|
||||
logger.warning("Reddit sentiment failed for %s", symbol, exc_info=True)
|
||||
return {"symbol": symbol, "error": "Failed to fetch Reddit sentiment"}
|
||||
|
||||
|
||||
async def get_reddit_trending() -> list[dict[str, Any]]:
|
||||
"""Get top trending stocks on Reddit (free, no key)."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
|
||||
resp = await client.get(APEWISDOM_URL)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return [
|
||||
{
|
||||
"rank": r.get("rank"),
|
||||
"symbol": r.get("ticker"),
|
||||
"name": r.get("name"),
|
||||
"mentions_24h": r.get("mentions"),
|
||||
"upvotes": r.get("upvotes"),
|
||||
"rank_24h_ago": r.get("rank_24h_ago"),
|
||||
"mentions_24h_ago": r.get("mentions_24h_ago"),
|
||||
}
|
||||
for r in data.get("results", [])[:25]
|
||||
]
|
||||
except Exception:
|
||||
logger.warning("Reddit trending failed", exc_info=True)
|
||||
return []
|
||||
@@ -23,13 +23,19 @@ def validate_symbol(symbol: str) -> str:
|
||||
|
||||
|
||||
def safe(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
"""Decorator to catch upstream errors and return 502."""
|
||||
"""Decorator to catch upstream errors and return 502.
|
||||
|
||||
ValueError is caught separately and returned as 400 (bad request).
|
||||
All other non-HTTP exceptions become 502 (upstream error).
|
||||
"""
|
||||
@functools.wraps(fn)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
try:
|
||||
return await fn(*args, **kwargs)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
except Exception:
|
||||
logger.exception("Upstream data error")
|
||||
raise HTTPException(
|
||||
|
||||
106
routes_backtest.py
Normal file
106
routes_backtest.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Routes for backtesting strategies."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import backtest_service
|
||||
from models import ApiResponse
|
||||
from route_utils import safe
|
||||
|
||||
router = APIRouter(prefix="/api/v1/backtest", tags=["backtest"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BacktestRequest(BaseModel):
|
||||
symbol: str = Field(..., min_length=1, max_length=20)
|
||||
days: int = Field(default=365, ge=30, le=3650)
|
||||
initial_capital: float = Field(default=10000.0, gt=0, le=1_000_000_000)
|
||||
|
||||
|
||||
class SMARequest(BacktestRequest):
|
||||
short_window: int = Field(default=20, ge=5, le=100)
|
||||
long_window: int = Field(default=50, ge=10, le=400)
|
||||
|
||||
|
||||
class RSIRequest(BacktestRequest):
|
||||
period: int = Field(default=14, ge=2, le=50)
|
||||
oversold: float = Field(default=30.0, ge=1, le=49)
|
||||
overbought: float = Field(default=70.0, ge=51, le=99)
|
||||
|
||||
|
||||
class BuyAndHoldRequest(BacktestRequest):
|
||||
pass
|
||||
|
||||
|
||||
class MomentumRequest(BaseModel):
|
||||
symbols: list[str] = Field(..., min_length=2, max_length=20)
|
||||
lookback: int = Field(default=60, ge=5, le=252)
|
||||
top_n: int = Field(default=2, ge=1)
|
||||
rebalance_days: int = Field(default=30, ge=5, le=252)
|
||||
days: int = Field(default=365, ge=60, le=3650)
|
||||
initial_capital: float = Field(default=10000.0, gt=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/sma-crossover", response_model=ApiResponse)
|
||||
@safe
|
||||
async def sma_crossover(req: SMARequest) -> ApiResponse:
|
||||
"""SMA crossover strategy: buy when short SMA crosses above long SMA."""
|
||||
result = await backtest_service.backtest_sma_crossover(
|
||||
req.symbol,
|
||||
short_window=req.short_window,
|
||||
long_window=req.long_window,
|
||||
days=req.days,
|
||||
initial_capital=req.initial_capital,
|
||||
)
|
||||
return ApiResponse(data=result)
|
||||
|
||||
|
||||
@router.post("/rsi", response_model=ApiResponse)
|
||||
@safe
|
||||
async def rsi_strategy(req: RSIRequest) -> ApiResponse:
|
||||
"""RSI strategy: buy when RSI < oversold, sell when RSI > overbought."""
|
||||
result = await backtest_service.backtest_rsi(
|
||||
req.symbol,
|
||||
period=req.period,
|
||||
oversold=req.oversold,
|
||||
overbought=req.overbought,
|
||||
days=req.days,
|
||||
initial_capital=req.initial_capital,
|
||||
)
|
||||
return ApiResponse(data=result)
|
||||
|
||||
|
||||
@router.post("/buy-and-hold", response_model=ApiResponse)
|
||||
@safe
|
||||
async def buy_and_hold(req: BuyAndHoldRequest) -> ApiResponse:
|
||||
"""Buy-and-hold benchmark: buy on day 1, hold through end of period."""
|
||||
result = await backtest_service.backtest_buy_and_hold(
|
||||
req.symbol,
|
||||
days=req.days,
|
||||
initial_capital=req.initial_capital,
|
||||
)
|
||||
return ApiResponse(data=result)
|
||||
|
||||
|
||||
@router.post("/momentum", response_model=ApiResponse)
|
||||
@safe
|
||||
async def momentum_strategy(req: MomentumRequest) -> ApiResponse:
|
||||
"""Momentum strategy: every rebalance_days pick top_n by lookback return."""
|
||||
result = await backtest_service.backtest_momentum(
|
||||
symbols=req.symbols,
|
||||
lookback=req.lookback,
|
||||
top_n=req.top_n,
|
||||
rebalance_days=req.rebalance_days,
|
||||
days=req.days,
|
||||
initial_capital=req.initial_capital,
|
||||
)
|
||||
return ApiResponse(data=result)
|
||||
105
routes_cn.py
Normal file
105
routes_cn.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Routes for A-share (China) and Hong Kong stock market data via AKShare."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Path, Query
|
||||
|
||||
from models import ApiResponse
|
||||
from route_utils import safe
|
||||
import akshare_service
|
||||
|
||||
router = APIRouter(prefix="/api/v1/cn", tags=["China & HK Markets"])
|
||||
|
||||
|
||||
# --- Validation helpers ---
|
||||
|
||||
|
||||
def _validate_a_share(symbol: str) -> str:
|
||||
"""Validate A-share symbol format; raise 400 on failure."""
|
||||
if not akshare_service.validate_a_share_symbol(symbol):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Invalid A-share symbol '{symbol}'. "
|
||||
"Must be 6 digits starting with 0, 3, or 6 (e.g. 000001, 300001, 600519)."
|
||||
),
|
||||
)
|
||||
return symbol
|
||||
|
||||
|
||||
def _validate_hk(symbol: str) -> str:
|
||||
"""Validate HK stock symbol format; raise 400 on failure."""
|
||||
if not akshare_service.validate_hk_symbol(symbol):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Invalid HK symbol '{symbol}'. "
|
||||
"Must be exactly 5 digits (e.g. 00700, 09988)."
|
||||
),
|
||||
)
|
||||
return symbol
|
||||
|
||||
|
||||
# --- A-share routes ---
|
||||
# NOTE: /a-share/search MUST be registered before /a-share/{symbol} to avoid shadowing.
|
||||
|
||||
|
||||
@router.get("/a-share/search", response_model=ApiResponse)
|
||||
@safe
|
||||
async def a_share_search(
|
||||
query: str = Query(..., description="Stock name to search for (partial match)"),
|
||||
) -> ApiResponse:
|
||||
"""Search A-share stocks by name (partial match)."""
|
||||
data = await akshare_service.search_a_shares(query)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/a-share/{symbol}/quote", response_model=ApiResponse)
|
||||
@safe
|
||||
async def a_share_quote(
|
||||
symbol: str = Path(..., min_length=6, max_length=6),
|
||||
) -> ApiResponse:
|
||||
"""Get real-time A-share quote (沪深 real-time price)."""
|
||||
symbol = _validate_a_share(symbol)
|
||||
data = await akshare_service.get_a_share_quote(symbol)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail=f"A-share symbol '{symbol}' not found.")
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/a-share/{symbol}/historical", response_model=ApiResponse)
|
||||
@safe
|
||||
async def a_share_historical(
|
||||
symbol: str = Path(..., min_length=6, max_length=6),
|
||||
days: int = Query(default=365, ge=1, le=3650),
|
||||
) -> ApiResponse:
|
||||
"""Get A-share daily OHLCV history with qfq (前复权) adjustment."""
|
||||
symbol = _validate_a_share(symbol)
|
||||
data = await akshare_service.get_a_share_historical(symbol, days=days)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
# --- HK stock routes ---
|
||||
|
||||
|
||||
@router.get("/hk/{symbol}/quote", response_model=ApiResponse)
|
||||
@safe
|
||||
async def hk_quote(
|
||||
symbol: str = Path(..., min_length=5, max_length=5),
|
||||
) -> ApiResponse:
|
||||
"""Get real-time HK stock quote (港股 real-time price)."""
|
||||
symbol = _validate_hk(symbol)
|
||||
data = await akshare_service.get_hk_quote(symbol)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail=f"HK symbol '{symbol}' not found.")
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/hk/{symbol}/historical", response_model=ApiResponse)
|
||||
@safe
|
||||
async def hk_historical(
|
||||
symbol: str = Path(..., min_length=5, max_length=5),
|
||||
days: int = Query(default=365, ge=1, le=3650),
|
||||
) -> ApiResponse:
|
||||
"""Get HK stock daily OHLCV history with qfq adjustment."""
|
||||
symbol = _validate_hk(symbol)
|
||||
data = await akshare_service.get_hk_historical(symbol, days=days)
|
||||
return ApiResponse(data=data)
|
||||
75
routes_defi.py
Normal file
75
routes_defi.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""DeFi data routes via DefiLlama API."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
import defi_service
|
||||
from models import ApiResponse
|
||||
from route_utils import safe
|
||||
|
||||
router = APIRouter(prefix="/api/v1/defi")
|
||||
|
||||
|
||||
@router.get("/tvl/protocols", response_model=ApiResponse)
|
||||
@safe
|
||||
async def tvl_protocols() -> ApiResponse:
|
||||
"""Get top DeFi protocols ranked by TVL."""
|
||||
data = await defi_service.get_top_protocols()
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/tvl/chains", response_model=ApiResponse)
|
||||
@safe
|
||||
async def tvl_chains() -> ApiResponse:
|
||||
"""Get TVL rankings for all chains."""
|
||||
data = await defi_service.get_chain_tvls()
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/tvl/{protocol}", response_model=ApiResponse)
|
||||
@safe
|
||||
async def protocol_tvl(protocol: str) -> ApiResponse:
|
||||
"""Get current TVL for a specific protocol slug."""
|
||||
tvl = await defi_service.get_protocol_tvl(protocol)
|
||||
if tvl is None:
|
||||
raise HTTPException(status_code=404, detail=f"Protocol '{protocol}' not found")
|
||||
return ApiResponse(data={"protocol": protocol, "tvl": tvl})
|
||||
|
||||
|
||||
@router.get("/yields", response_model=ApiResponse)
|
||||
@safe
|
||||
async def yield_pools(
|
||||
chain: str | None = Query(default=None, description="Filter by chain name"),
|
||||
project: str | None = Query(default=None, description="Filter by project name"),
|
||||
) -> ApiResponse:
|
||||
"""Get top yield pools, optionally filtered by chain and/or project."""
|
||||
data = await defi_service.get_yield_pools(chain=chain, project=project)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/stablecoins", response_model=ApiResponse)
|
||||
@safe
|
||||
async def stablecoins() -> ApiResponse:
|
||||
"""Get top stablecoins by circulating supply."""
|
||||
data = await defi_service.get_stablecoins()
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/volumes/dexs", response_model=ApiResponse)
|
||||
@safe
|
||||
async def dex_volumes() -> ApiResponse:
|
||||
"""Get DEX volume overview including top protocols."""
|
||||
data = await defi_service.get_dex_volumes()
|
||||
if data is None:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch DEX volume data from DefiLlama",
|
||||
)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/fees", response_model=ApiResponse)
|
||||
@safe
|
||||
async def protocol_fees() -> ApiResponse:
|
||||
"""Get protocol fees and revenue overview."""
|
||||
data = await defi_service.get_protocol_fees()
|
||||
return ApiResponse(data=data)
|
||||
96
routes_portfolio.py
Normal file
96
routes_portfolio.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Routes for portfolio optimization (HRP, correlation, risk parity)."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models import ApiResponse
|
||||
from route_utils import safe
|
||||
import portfolio_service
|
||||
|
||||
router = APIRouter(prefix="/api/v1/portfolio")
|
||||
|
||||
|
||||
class PortfolioOptimizeRequest(BaseModel):
|
||||
symbols: list[str] = Field(..., min_length=1, max_length=50)
|
||||
days: int = Field(default=365, ge=1, le=3650)
|
||||
|
||||
|
||||
@router.post("/optimize", response_model=ApiResponse)
|
||||
@safe
|
||||
async def portfolio_optimize(request: PortfolioOptimizeRequest):
|
||||
"""Compute HRP optimal weights for a list of symbols."""
|
||||
try:
|
||||
result = await portfolio_service.optimize_hrp(
|
||||
request.symbols, days=request.days
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return ApiResponse(data=result)
|
||||
|
||||
|
||||
@router.post("/correlation", response_model=ApiResponse)
|
||||
@safe
|
||||
async def portfolio_correlation(request: PortfolioOptimizeRequest):
|
||||
"""Compute correlation matrix for a list of symbols."""
|
||||
try:
|
||||
result = await portfolio_service.compute_correlation(
|
||||
request.symbols, days=request.days
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return ApiResponse(data=result)
|
||||
|
||||
|
||||
@router.post("/risk-parity", response_model=ApiResponse)
|
||||
@safe
|
||||
async def portfolio_risk_parity(request: PortfolioOptimizeRequest):
|
||||
"""Compute equal risk contribution weights for a list of symbols."""
|
||||
try:
|
||||
result = await portfolio_service.compute_risk_parity(
|
||||
request.symbols, days=request.days
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return ApiResponse(data=result)
|
||||
|
||||
|
||||
class ClusterRequest(BaseModel):
|
||||
symbols: list[str] = Field(..., min_length=3, max_length=50)
|
||||
days: int = Field(default=180, ge=30, le=3650)
|
||||
n_clusters: int | None = Field(default=None, ge=2, le=20)
|
||||
|
||||
|
||||
class SimilarRequest(BaseModel):
|
||||
symbol: str = Field(..., min_length=1, max_length=20)
|
||||
universe: list[str] = Field(..., min_length=2, max_length=50)
|
||||
days: int = Field(default=180, ge=30, le=3650)
|
||||
top_n: int = Field(default=5, ge=1, le=20)
|
||||
|
||||
|
||||
@router.post("/cluster", response_model=ApiResponse)
|
||||
@safe
|
||||
async def portfolio_cluster(request: ClusterRequest):
|
||||
"""Cluster stocks by return similarity using t-SNE + KMeans."""
|
||||
try:
|
||||
result = await portfolio_service.cluster_stocks(
|
||||
request.symbols, days=request.days, n_clusters=request.n_clusters
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return ApiResponse(data=result)
|
||||
|
||||
|
||||
@router.post("/similar", response_model=ApiResponse)
|
||||
@safe
|
||||
async def portfolio_similar(request: SimilarRequest):
|
||||
"""Find stocks most/least similar to a target by return correlation."""
|
||||
try:
|
||||
result = await portfolio_service.find_similar_stocks(
|
||||
request.symbol,
|
||||
request.universe,
|
||||
days=request.days,
|
||||
top_n=request.top_n,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
return ApiResponse(data=result)
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Routes for regulatory data (CFTC, SEC)."""
|
||||
"""Routes for regulatory data (CFTC, SEC, Congress)."""
|
||||
|
||||
from fastapi import APIRouter, Path, Query
|
||||
|
||||
from models import ApiResponse
|
||||
from route_utils import safe, validate_symbol
|
||||
import regulators_service
|
||||
import congress_service
|
||||
|
||||
router = APIRouter(prefix="/api/v1/regulators")
|
||||
|
||||
@@ -49,3 +50,22 @@ async def sec_cik_map(symbol: str = Path(..., min_length=1, max_length=20)):
|
||||
symbol = validate_symbol(symbol)
|
||||
data = await regulators_service.get_cik_map(symbol)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
# --- Congress Trading ---
|
||||
|
||||
|
||||
@router.get("/congress/trades", response_model=ApiResponse)
|
||||
@safe
|
||||
async def congress_trades():
|
||||
"""Recent US congress member stock trades."""
|
||||
data = await congress_service.get_congress_trades()
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/congress/bills", response_model=ApiResponse)
|
||||
@safe
|
||||
async def congress_bills(query: str = Query(..., min_length=1, max_length=200)):
|
||||
"""Search US congress bills by keyword."""
|
||||
data = await congress_service.search_congress_bills(query)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
@@ -9,6 +9,7 @@ from route_utils import safe, validate_symbol
|
||||
import alphavantage_service
|
||||
import finnhub_service
|
||||
import openbb_service
|
||||
import reddit_service
|
||||
|
||||
import logging
|
||||
|
||||
@@ -23,22 +24,111 @@ router = APIRouter(prefix="/api/v1")
|
||||
@router.get("/stock/{symbol}/sentiment", response_model=ApiResponse)
|
||||
@safe
|
||||
async def stock_sentiment(symbol: str = Path(..., min_length=1, max_length=20)):
|
||||
"""Get aggregated sentiment: Alpha Vantage news sentiment + Finnhub analyst data."""
|
||||
"""Aggregated sentiment from all sources with composite score.
|
||||
|
||||
Combines: Alpha Vantage news sentiment, Finnhub analyst data,
|
||||
Reddit mentions, and analyst upgrades into a single composite score.
|
||||
Score range: -1.0 (extreme bearish) to +1.0 (extreme bullish).
|
||||
"""
|
||||
symbol = validate_symbol(symbol)
|
||||
finnhub_data, av_data = await asyncio.gather(
|
||||
finnhub_service.get_sentiment_summary(symbol),
|
||||
|
||||
# Fetch all sources in parallel
|
||||
av_data, finnhub_data, reddit_data, upgrades_data, recs_data = await asyncio.gather(
|
||||
alphavantage_service.get_news_sentiment(symbol, limit=20),
|
||||
finnhub_service.get_sentiment_summary(symbol),
|
||||
reddit_service.get_reddit_sentiment(symbol),
|
||||
openbb_service.get_upgrades_downgrades(symbol, limit=10),
|
||||
finnhub_service.get_recommendation_trends(symbol),
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(finnhub_data, BaseException):
|
||||
logger.exception("Finnhub error", exc_info=finnhub_data)
|
||||
finnhub_data = {}
|
||||
if isinstance(av_data, BaseException):
|
||||
logger.exception("Alpha Vantage error", exc_info=av_data)
|
||||
av_data = {}
|
||||
|
||||
data = {**finnhub_data, "alpha_vantage_sentiment": av_data}
|
||||
return ApiResponse(data=data)
|
||||
def _safe(result, default):
|
||||
return default if isinstance(result, BaseException) else result
|
||||
|
||||
av_data = _safe(av_data, {})
|
||||
finnhub_data = _safe(finnhub_data, {})
|
||||
reddit_data = _safe(reddit_data, {})
|
||||
upgrades_data = _safe(upgrades_data, [])
|
||||
recs_data = _safe(recs_data, [])
|
||||
|
||||
# --- Score each source ---
|
||||
scores: list[tuple[str, float, float]] = [] # (source, score, weight)
|
||||
|
||||
# 1. News sentiment (Alpha Vantage): avg_score ranges ~-0.35 to +0.35
|
||||
if isinstance(av_data, dict) and av_data.get("overall_sentiment"):
|
||||
av_score = av_data["overall_sentiment"].get("avg_score")
|
||||
if av_score is not None:
|
||||
# Normalize to -1..+1 (AV scores are typically -0.35 to +0.35)
|
||||
normalized = max(-1.0, min(1.0, av_score * 2.5))
|
||||
scores.append(("news", round(normalized, 3), 0.25))
|
||||
|
||||
# 2. Analyst recommendations (Finnhub): buy/sell ratio
|
||||
if isinstance(recs_data, list) and recs_data:
|
||||
latest = recs_data[0]
|
||||
total = sum(latest.get(k, 0) for k in ("strongBuy", "buy", "hold", "sell", "strongSell"))
|
||||
if total > 0:
|
||||
bullish = latest.get("strongBuy", 0) + latest.get("buy", 0)
|
||||
bearish = latest.get("sell", 0) + latest.get("strongSell", 0)
|
||||
ratio = (bullish - bearish) / total # -1 to +1
|
||||
scores.append(("analysts", round(ratio, 3), 0.30))
|
||||
|
||||
# 3. Analyst upgrades vs downgrades (yfinance)
|
||||
if isinstance(upgrades_data, list) and upgrades_data:
|
||||
ups = sum(1 for u in upgrades_data if u.get("action") in ("up", "init"))
|
||||
downs = sum(1 for u in upgrades_data if u.get("action") == "down")
|
||||
if len(upgrades_data) > 0:
|
||||
upgrade_score = (ups - downs) / len(upgrades_data)
|
||||
scores.append(("upgrades", round(upgrade_score, 3), 0.20))
|
||||
|
||||
# 4. Reddit buzz (ApeWisdom)
|
||||
if isinstance(reddit_data, dict) and reddit_data.get("found"):
|
||||
mentions = reddit_data.get("mentions_24h", 0)
|
||||
change = reddit_data.get("mentions_change_pct")
|
||||
if change is not None and mentions > 10:
|
||||
# Positive change = bullish buzz, capped at +/- 1
|
||||
reddit_score = max(-1.0, min(1.0, change / 100))
|
||||
scores.append(("reddit", round(reddit_score, 3), 0.25))
|
||||
|
||||
# --- Compute weighted composite ---
|
||||
if scores:
|
||||
total_weight = sum(w for _, _, w in scores)
|
||||
composite = sum(s * w for _, s, w in scores) / total_weight
|
||||
composite = round(composite, 3)
|
||||
else:
|
||||
composite = None
|
||||
|
||||
# Map to label
|
||||
if composite is None:
|
||||
label = "Unknown"
|
||||
elif composite >= 0.5:
|
||||
label = "Strong Bullish"
|
||||
elif composite >= 0.15:
|
||||
label = "Bullish"
|
||||
elif composite >= -0.15:
|
||||
label = "Neutral"
|
||||
elif composite >= -0.5:
|
||||
label = "Bearish"
|
||||
else:
|
||||
label = "Strong Bearish"
|
||||
|
||||
return ApiResponse(data={
|
||||
"symbol": symbol,
|
||||
"composite_score": composite,
|
||||
"composite_label": label,
|
||||
"source_scores": {name: score for name, score, _ in scores},
|
||||
"source_weights": {name: weight for name, _, weight in scores},
|
||||
"details": {
|
||||
"news_sentiment": av_data if isinstance(av_data, dict) else {},
|
||||
"analyst_recommendations": recs_data[0] if isinstance(recs_data, list) and recs_data else {},
|
||||
"recent_upgrades": upgrades_data[:5] if isinstance(upgrades_data, list) else [],
|
||||
"reddit": reddit_data if isinstance(reddit_data, dict) else {},
|
||||
"finnhub_news": (
|
||||
finnhub_data.get("recent_news", [])[:5]
|
||||
if isinstance(finnhub_data, dict)
|
||||
else []
|
||||
),
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@router.get("/stock/{symbol}/news-sentiment", response_model=ApiResponse)
|
||||
@@ -101,3 +191,33 @@ async def stock_upgrades(symbol: str = Path(..., min_length=1, max_length=20)):
|
||||
symbol = validate_symbol(symbol)
|
||||
data = await openbb_service.get_upgrades_downgrades(symbol)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/stock/{symbol}/social-sentiment", response_model=ApiResponse)
|
||||
@safe
|
||||
async def stock_social_sentiment(
|
||||
symbol: str = Path(..., min_length=1, max_length=20),
|
||||
):
|
||||
"""Social media sentiment from Reddit and Twitter (Finnhub)."""
|
||||
symbol = validate_symbol(symbol)
|
||||
data = await finnhub_service.get_social_sentiment(symbol)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/stock/{symbol}/reddit-sentiment", response_model=ApiResponse)
|
||||
@safe
|
||||
async def stock_reddit_sentiment(
|
||||
symbol: str = Path(..., min_length=1, max_length=20),
|
||||
):
|
||||
"""Reddit sentiment: mentions, upvotes, rank on WSB/stocks/investing (free, no key)."""
|
||||
symbol = validate_symbol(symbol)
|
||||
data = await reddit_service.get_reddit_sentiment(symbol)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
@router.get("/discover/reddit-trending", response_model=ApiResponse)
|
||||
@safe
|
||||
async def reddit_trending():
|
||||
"""Top 25 trending stocks on Reddit (WSB, r/stocks, r/investing). Free, no key."""
|
||||
data = await reddit_service.get_reddit_trending()
|
||||
return ApiResponse(data=data)
|
||||
|
||||
442
tests/test_akshare_service.py
Normal file
442
tests/test_akshare_service.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""Unit tests for akshare_service.py - written FIRST (TDD RED phase)."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import akshare_service
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
|
||||
def _make_hist_df(rows: int = 3) -> pd.DataFrame:
|
||||
"""Return a minimal historical DataFrame with Chinese column names."""
|
||||
dates = pd.date_range("2026-01-01", periods=rows, freq="D")
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"日期": dates,
|
||||
"开盘": [10.0] * rows,
|
||||
"收盘": [10.5] * rows,
|
||||
"最高": [11.0] * rows,
|
||||
"最低": [9.5] * rows,
|
||||
"成交量": [1_000_000] * rows,
|
||||
"成交额": [10_500_000.0] * rows,
|
||||
"振幅": [1.5] * rows,
|
||||
"涨跌幅": [0.5] * rows,
|
||||
"涨跌额": [0.05] * rows,
|
||||
"换手率": [0.3] * rows,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _make_spot_df(code: str = "000001", name: str = "平安银行") -> pd.DataFrame:
|
||||
"""Return a minimal real-time quote DataFrame with Chinese column names."""
|
||||
return pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"代码": code,
|
||||
"名称": name,
|
||||
"最新价": 12.34,
|
||||
"涨跌幅": 1.23,
|
||||
"涨跌额": 0.15,
|
||||
"成交量": 500_000,
|
||||
"成交额": 6_170_000.0,
|
||||
"今开": 12.10,
|
||||
"最高": 12.50,
|
||||
"最低": 12.00,
|
||||
"昨收": 12.19,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _make_code_name_df() -> pd.DataFrame:
|
||||
"""Return a minimal code/name mapping DataFrame."""
|
||||
return pd.DataFrame(
|
||||
[
|
||||
{"code": "000001", "name": "平安银行"},
|
||||
{"code": "600519", "name": "贵州茅台"},
|
||||
{"code": "000002", "name": "万科A"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Symbol validation
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestValidateAShare:
|
||||
def test_valid_starts_with_0(self):
|
||||
assert akshare_service.validate_a_share_symbol("000001") is True
|
||||
|
||||
def test_valid_starts_with_3(self):
|
||||
assert akshare_service.validate_a_share_symbol("300001") is True
|
||||
|
||||
def test_valid_starts_with_6(self):
|
||||
assert akshare_service.validate_a_share_symbol("600519") is True
|
||||
|
||||
def test_invalid_starts_with_1(self):
|
||||
assert akshare_service.validate_a_share_symbol("100001") is False
|
||||
|
||||
def test_invalid_too_short(self):
|
||||
assert akshare_service.validate_a_share_symbol("00001") is False
|
||||
|
||||
def test_invalid_too_long(self):
|
||||
assert akshare_service.validate_a_share_symbol("0000011") is False
|
||||
|
||||
def test_invalid_letters(self):
|
||||
assert akshare_service.validate_a_share_symbol("00000A") is False
|
||||
|
||||
def test_invalid_empty(self):
|
||||
assert akshare_service.validate_a_share_symbol("") is False
|
||||
|
||||
|
||||
class TestValidateHKSymbol:
|
||||
def test_valid_five_digits(self):
|
||||
assert akshare_service.validate_hk_symbol("00700") is True
|
||||
|
||||
def test_valid_all_nines(self):
|
||||
assert akshare_service.validate_hk_symbol("99999") is True
|
||||
|
||||
def test_invalid_too_short(self):
|
||||
assert akshare_service.validate_hk_symbol("0070") is False
|
||||
|
||||
def test_invalid_too_long(self):
|
||||
assert akshare_service.validate_hk_symbol("007000") is False
|
||||
|
||||
def test_invalid_letters(self):
|
||||
assert akshare_service.validate_hk_symbol("0070A") is False
|
||||
|
||||
def test_invalid_empty(self):
|
||||
assert akshare_service.validate_hk_symbol("") is False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# _parse_hist_df
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestParseHistDf:
|
||||
def test_returns_list_of_dicts(self):
|
||||
df = _make_hist_df(2)
|
||||
result = akshare_service._parse_hist_df(df)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_keys_are_english(self):
|
||||
df = _make_hist_df(1)
|
||||
result = akshare_service._parse_hist_df(df)
|
||||
row = result[0]
|
||||
assert "date" in row
|
||||
assert "open" in row
|
||||
assert "close" in row
|
||||
assert "high" in row
|
||||
assert "low" in row
|
||||
assert "volume" in row
|
||||
assert "turnover" in row
|
||||
assert "change_percent" in row
|
||||
assert "turnover_rate" in row
|
||||
|
||||
def test_no_chinese_keys_remain(self):
|
||||
df = _make_hist_df(1)
|
||||
result = akshare_service._parse_hist_df(df)
|
||||
row = result[0]
|
||||
for key in row:
|
||||
assert not any(ord(c) > 127 for c in key), f"Non-ASCII key found: {key}"
|
||||
|
||||
def test_date_is_string(self):
|
||||
df = _make_hist_df(1)
|
||||
result = akshare_service._parse_hist_df(df)
|
||||
assert isinstance(result[0]["date"], str)
|
||||
|
||||
def test_values_are_correct(self):
|
||||
df = _make_hist_df(1)
|
||||
result = akshare_service._parse_hist_df(df)
|
||||
assert result[0]["open"] == pytest.approx(10.0)
|
||||
assert result[0]["close"] == pytest.approx(10.5)
|
||||
|
||||
def test_empty_df_returns_empty_list(self):
|
||||
df = pd.DataFrame()
|
||||
result = akshare_service._parse_hist_df(df)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ============================================================
|
||||
# _parse_spot_row
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestParseSpotRow:
|
||||
def test_returns_dict_with_english_keys(self):
|
||||
df = _make_spot_df("000001", "平安银行")
|
||||
result = akshare_service._parse_spot_row(df, "000001")
|
||||
assert result is not None
|
||||
assert "symbol" in result
|
||||
assert "name" in result
|
||||
assert "price" in result
|
||||
|
||||
def test_correct_symbol_extracted(self):
|
||||
df = _make_spot_df("000001")
|
||||
result = akshare_service._parse_spot_row(df, "000001")
|
||||
assert result["symbol"] == "000001"
|
||||
|
||||
def test_returns_none_when_symbol_not_found(self):
|
||||
df = _make_spot_df("000001")
|
||||
result = akshare_service._parse_spot_row(df, "999999")
|
||||
assert result is None
|
||||
|
||||
def test_price_value_correct(self):
|
||||
df = _make_spot_df("600519")
|
||||
df["代码"] = "600519"
|
||||
result = akshare_service._parse_spot_row(df, "600519")
|
||||
assert result["price"] == pytest.approx(12.34)
|
||||
|
||||
def test_all_quote_fields_present(self):
|
||||
df = _make_spot_df("000001")
|
||||
result = akshare_service._parse_spot_row(df, "000001")
|
||||
expected_keys = {
|
||||
"symbol", "name", "price", "change", "change_percent",
|
||||
"volume", "turnover", "open", "high", "low", "prev_close",
|
||||
}
|
||||
assert expected_keys.issubset(set(result.keys()))
|
||||
|
||||
|
||||
# ============================================================
|
||||
# get_a_share_quote
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestGetAShareQuote:
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_zh_a_spot_em")
|
||||
async def test_returns_quote_dict(self, mock_spot):
|
||||
mock_spot.return_value = _make_spot_df("000001")
|
||||
result = await akshare_service.get_a_share_quote("000001")
|
||||
assert result is not None
|
||||
assert result["symbol"] == "000001"
|
||||
assert result["name"] == "平安银行"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_zh_a_spot_em")
|
||||
async def test_returns_none_for_unknown_symbol(self, mock_spot):
|
||||
mock_spot.return_value = _make_spot_df("000001")
|
||||
result = await akshare_service.get_a_share_quote("999999")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_zh_a_spot_em")
|
||||
async def test_propagates_exception(self, mock_spot):
|
||||
mock_spot.side_effect = RuntimeError("AKShare unavailable")
|
||||
with pytest.raises(RuntimeError):
|
||||
await akshare_service.get_a_share_quote("000001")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_zh_a_spot_em")
|
||||
async def test_akshare_called_once(self, mock_spot):
|
||||
mock_spot.return_value = _make_spot_df("000001")
|
||||
await akshare_service.get_a_share_quote("000001")
|
||||
mock_spot.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# get_a_share_historical
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestGetAShareHistorical:
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_zh_a_hist")
|
||||
async def test_returns_list_of_bars(self, mock_hist):
|
||||
mock_hist.return_value = _make_hist_df(5)
|
||||
result = await akshare_service.get_a_share_historical("000001", days=30)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_zh_a_hist")
|
||||
async def test_bars_have_english_keys(self, mock_hist):
|
||||
mock_hist.return_value = _make_hist_df(1)
|
||||
result = await akshare_service.get_a_share_historical("000001", days=30)
|
||||
assert "date" in result[0]
|
||||
assert "open" in result[0]
|
||||
assert "close" in result[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_zh_a_hist")
|
||||
async def test_called_with_correct_symbol(self, mock_hist):
|
||||
mock_hist.return_value = _make_hist_df(1)
|
||||
await akshare_service.get_a_share_historical("600519", days=90)
|
||||
call_kwargs = mock_hist.call_args
|
||||
assert call_kwargs.kwargs.get("symbol") == "600519" or call_kwargs.args[0] == "600519"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_zh_a_hist")
|
||||
async def test_adjust_is_qfq(self, mock_hist):
|
||||
mock_hist.return_value = _make_hist_df(1)
|
||||
await akshare_service.get_a_share_historical("000001", days=30)
|
||||
call_kwargs = mock_hist.call_args
|
||||
assert call_kwargs.kwargs.get("adjust") == "qfq"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_zh_a_hist")
|
||||
async def test_empty_df_returns_empty_list(self, mock_hist):
|
||||
mock_hist.return_value = pd.DataFrame()
|
||||
result = await akshare_service.get_a_share_historical("000001", days=30)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_zh_a_hist")
|
||||
async def test_propagates_exception(self, mock_hist):
|
||||
mock_hist.side_effect = RuntimeError("network error")
|
||||
with pytest.raises(RuntimeError):
|
||||
await akshare_service.get_a_share_historical("000001", days=30)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# search_a_shares
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestSearchAShares:
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_info_a_code_name")
|
||||
async def test_returns_matching_results(self, mock_codes):
|
||||
mock_codes.return_value = _make_code_name_df()
|
||||
result = await akshare_service.search_a_shares("平安")
|
||||
assert len(result) == 1
|
||||
assert result[0]["code"] == "000001"
|
||||
assert result[0]["name"] == "平安银行"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_info_a_code_name")
|
||||
async def test_returns_empty_list_when_no_match(self, mock_codes):
|
||||
mock_codes.return_value = _make_code_name_df()
|
||||
result = await akshare_service.search_a_shares("NONEXISTENT")
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_info_a_code_name")
|
||||
async def test_returns_multiple_matches(self, mock_codes):
|
||||
mock_codes.return_value = _make_code_name_df()
|
||||
result = await akshare_service.search_a_shares("万")
|
||||
assert len(result) == 1
|
||||
assert result[0]["code"] == "000002"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_info_a_code_name")
|
||||
async def test_result_has_code_and_name_keys(self, mock_codes):
|
||||
mock_codes.return_value = _make_code_name_df()
|
||||
result = await akshare_service.search_a_shares("茅台")
|
||||
assert len(result) == 1
|
||||
assert "code" in result[0]
|
||||
assert "name" in result[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_info_a_code_name")
|
||||
async def test_propagates_exception(self, mock_codes):
|
||||
mock_codes.side_effect = RuntimeError("timeout")
|
||||
with pytest.raises(RuntimeError):
|
||||
await akshare_service.search_a_shares("平安")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_info_a_code_name")
|
||||
async def test_empty_query_returns_all(self, mock_codes):
|
||||
mock_codes.return_value = _make_code_name_df()
|
||||
result = await akshare_service.search_a_shares("")
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
# ============================================================
|
||||
# get_hk_quote
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestGetHKQuote:
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_hk_spot_em")
|
||||
async def test_returns_quote_dict(self, mock_spot):
|
||||
mock_spot.return_value = _make_spot_df("00700", "腾讯控股")
|
||||
result = await akshare_service.get_hk_quote("00700")
|
||||
assert result is not None
|
||||
assert result["symbol"] == "00700"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_hk_spot_em")
|
||||
async def test_returns_none_for_unknown_symbol(self, mock_spot):
|
||||
mock_spot.return_value = _make_spot_df("00700")
|
||||
result = await akshare_service.get_hk_quote("99999")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_hk_spot_em")
|
||||
async def test_propagates_exception(self, mock_spot):
|
||||
mock_spot.side_effect = RuntimeError("AKShare unavailable")
|
||||
with pytest.raises(RuntimeError):
|
||||
await akshare_service.get_hk_quote("00700")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_hk_spot_em")
|
||||
async def test_all_fields_present(self, mock_spot):
|
||||
mock_spot.return_value = _make_spot_df("00700", "腾讯控股")
|
||||
result = await akshare_service.get_hk_quote("00700")
|
||||
expected_keys = {
|
||||
"symbol", "name", "price", "change", "change_percent",
|
||||
"volume", "turnover", "open", "high", "low", "prev_close",
|
||||
}
|
||||
assert expected_keys.issubset(set(result.keys()))
|
||||
|
||||
|
||||
# ============================================================
|
||||
# get_hk_historical
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestGetHKHistorical:
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_hk_hist")
|
||||
async def test_returns_list_of_bars(self, mock_hist):
|
||||
mock_hist.return_value = _make_hist_df(4)
|
||||
result = await akshare_service.get_hk_historical("00700", days=30)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_hk_hist")
|
||||
async def test_bars_have_english_keys(self, mock_hist):
|
||||
mock_hist.return_value = _make_hist_df(1)
|
||||
result = await akshare_service.get_hk_historical("00700", days=30)
|
||||
assert "date" in result[0]
|
||||
assert "close" in result[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_hk_hist")
|
||||
async def test_called_with_correct_symbol(self, mock_hist):
|
||||
mock_hist.return_value = _make_hist_df(1)
|
||||
await akshare_service.get_hk_historical("09988", days=90)
|
||||
call_kwargs = mock_hist.call_args
|
||||
assert call_kwargs.kwargs.get("symbol") == "09988" or call_kwargs.args[0] == "09988"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_hk_hist")
|
||||
async def test_adjust_is_qfq(self, mock_hist):
|
||||
mock_hist.return_value = _make_hist_df(1)
|
||||
await akshare_service.get_hk_historical("00700", days=30)
|
||||
call_kwargs = mock_hist.call_args
|
||||
assert call_kwargs.kwargs.get("adjust") == "qfq"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_hk_hist")
|
||||
async def test_empty_df_returns_empty_list(self, mock_hist):
|
||||
mock_hist.return_value = pd.DataFrame()
|
||||
result = await akshare_service.get_hk_historical("00700", days=30)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("akshare_service.ak.stock_hk_hist")
|
||||
async def test_propagates_exception(self, mock_hist):
|
||||
mock_hist.side_effect = RuntimeError("network error")
|
||||
with pytest.raises(RuntimeError):
|
||||
await akshare_service.get_hk_historical("00700", days=30)
|
||||
627
tests/test_backtest_service.py
Normal file
627
tests/test_backtest_service.py
Normal file
@@ -0,0 +1,627 @@
|
||||
"""Unit tests for backtest_service - written FIRST (TDD RED phase)."""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import backtest_service
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_equity(values: list[float]) -> pd.Series:
|
||||
"""Build a simple equity-curve Series from a list of values."""
|
||||
return pd.Series(values, dtype=float)
|
||||
|
||||
|
||||
def _rising_prices(n: int = 100, start: float = 100.0, step: float = 1.0) -> pd.Series:
|
||||
"""Linearly rising price series."""
|
||||
return pd.Series([start + i * step for i in range(n)], dtype=float)
|
||||
|
||||
|
||||
def _flat_prices(n: int = 100, price: float = 100.0) -> pd.Series:
|
||||
"""Flat price series - no movement."""
|
||||
return pd.Series([price] * n, dtype=float)
|
||||
|
||||
|
||||
def _oscillating_prices(n: int = 200, period: int = 40) -> pd.Series:
|
||||
"""Sinusoidal price series to generate crossover signals."""
|
||||
t = np.arange(n)
|
||||
prices = 100 + 20 * np.sin(2 * np.pi * t / period)
|
||||
return pd.Series(prices, dtype=float)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _compute_metrics tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeMetrics:
|
||||
def test_total_return_positive(self):
|
||||
equity = _make_equity([10000, 11000, 12000])
|
||||
result = backtest_service._compute_metrics(equity, trades=1)
|
||||
assert result["total_return"] == pytest.approx(0.2, abs=1e-6)
|
||||
|
||||
def test_total_return_negative(self):
|
||||
equity = _make_equity([10000, 9000, 8000])
|
||||
result = backtest_service._compute_metrics(equity, trades=1)
|
||||
assert result["total_return"] == pytest.approx(-0.2, abs=1e-6)
|
||||
|
||||
def test_total_return_zero_on_flat(self):
|
||||
equity = _make_equity([10000, 10000, 10000])
|
||||
result = backtest_service._compute_metrics(equity, trades=0)
|
||||
assert result["total_return"] == pytest.approx(0.0, abs=1e-6)
|
||||
|
||||
def test_annualized_return_shape(self):
|
||||
# 252 daily bars => 1 trading year; 10000 -> 11000 = +10% annualized
|
||||
values = [10000 * (1.0 + 0.1 / 252) ** i for i in range(253)]
|
||||
equity = _make_equity(values)
|
||||
result = backtest_service._compute_metrics(equity, trades=5)
|
||||
# Should be close to 10% annualized
|
||||
assert result["annualized_return"] == pytest.approx(0.1, abs=0.01)
|
||||
|
||||
def test_sharpe_ratio_positive_drift(self):
|
||||
# Steadily rising equity with small daily increments -> positive Sharpe
|
||||
values = [10000 + i * 10 for i in range(252)]
|
||||
equity = _make_equity(values)
|
||||
result = backtest_service._compute_metrics(equity, trades=5)
|
||||
assert result["sharpe_ratio"] > 0
|
||||
|
||||
def test_sharpe_ratio_none_on_single_point(self):
|
||||
equity = _make_equity([10000])
|
||||
result = backtest_service._compute_metrics(equity, trades=0)
|
||||
assert result["sharpe_ratio"] is None
|
||||
|
||||
def test_sharpe_ratio_none_on_zero_std(self):
|
||||
# Perfectly flat equity => std = 0, Sharpe undefined
|
||||
equity = _make_equity([10000] * 50)
|
||||
result = backtest_service._compute_metrics(equity, trades=0)
|
||||
assert result["sharpe_ratio"] is None
|
||||
|
||||
def test_max_drawdown_known_value(self):
|
||||
# Peak 12000, trough 8000 => drawdown = (8000-12000)/12000 = -1/3
|
||||
equity = _make_equity([10000, 12000, 8000, 9000])
|
||||
result = backtest_service._compute_metrics(equity, trades=2)
|
||||
assert result["max_drawdown"] == pytest.approx(-1 / 3, abs=1e-6)
|
||||
|
||||
def test_max_drawdown_zero_on_monotone_rise(self):
|
||||
equity = _make_equity([10000, 11000, 12000, 13000])
|
||||
result = backtest_service._compute_metrics(equity, trades=1)
|
||||
assert result["max_drawdown"] == pytest.approx(0.0, abs=1e-6)
|
||||
|
||||
def test_total_trades_propagated(self):
|
||||
equity = _make_equity([10000, 11000])
|
||||
result = backtest_service._compute_metrics(equity, trades=7)
|
||||
assert result["total_trades"] == 7
|
||||
|
||||
def test_win_rate_zero_trades(self):
|
||||
equity = _make_equity([10000, 10000])
|
||||
result = backtest_service._compute_metrics(equity, trades=0)
|
||||
assert result["win_rate"] is None
|
||||
|
||||
def test_equity_curve_last_20_points(self):
|
||||
values = list(range(100, 160)) # 60 points
|
||||
equity = _make_equity(values)
|
||||
result = backtest_service._compute_metrics(equity, trades=10)
|
||||
assert len(result["equity_curve"]) == 20
|
||||
assert result["equity_curve"][-1] == pytest.approx(159.0, abs=1e-6)
|
||||
|
||||
def test_equity_curve_shorter_than_20(self):
|
||||
values = [10000, 11000, 12000]
|
||||
equity = _make_equity(values)
|
||||
result = backtest_service._compute_metrics(equity, trades=1)
|
||||
assert len(result["equity_curve"]) == 3
|
||||
|
||||
def test_result_keys_present(self):
|
||||
equity = _make_equity([10000, 11000])
|
||||
result = backtest_service._compute_metrics(equity, trades=1)
|
||||
expected_keys = {
|
||||
"total_return",
|
||||
"annualized_return",
|
||||
"sharpe_ratio",
|
||||
"max_drawdown",
|
||||
"win_rate",
|
||||
"total_trades",
|
||||
"equity_curve",
|
||||
}
|
||||
assert expected_keys.issubset(result.keys())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _compute_sma_signals tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeSmaSignals:
|
||||
def test_returns_series_with_position_column(self):
|
||||
prices = _oscillating_prices(200, period=40)
|
||||
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
|
||||
assert isinstance(positions, pd.Series)
|
||||
assert len(positions) == len(prices)
|
||||
|
||||
def test_positions_are_zero_or_one(self):
|
||||
prices = _oscillating_prices(200, period=40)
|
||||
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
|
||||
unique_vals = set(positions.dropna().unique())
|
||||
assert unique_vals.issubset({0, 1})
|
||||
|
||||
def test_no_position_before_long_window(self):
|
||||
prices = _oscillating_prices(200, period=40)
|
||||
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
|
||||
# Before long_window-1 data points, positions should be 0
|
||||
assert (positions.iloc[: 19] == 0).all()
|
||||
|
||||
def test_generates_at_least_one_signal_on_oscillating(self):
|
||||
prices = _oscillating_prices(300, period=60)
|
||||
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
|
||||
# Should flip between 0 and 1 at least once on oscillating data
|
||||
changes = positions.diff().abs().sum()
|
||||
assert changes > 0
|
||||
|
||||
def test_flat_prices_produce_no_signals(self):
|
||||
prices = _flat_prices(100)
|
||||
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
|
||||
# After warm-up both SMAs equal price; short never strictly above long
|
||||
assert (positions == 0).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _compute_rsi tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeRsi:
|
||||
def test_rsi_length(self):
|
||||
prices = _rising_prices(50)
|
||||
rsi = backtest_service._compute_rsi(prices, period=14)
|
||||
assert len(rsi) == len(prices)
|
||||
|
||||
def test_rsi_range(self):
|
||||
prices = _oscillating_prices(100, period=20)
|
||||
rsi = backtest_service._compute_rsi(prices, period=14)
|
||||
valid = rsi.dropna()
|
||||
assert (valid >= 0).all()
|
||||
assert (valid <= 100).all()
|
||||
|
||||
def test_rsi_rising_prices_high(self):
|
||||
# Monotonically rising prices => RSI should be high (>= 70)
|
||||
prices = _rising_prices(80, step=1.0)
|
||||
rsi = backtest_service._compute_rsi(prices, period=14)
|
||||
# After warm-up period, RSI should be very high
|
||||
assert rsi.iloc[-1] >= 70
|
||||
|
||||
def test_rsi_falling_prices_low(self):
|
||||
# Monotonically falling prices => RSI should be low (<= 30)
|
||||
prices = pd.Series([100 - i * 0.8 for i in range(80)], dtype=float)
|
||||
rsi = backtest_service._compute_rsi(prices, period=14)
|
||||
assert rsi.iloc[-1] <= 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _compute_rsi_signals tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeRsiSignals:
|
||||
def test_returns_series(self):
|
||||
prices = _oscillating_prices(200, period=40)
|
||||
positions = backtest_service._compute_rsi_signals(
|
||||
prices, period=14, oversold=30, overbought=70
|
||||
)
|
||||
assert isinstance(positions, pd.Series)
|
||||
assert len(positions) == len(prices)
|
||||
|
||||
def test_positions_are_zero_or_one(self):
|
||||
prices = _oscillating_prices(200, period=40)
|
||||
positions = backtest_service._compute_rsi_signals(
|
||||
prices, period=14, oversold=30, overbought=70
|
||||
)
|
||||
unique_vals = set(positions.dropna().unique())
|
||||
assert unique_vals.issubset({0, 1})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# backtest_sma_crossover tests (async integration of service layer)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBacktestSmaCrossover:
|
||||
@pytest.fixture
|
||||
def mock_hist(self, monkeypatch):
|
||||
"""Patch fetch_historical to return a synthetic OBBject-like result."""
|
||||
prices = _oscillating_prices(300, period=60).tolist()
|
||||
|
||||
class FakeBar:
|
||||
def __init__(self, close):
|
||||
self.close = close
|
||||
|
||||
class FakeResult:
|
||||
results = [FakeBar(p) for p in prices]
|
||||
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
return FakeResult()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_all_required_keys(self, mock_hist):
|
||||
result = await backtest_service.backtest_sma_crossover(
|
||||
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
|
||||
)
|
||||
required = {
|
||||
"total_return",
|
||||
"annualized_return",
|
||||
"sharpe_ratio",
|
||||
"max_drawdown",
|
||||
"win_rate",
|
||||
"total_trades",
|
||||
"equity_curve",
|
||||
}
|
||||
assert required.issubset(result.keys())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_equity_curve_max_20_points(self, mock_hist):
|
||||
result = await backtest_service.backtest_sma_crossover(
|
||||
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
|
||||
)
|
||||
assert len(result["equity_curve"]) <= 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_value_error_on_no_data(self, monkeypatch):
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
with pytest.raises(ValueError, match="No historical data"):
|
||||
await backtest_service.backtest_sma_crossover(
|
||||
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_capital_reflected_in_equity(self, mock_hist):
|
||||
result = await backtest_service.backtest_sma_crossover(
|
||||
"AAPL", short_window=5, long_window=20, days=365, initial_capital=50000
|
||||
)
|
||||
# equity_curve values should be in range related to 50000 initial capital
|
||||
assert result["equity_curve"][0] > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# backtest_rsi tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBacktestRsi:
|
||||
@pytest.fixture
|
||||
def mock_hist(self, monkeypatch):
|
||||
prices = _oscillating_prices(300, period=60).tolist()
|
||||
|
||||
class FakeBar:
|
||||
def __init__(self, close):
|
||||
self.close = close
|
||||
|
||||
class FakeResult:
|
||||
results = [FakeBar(p) for p in prices]
|
||||
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_all_required_keys(self, mock_hist):
|
||||
result = await backtest_service.backtest_rsi(
|
||||
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
|
||||
)
|
||||
required = {
|
||||
"total_return",
|
||||
"annualized_return",
|
||||
"sharpe_ratio",
|
||||
"max_drawdown",
|
||||
"win_rate",
|
||||
"total_trades",
|
||||
"equity_curve",
|
||||
}
|
||||
assert required.issubset(result.keys())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_equity_curve_max_20_points(self, mock_hist):
|
||||
result = await backtest_service.backtest_rsi(
|
||||
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
|
||||
)
|
||||
assert len(result["equity_curve"]) <= 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_value_error_on_no_data(self, monkeypatch):
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
with pytest.raises(ValueError, match="No historical data"):
|
||||
await backtest_service.backtest_rsi(
|
||||
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# backtest_buy_and_hold tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBacktestBuyAndHold:
|
||||
@pytest.fixture
|
||||
def mock_hist_rising(self, monkeypatch):
|
||||
prices = _rising_prices(252, start=100.0, step=1.0).tolist()
|
||||
|
||||
class FakeBar:
|
||||
def __init__(self, close):
|
||||
self.close = close
|
||||
|
||||
class FakeResult:
|
||||
results = [FakeBar(p) for p in prices]
|
||||
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_all_required_keys(self, mock_hist_rising):
|
||||
result = await backtest_service.backtest_buy_and_hold(
|
||||
"AAPL", days=365, initial_capital=10000
|
||||
)
|
||||
required = {
|
||||
"total_return",
|
||||
"annualized_return",
|
||||
"sharpe_ratio",
|
||||
"max_drawdown",
|
||||
"win_rate",
|
||||
"total_trades",
|
||||
"equity_curve",
|
||||
}
|
||||
assert required.issubset(result.keys())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_total_trades_always_one(self, mock_hist_rising):
|
||||
result = await backtest_service.backtest_buy_and_hold(
|
||||
"AAPL", days=365, initial_capital=10000
|
||||
)
|
||||
assert result["total_trades"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rising_prices_positive_return(self, mock_hist_rising):
|
||||
result = await backtest_service.backtest_buy_and_hold(
|
||||
"AAPL", days=365, initial_capital=10000
|
||||
)
|
||||
assert result["total_return"] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_known_return_value(self, monkeypatch):
|
||||
# 100 -> 200: 100% total return
|
||||
prices = [100.0, 200.0]
|
||||
|
||||
class FakeBar:
|
||||
def __init__(self, close):
|
||||
self.close = close
|
||||
|
||||
class FakeResult:
|
||||
results = [FakeBar(p) for p in prices]
|
||||
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
result = await backtest_service.backtest_buy_and_hold(
|
||||
"AAPL", days=365, initial_capital=10000
|
||||
)
|
||||
assert result["total_return"] == pytest.approx(1.0, abs=1e-6)
|
||||
assert result["equity_curve"][-1] == pytest.approx(20000.0, abs=1e-6)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_value_error_on_no_data(self, monkeypatch):
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
with pytest.raises(ValueError, match="No historical data"):
|
||||
await backtest_service.backtest_buy_and_hold("AAPL", days=365, initial_capital=10000)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flat_prices_zero_return(self, monkeypatch):
|
||||
prices = _flat_prices(50).tolist()
|
||||
|
||||
class FakeBar:
|
||||
def __init__(self, close):
|
||||
self.close = close
|
||||
|
||||
class FakeResult:
|
||||
results = [FakeBar(p) for p in prices]
|
||||
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
result = await backtest_service.backtest_buy_and_hold(
|
||||
"AAPL", days=365, initial_capital=10000
|
||||
)
|
||||
assert result["total_return"] == pytest.approx(0.0, abs=1e-6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# backtest_momentum tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBacktestMomentum:
|
||||
@pytest.fixture
|
||||
def mock_multi_hist(self, monkeypatch):
|
||||
"""Three symbols with different return profiles."""
|
||||
aapl_prices = _rising_prices(200, start=100.0, step=2.0).tolist()
|
||||
msft_prices = _rising_prices(200, start=100.0, step=0.5).tolist()
|
||||
googl_prices = _flat_prices(200, price=150.0).tolist()
|
||||
|
||||
price_map = {
|
||||
"AAPL": aapl_prices,
|
||||
"MSFT": msft_prices,
|
||||
"GOOGL": googl_prices,
|
||||
}
|
||||
|
||||
class FakeBar:
|
||||
def __init__(self, close):
|
||||
self.close = close
|
||||
|
||||
class FakeResult:
|
||||
def __init__(self, prices):
|
||||
self.results = [FakeBar(p) for p in prices]
|
||||
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return FakeResult(price_map[symbol])
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_all_required_keys(self, mock_multi_hist):
|
||||
result = await backtest_service.backtest_momentum(
|
||||
symbols=["AAPL", "MSFT", "GOOGL"],
|
||||
lookback=20,
|
||||
top_n=2,
|
||||
rebalance_days=30,
|
||||
days=365,
|
||||
initial_capital=10000,
|
||||
)
|
||||
required = {
|
||||
"total_return",
|
||||
"annualized_return",
|
||||
"sharpe_ratio",
|
||||
"max_drawdown",
|
||||
"win_rate",
|
||||
"total_trades",
|
||||
"equity_curve",
|
||||
"allocation_history",
|
||||
}
|
||||
assert required.issubset(result.keys())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allocation_history_is_list(self, mock_multi_hist):
|
||||
result = await backtest_service.backtest_momentum(
|
||||
symbols=["AAPL", "MSFT", "GOOGL"],
|
||||
lookback=20,
|
||||
top_n=2,
|
||||
rebalance_days=30,
|
||||
days=365,
|
||||
initial_capital=10000,
|
||||
)
|
||||
assert isinstance(result["allocation_history"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_n_respected_in_allocations(self, mock_multi_hist):
|
||||
result = await backtest_service.backtest_momentum(
|
||||
symbols=["AAPL", "MSFT", "GOOGL"],
|
||||
lookback=20,
|
||||
top_n=2,
|
||||
rebalance_days=30,
|
||||
days=365,
|
||||
initial_capital=10000,
|
||||
)
|
||||
for entry in result["allocation_history"]:
|
||||
assert len(entry["symbols"]) <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_value_error_on_no_data(self, monkeypatch):
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
with pytest.raises(ValueError, match="No price data"):
|
||||
await backtest_service.backtest_momentum(
|
||||
symbols=["AAPL", "MSFT"],
|
||||
lookback=20,
|
||||
top_n=1,
|
||||
rebalance_days=30,
|
||||
days=365,
|
||||
initial_capital=10000,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_equity_curve_max_20_points(self, mock_multi_hist):
|
||||
result = await backtest_service.backtest_momentum(
|
||||
symbols=["AAPL", "MSFT", "GOOGL"],
|
||||
lookback=20,
|
||||
top_n=2,
|
||||
rebalance_days=30,
|
||||
days=365,
|
||||
initial_capital=10000,
|
||||
)
|
||||
assert len(result["equity_curve"]) <= 20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge case: insufficient data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sma_crossover_insufficient_bars_raises(self, monkeypatch):
|
||||
"""Fewer bars than long_window should raise ValueError."""
|
||||
prices = [100.0, 101.0, 102.0] # Only 3 bars
|
||||
|
||||
class FakeBar:
|
||||
def __init__(self, close):
|
||||
self.close = close
|
||||
|
||||
class FakeResult:
|
||||
results = [FakeBar(p) for p in prices]
|
||||
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
with pytest.raises(ValueError, match="Insufficient data"):
|
||||
await backtest_service.backtest_sma_crossover(
|
||||
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rsi_insufficient_bars_raises(self, monkeypatch):
|
||||
prices = [100.0, 101.0]
|
||||
|
||||
class FakeBar:
|
||||
def __init__(self, close):
|
||||
self.close = close
|
||||
|
||||
class FakeResult:
|
||||
results = [FakeBar(p) for p in prices]
|
||||
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
with pytest.raises(ValueError, match="Insufficient data"):
|
||||
await backtest_service.backtest_rsi(
|
||||
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buy_and_hold_single_bar_raises(self, monkeypatch):
|
||||
prices = [100.0]
|
||||
|
||||
class FakeBar:
|
||||
def __init__(self, close):
|
||||
self.close = close
|
||||
|
||||
class FakeResult:
|
||||
results = [FakeBar(p) for p in prices]
|
||||
|
||||
async def fake_fetch(symbol, days, **kwargs):
|
||||
return FakeResult()
|
||||
|
||||
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||
with pytest.raises(ValueError, match="Insufficient data"):
|
||||
await backtest_service.backtest_buy_and_hold(
|
||||
"AAPL", days=365, initial_capital=10000
|
||||
)
|
||||
187
tests/test_congress_service.py
Normal file
187
tests/test_congress_service.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Tests for congress trading service (TDD - RED phase first)."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# --- get_congress_trades ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_congress_trades_happy_path():
|
||||
"""Returns list of trade dicts when OBB call succeeds."""
|
||||
expected = [
|
||||
{
|
||||
"representative": "Nancy Pelosi",
|
||||
"ticker": "NVDA",
|
||||
"transaction_date": "2024-01-15",
|
||||
"transaction_type": "Purchase",
|
||||
"amount": "$1,000,001-$5,000,000",
|
||||
}
|
||||
]
|
||||
|
||||
import congress_service
|
||||
|
||||
mock_fn = MagicMock()
|
||||
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
|
||||
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=expected):
|
||||
result = await congress_service.get_congress_trades()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0]["representative"] == "Nancy Pelosi"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_congress_trades_returns_empty_when_fn_not_available():
|
||||
"""Returns empty list when OBB congress function is not available."""
|
||||
import congress_service
|
||||
|
||||
with patch.object(congress_service, "_get_congress_fn", return_value=None):
|
||||
result = await congress_service.get_congress_trades()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_congress_trades_returns_empty_on_all_provider_failures():
|
||||
"""Returns empty list when all providers fail (_try_obb_call returns None)."""
|
||||
import congress_service
|
||||
|
||||
mock_fn = MagicMock()
|
||||
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
|
||||
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=None):
|
||||
result = await congress_service.get_congress_trades()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_congress_trades_empty_list_result():
|
||||
"""Returns empty list when _try_obb_call returns empty list."""
|
||||
import congress_service
|
||||
|
||||
mock_fn = MagicMock()
|
||||
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
|
||||
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=[]):
|
||||
result = await congress_service.get_congress_trades()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
# --- _get_congress_fn ---
|
||||
|
||||
|
||||
def test_get_congress_fn_returns_none_when_attribute_missing():
|
||||
"""Returns None gracefully when obb.regulators.government_us is absent."""
|
||||
import congress_service
|
||||
|
||||
mock_obb = MagicMock(spec=[]) # spec with no attributes
|
||||
with patch.object(congress_service, "obb", mock_obb):
|
||||
result = congress_service._get_congress_fn()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_congress_fn_returns_callable_when_available():
|
||||
"""Returns the congress_trading callable when attribute exists."""
|
||||
import congress_service
|
||||
|
||||
mock_fn = MagicMock()
|
||||
mock_obb = MagicMock()
|
||||
mock_obb.regulators.government_us.congress_trading = mock_fn
|
||||
|
||||
with patch.object(congress_service, "obb", mock_obb):
|
||||
result = congress_service._get_congress_fn()
|
||||
|
||||
assert result is mock_fn
|
||||
|
||||
|
||||
# --- _try_obb_call ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_try_obb_call_returns_list_on_success():
|
||||
"""_try_obb_call converts OBBject result to list via to_list."""
|
||||
import congress_service
|
||||
|
||||
mock_result = MagicMock()
|
||||
expected = [{"ticker": "AAPL"}]
|
||||
|
||||
with patch.object(congress_service, "to_list", return_value=expected), \
|
||||
patch("congress_service.asyncio.to_thread", new_callable=AsyncMock, return_value=mock_result):
|
||||
result = await congress_service._try_obb_call(MagicMock())
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_try_obb_call_returns_none_on_exception():
|
||||
"""_try_obb_call returns None when asyncio.to_thread raises."""
|
||||
import congress_service
|
||||
|
||||
with patch("congress_service.asyncio.to_thread", new_callable=AsyncMock, side_effect=Exception("fail")):
|
||||
result = await congress_service._try_obb_call(MagicMock())
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- search_congress_bills ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_congress_bills_happy_path():
|
||||
"""Returns list of bill dicts when OBB call succeeds."""
|
||||
expected = [
|
||||
{"title": "Infrastructure Investment and Jobs Act", "bill_id": "HR3684"},
|
||||
{"title": "Inflation Reduction Act", "bill_id": "HR5376"},
|
||||
]
|
||||
|
||||
import congress_service
|
||||
|
||||
mock_fn = MagicMock()
|
||||
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
|
||||
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=expected):
|
||||
result = await congress_service.search_congress_bills("infrastructure")
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert result[0]["bill_id"] == "HR3684"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_congress_bills_returns_empty_when_fn_not_available():
|
||||
"""Returns empty list when OBB function is not available."""
|
||||
import congress_service
|
||||
|
||||
with patch.object(congress_service, "_get_congress_fn", return_value=None):
|
||||
result = await congress_service.search_congress_bills("taxes")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_congress_bills_returns_empty_on_failure():
|
||||
"""Returns empty list when all providers fail."""
|
||||
import congress_service
|
||||
|
||||
mock_fn = MagicMock()
|
||||
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
|
||||
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=None):
|
||||
result = await congress_service.search_congress_bills("taxes")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_congress_bills_empty_results():
|
||||
"""Returns empty list when _try_obb_call returns empty list."""
|
||||
import congress_service
|
||||
|
||||
mock_fn = MagicMock()
|
||||
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
|
||||
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=[]):
|
||||
result = await congress_service.search_congress_bills("nonexistent")
|
||||
|
||||
assert result == []
|
||||
655
tests/test_defi_service.py
Normal file
655
tests/test_defi_service.py
Normal file
@@ -0,0 +1,655 @@
|
||||
"""Tests for defi_service.py - DefiLlama API integration.
|
||||
|
||||
TDD: these tests are written before implementation.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from defi_service import (
|
||||
get_chain_tvls,
|
||||
get_dex_volumes,
|
||||
get_protocol_fees,
|
||||
get_protocol_tvl,
|
||||
get_stablecoins,
|
||||
get_top_protocols,
|
||||
get_yield_pools,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_PROTOCOLS = [
|
||||
{
|
||||
"name": "Aave",
|
||||
"symbol": "AAVE",
|
||||
"tvl": 10_000_000_000.0,
|
||||
"chain": "Ethereum",
|
||||
"chains": ["Ethereum", "Polygon"],
|
||||
"category": "Lending",
|
||||
"change_1d": 0.5,
|
||||
"change_7d": -1.2,
|
||||
},
|
||||
{
|
||||
"name": "Uniswap",
|
||||
"symbol": "UNI",
|
||||
"tvl": 8_000_000_000.0,
|
||||
"chain": "Ethereum",
|
||||
"chains": ["Ethereum"],
|
||||
"category": "DEX",
|
||||
"change_1d": 1.0,
|
||||
"change_7d": 2.0,
|
||||
},
|
||||
]
|
||||
|
||||
SAMPLE_CHAINS = [
|
||||
{"name": "Ethereum", "tvl": 50_000_000_000.0, "tokenSymbol": "ETH"},
|
||||
{"name": "BSC", "tvl": 5_000_000_000.0, "tokenSymbol": "BNB"},
|
||||
]
|
||||
|
||||
SAMPLE_POOLS = [
|
||||
{
|
||||
"pool": "0xabcd",
|
||||
"chain": "Ethereum",
|
||||
"project": "aave-v3",
|
||||
"symbol": "USDC",
|
||||
"tvlUsd": 1_000_000_000.0,
|
||||
"apy": 3.5,
|
||||
"apyBase": 3.0,
|
||||
"apyReward": 0.5,
|
||||
},
|
||||
{
|
||||
"pool": "0x1234",
|
||||
"chain": "Polygon",
|
||||
"project": "curve",
|
||||
"symbol": "DAI",
|
||||
"tvlUsd": 500_000_000.0,
|
||||
"apy": 4.2,
|
||||
"apyBase": 2.5,
|
||||
"apyReward": 1.7,
|
||||
},
|
||||
]
|
||||
|
||||
SAMPLE_STABLECOINS_RESPONSE = {
|
||||
"peggedAssets": [
|
||||
{
|
||||
"name": "Tether",
|
||||
"symbol": "USDT",
|
||||
"pegType": "peggedUSD",
|
||||
"circulating": {"peggedUSD": 100_000_000_000.0},
|
||||
"price": 1.0,
|
||||
},
|
||||
{
|
||||
"name": "USD Coin",
|
||||
"symbol": "USDC",
|
||||
"pegType": "peggedUSD",
|
||||
"circulating": {"peggedUSD": 40_000_000_000.0},
|
||||
"price": 1.0,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
SAMPLE_DEX_RESPONSE = {
|
||||
"total24h": 5_000_000_000.0,
|
||||
"total7d": 30_000_000_000.0,
|
||||
"protocols": [
|
||||
{"name": "Uniswap", "total24h": 2_000_000_000.0},
|
||||
{"name": "Curve", "total24h": 500_000_000.0},
|
||||
],
|
||||
}
|
||||
|
||||
SAMPLE_FEES_RESPONSE = {
|
||||
"protocols": [
|
||||
{"name": "Uniswap", "total24h": 1_000_000.0, "revenue24h": 500_000.0},
|
||||
{"name": "Aave", "total24h": 800_000.0, "revenue24h": 800_000.0},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _make_mock_client(json_return):
|
||||
"""Build a fully configured AsyncMock for httpx.AsyncClient context manager.
|
||||
|
||||
Uses MagicMock for the response object so that resp.json() (a sync method
|
||||
in httpx) returns the value directly rather than a coroutine.
|
||||
"""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = json_return
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
return mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_top_protocols
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_top_protocols_returns_top_20(mock_client_cls):
|
||||
raw = [
|
||||
{
|
||||
"name": f"Protocol{i}",
|
||||
"symbol": f"P{i}",
|
||||
"tvl": float(100 - i),
|
||||
"chain": "Ethereum",
|
||||
"chains": ["Ethereum"],
|
||||
"category": "Lending",
|
||||
"change_1d": 0.1,
|
||||
"change_7d": 0.2,
|
||||
}
|
||||
for i in range(30)
|
||||
]
|
||||
mock_client_cls.return_value = _make_mock_client(raw)
|
||||
|
||||
result = await get_top_protocols()
|
||||
|
||||
assert len(result) == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_top_protocols_returns_correct_fields(mock_client_cls):
|
||||
mock_client_cls.return_value = _make_mock_client(SAMPLE_PROTOCOLS)
|
||||
|
||||
result = await get_top_protocols()
|
||||
|
||||
assert len(result) == 2
|
||||
first = result[0]
|
||||
assert first["name"] == "Aave"
|
||||
assert first["symbol"] == "AAVE"
|
||||
assert first["tvl"] == 10_000_000_000.0
|
||||
assert first["chain"] == "Ethereum"
|
||||
assert first["chains"] == ["Ethereum", "Polygon"]
|
||||
assert first["category"] == "Lending"
|
||||
assert first["change_1d"] == 0.5
|
||||
assert first["change_7d"] == -1.2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_top_protocols_respects_custom_limit(mock_client_cls):
|
||||
raw = [
|
||||
{"name": f"P{i}", "symbol": f"S{i}", "tvl": float(i), "chain": "ETH",
|
||||
"chains": [], "category": "DEX", "change_1d": 0.0, "change_7d": 0.0}
|
||||
for i in range(25)
|
||||
]
|
||||
mock_client_cls.return_value = _make_mock_client(raw)
|
||||
|
||||
result = await get_top_protocols(limit=5)
|
||||
|
||||
assert len(result) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_top_protocols_handles_missing_fields(mock_client_cls):
|
||||
raw = [{"name": "Sparse"}] # missing all optional fields
|
||||
mock_client_cls.return_value = _make_mock_client(raw)
|
||||
|
||||
result = await get_top_protocols()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Sparse"
|
||||
assert result[0]["tvl"] is None
|
||||
assert result[0]["symbol"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_top_protocols_returns_empty_on_http_error(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status.side_effect = Exception("HTTP 500")
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await get_top_protocols()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_top_protocols_calls_correct_url(mock_client_cls):
|
||||
mock_client = _make_mock_client([])
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
await get_top_protocols()
|
||||
|
||||
mock_client.get.assert_called_once_with("https://api.llama.fi/protocols")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_chain_tvls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_chain_tvls_returns_all_chains(mock_client_cls):
|
||||
mock_client_cls.return_value = _make_mock_client(SAMPLE_CHAINS)
|
||||
|
||||
result = await get_chain_tvls()
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "Ethereum"
|
||||
assert result[0]["tvl"] == 50_000_000_000.0
|
||||
assert result[0]["tokenSymbol"] == "ETH"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_chain_tvls_handles_missing_token_symbol(mock_client_cls):
|
||||
raw = [{"name": "SomeChain", "tvl": 100.0}]
|
||||
mock_client_cls.return_value = _make_mock_client(raw)
|
||||
|
||||
result = await get_chain_tvls()
|
||||
|
||||
assert result[0]["tokenSymbol"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_chain_tvls_returns_empty_on_error(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status.side_effect = Exception("timeout")
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await get_chain_tvls()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_chain_tvls_calls_correct_url(mock_client_cls):
|
||||
mock_client = _make_mock_client([])
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
await get_chain_tvls()
|
||||
|
||||
mock_client.get.assert_called_once_with("https://api.llama.fi/v2/chains")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_protocol_tvl
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_protocol_tvl_returns_numeric_value(mock_client_cls):
|
||||
mock_client_cls.return_value = _make_mock_client(10_000_000_000.0)
|
||||
|
||||
result = await get_protocol_tvl("aave")
|
||||
|
||||
assert result == 10_000_000_000.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_protocol_tvl_calls_correct_url(mock_client_cls):
|
||||
mock_client = _make_mock_client(1234.0)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
await get_protocol_tvl("uniswap")
|
||||
|
||||
mock_client.get.assert_called_once_with("https://api.llama.fi/tvl/uniswap")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_protocol_tvl_returns_none_on_error(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status.side_effect = Exception("404 not found")
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await get_protocol_tvl("nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_yield_pools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_yield_pools_returns_top_20_by_tvl(mock_client_cls):
|
||||
# Create 30 pools with decreasing TVL
|
||||
raw_data = {
|
||||
"data": [
|
||||
{
|
||||
"pool": f"pool{i}",
|
||||
"chain": "Ethereum",
|
||||
"project": "aave",
|
||||
"symbol": "USDC",
|
||||
"tvlUsd": float(1000 - i),
|
||||
"apy": 3.0,
|
||||
"apyBase": 2.5,
|
||||
"apyReward": 0.5,
|
||||
}
|
||||
for i in range(30)
|
||||
]
|
||||
}
|
||||
mock_client_cls.return_value = _make_mock_client(raw_data)
|
||||
|
||||
result = await get_yield_pools()
|
||||
|
||||
assert len(result) == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_yield_pools_returns_correct_fields(mock_client_cls):
|
||||
mock_client_cls.return_value = _make_mock_client({"data": SAMPLE_POOLS})
|
||||
|
||||
result = await get_yield_pools()
|
||||
|
||||
assert len(result) == 2
|
||||
first = result[0]
|
||||
assert first["pool"] == "0xabcd"
|
||||
assert first["chain"] == "Ethereum"
|
||||
assert first["project"] == "aave-v3"
|
||||
assert first["symbol"] == "USDC"
|
||||
assert first["tvlUsd"] == 1_000_000_000.0
|
||||
assert first["apy"] == 3.5
|
||||
assert first["apyBase"] == 3.0
|
||||
assert first["apyReward"] == 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_yield_pools_filters_by_chain(mock_client_cls):
|
||||
mock_client_cls.return_value = _make_mock_client({"data": SAMPLE_POOLS})
|
||||
|
||||
result = await get_yield_pools(chain="Ethereum")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["chain"] == "Ethereum"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_yield_pools_filters_by_project(mock_client_cls):
|
||||
mock_client_cls.return_value = _make_mock_client({"data": SAMPLE_POOLS})
|
||||
|
||||
result = await get_yield_pools(project="curve")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["project"] == "curve"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_yield_pools_filters_by_chain_and_project(mock_client_cls):
|
||||
pools = SAMPLE_POOLS + [
|
||||
{
|
||||
"pool": "0xzzzz",
|
||||
"chain": "Ethereum",
|
||||
"project": "curve",
|
||||
"symbol": "USDT",
|
||||
"tvlUsd": 200_000_000.0,
|
||||
"apy": 2.0,
|
||||
"apyBase": 2.0,
|
||||
"apyReward": 0.0,
|
||||
}
|
||||
]
|
||||
mock_client_cls.return_value = _make_mock_client({"data": pools})
|
||||
|
||||
result = await get_yield_pools(chain="Ethereum", project="curve")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["pool"] == "0xzzzz"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_yield_pools_returns_empty_on_error(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status.side_effect = Exception("Connection error")
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await get_yield_pools()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_yield_pools_calls_correct_url(mock_client_cls):
|
||||
mock_client = _make_mock_client({"data": []})
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
await get_yield_pools()
|
||||
|
||||
mock_client.get.assert_called_once_with("https://yields.llama.fi/pools")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_yield_pools_sorts_by_tvl_descending(mock_client_cls):
|
||||
unsorted_pools = [
|
||||
{
|
||||
"pool": "low", "chain": "Ethereum", "project": "aave",
|
||||
"symbol": "DAI", "tvlUsd": 100.0, "apy": 1.0,
|
||||
"apyBase": 1.0, "apyReward": 0.0,
|
||||
},
|
||||
{
|
||||
"pool": "high", "chain": "Ethereum", "project": "aave",
|
||||
"symbol": "USDC", "tvlUsd": 9000.0, "apy": 2.0,
|
||||
"apyBase": 2.0, "apyReward": 0.0,
|
||||
},
|
||||
]
|
||||
mock_client_cls.return_value = _make_mock_client({"data": unsorted_pools})
|
||||
|
||||
result = await get_yield_pools()
|
||||
|
||||
assert result[0]["pool"] == "high"
|
||||
assert result[1]["pool"] == "low"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_stablecoins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_stablecoins_returns_correct_fields(mock_client_cls):
|
||||
mock_client_cls.return_value = _make_mock_client(SAMPLE_STABLECOINS_RESPONSE)
|
||||
|
||||
result = await get_stablecoins()
|
||||
|
||||
assert len(result) == 2
|
||||
first = result[0]
|
||||
assert first["name"] == "Tether"
|
||||
assert first["symbol"] == "USDT"
|
||||
assert first["pegType"] == "peggedUSD"
|
||||
assert first["circulating"] == 100_000_000_000.0
|
||||
assert first["price"] == 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_stablecoins_returns_top_20(mock_client_cls):
|
||||
assets = [
|
||||
{
|
||||
"name": f"Stable{i}",
|
||||
"symbol": f"S{i}",
|
||||
"pegType": "peggedUSD",
|
||||
"circulating": {"peggedUSD": float(1000 - i)},
|
||||
"price": 1.0,
|
||||
}
|
||||
for i in range(25)
|
||||
]
|
||||
mock_client_cls.return_value = _make_mock_client({"peggedAssets": assets})
|
||||
|
||||
result = await get_stablecoins()
|
||||
|
||||
assert len(result) == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_stablecoins_handles_missing_circulating(mock_client_cls):
|
||||
raw = {"peggedAssets": [{"name": "NoCirc", "symbol": "NC", "pegType": "peggedUSD", "price": 1.0}]}
|
||||
mock_client_cls.return_value = _make_mock_client(raw)
|
||||
|
||||
result = await get_stablecoins()
|
||||
|
||||
assert result[0]["circulating"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_stablecoins_returns_empty_on_error(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status.side_effect = Exception("timeout")
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await get_stablecoins()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_stablecoins_calls_correct_url(mock_client_cls):
|
||||
mock_client = _make_mock_client({"peggedAssets": []})
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
await get_stablecoins()
|
||||
|
||||
mock_client.get.assert_called_once_with("https://stablecoins.llama.fi/stablecoins")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_dex_volumes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_dex_volumes_returns_correct_structure(mock_client_cls):
|
||||
mock_client_cls.return_value = _make_mock_client(SAMPLE_DEX_RESPONSE)
|
||||
|
||||
result = await get_dex_volumes()
|
||||
|
||||
assert result["totalVolume24h"] == 5_000_000_000.0
|
||||
assert result["totalVolume7d"] == 30_000_000_000.0
|
||||
assert len(result["protocols"]) == 2
|
||||
assert result["protocols"][0]["name"] == "Uniswap"
|
||||
assert result["protocols"][0]["volume24h"] == 2_000_000_000.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_dex_volumes_returns_none_on_error(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status.side_effect = Exception("server error")
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await get_dex_volumes()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_dex_volumes_calls_correct_url(mock_client_cls):
|
||||
mock_client = _make_mock_client({"total24h": 0, "total7d": 0, "protocols": []})
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
await get_dex_volumes()
|
||||
|
||||
mock_client.get.assert_called_once_with("https://api.llama.fi/overview/dexs")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_protocol_fees
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_protocol_fees_returns_correct_structure(mock_client_cls):
|
||||
mock_client_cls.return_value = _make_mock_client(SAMPLE_FEES_RESPONSE)
|
||||
|
||||
result = await get_protocol_fees()
|
||||
|
||||
assert len(result) == 2
|
||||
first = result[0]
|
||||
assert first["name"] == "Uniswap"
|
||||
assert first["fees24h"] == 1_000_000.0
|
||||
assert first["revenue24h"] == 500_000.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_protocol_fees_returns_empty_on_error(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status.side_effect = Exception("connection reset")
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await get_protocol_fees()
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_protocol_fees_calls_correct_url(mock_client_cls):
|
||||
mock_client = _make_mock_client({"protocols": []})
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
await get_protocol_fees()
|
||||
|
||||
mock_client.get.assert_called_once_with("https://api.llama.fi/overview/fees")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("defi_service.httpx.AsyncClient")
|
||||
async def test_get_protocol_fees_handles_missing_revenue(mock_client_cls):
|
||||
raw = {"protocols": [{"name": "SomeProtocol", "total24h": 500_000.0}]}
|
||||
mock_client_cls.return_value = _make_mock_client(raw)
|
||||
|
||||
result = await get_protocol_fees()
|
||||
|
||||
assert result[0]["revenue24h"] is None
|
||||
368
tests/test_finnhub_service_social.py
Normal file
368
tests/test_finnhub_service_social.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""Tests for the new social sentiment functions in finnhub_service."""
|
||||
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import finnhub_service
|
||||
import reddit_service
|
||||
|
||||
|
||||
# --- get_social_sentiment ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("finnhub_service.settings")
|
||||
async def test_social_sentiment_not_configured(mock_settings):
|
||||
mock_settings.finnhub_api_key = ""
|
||||
result = await finnhub_service.get_social_sentiment("AAPL")
|
||||
assert result["configured"] is False
|
||||
assert "INVEST_API_FINNHUB_API_KEY" in result["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("finnhub_service.settings")
|
||||
@patch("finnhub_service.httpx.AsyncClient")
|
||||
async def test_social_sentiment_premium_required_403(mock_client_cls, mock_settings):
|
||||
mock_settings.finnhub_api_key = "test_key"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 403
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await finnhub_service.get_social_sentiment("AAPL")
|
||||
assert result["configured"] is True
|
||||
assert result["premium_required"] is True
|
||||
assert result["reddit"] == []
|
||||
assert result["twitter"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("finnhub_service.settings")
|
||||
@patch("finnhub_service.httpx.AsyncClient")
|
||||
async def test_social_sentiment_premium_required_401(mock_client_cls, mock_settings):
|
||||
mock_settings.finnhub_api_key = "test_key"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 401
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await finnhub_service.get_social_sentiment("AAPL")
|
||||
assert result["premium_required"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("finnhub_service.settings")
|
||||
@patch("finnhub_service.httpx.AsyncClient")
|
||||
async def test_social_sentiment_success_with_data(mock_client_cls, mock_settings):
|
||||
mock_settings.finnhub_api_key = "test_key"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"reddit": [
|
||||
{"mention": 50, "positiveScore": 30, "negativeScore": 10, "score": 0.5},
|
||||
{"mention": 30, "positiveScore": 20, "negativeScore": 5, "score": 0.6},
|
||||
],
|
||||
"twitter": [
|
||||
{"mention": 100, "positiveScore": 60, "negativeScore": 20, "score": 0.4},
|
||||
],
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await finnhub_service.get_social_sentiment("AAPL")
|
||||
assert result["configured"] is True
|
||||
assert result["symbol"] == "AAPL"
|
||||
assert result["reddit_summary"]["total_mentions"] == 80
|
||||
assert result["reddit_summary"]["data_points"] == 2
|
||||
assert result["twitter_summary"]["total_mentions"] == 100
|
||||
assert len(result["reddit"]) == 2
|
||||
assert len(result["twitter"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("finnhub_service.settings")
|
||||
@patch("finnhub_service.httpx.AsyncClient")
|
||||
async def test_social_sentiment_empty_lists(mock_client_cls, mock_settings):
|
||||
mock_settings.finnhub_api_key = "test_key"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {"reddit": [], "twitter": []}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await finnhub_service.get_social_sentiment("AAPL")
|
||||
assert result["configured"] is True
|
||||
assert result["reddit_summary"] is None
|
||||
assert result["twitter_summary"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("finnhub_service.settings")
|
||||
@patch("finnhub_service.httpx.AsyncClient")
|
||||
async def test_social_sentiment_non_dict_response(mock_client_cls, mock_settings):
|
||||
mock_settings.finnhub_api_key = "test_key"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = "unexpected string response"
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await finnhub_service.get_social_sentiment("AAPL")
|
||||
assert result["reddit"] == []
|
||||
assert result["twitter"] == []
|
||||
|
||||
|
||||
# --- _summarize_social ---
|
||||
|
||||
|
||||
def test_summarize_social_empty():
|
||||
result = finnhub_service._summarize_social([])
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_summarize_social_single_entry():
|
||||
entries = [{"mention": 10, "positiveScore": 7, "negativeScore": 2, "score": 0.5}]
|
||||
result = finnhub_service._summarize_social(entries)
|
||||
assert result["total_mentions"] == 10
|
||||
assert result["total_positive"] == 7
|
||||
assert result["total_negative"] == 2
|
||||
assert result["avg_score"] == 0.5
|
||||
assert result["data_points"] == 1
|
||||
|
||||
|
||||
def test_summarize_social_multiple_entries():
|
||||
entries = [
|
||||
{"mention": 100, "positiveScore": 60, "negativeScore": 20, "score": 0.4},
|
||||
{"mention": 50, "positiveScore": 30, "negativeScore": 10, "score": 0.6},
|
||||
]
|
||||
result = finnhub_service._summarize_social(entries)
|
||||
assert result["total_mentions"] == 150
|
||||
assert result["total_positive"] == 90
|
||||
assert result["total_negative"] == 30
|
||||
assert result["avg_score"] == 0.5
|
||||
assert result["data_points"] == 2
|
||||
|
||||
|
||||
def test_summarize_social_missing_fields():
|
||||
entries = [{"mention": 5}]
|
||||
result = finnhub_service._summarize_social(entries)
|
||||
assert result["total_mentions"] == 5
|
||||
assert result["total_positive"] == 0
|
||||
assert result["total_negative"] == 0
|
||||
|
||||
|
||||
# --- get_reddit_sentiment ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("reddit_service.httpx.AsyncClient")
|
||||
async def test_reddit_sentiment_symbol_found(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"results": [
|
||||
{"ticker": "AAPL", "name": "Apple Inc", "rank": 3, "mentions": 150, "mentions_24h_ago": 100, "upvotes": 500, "rank_24h_ago": 5},
|
||||
{"ticker": "TSLA", "name": "Tesla", "rank": 1, "mentions": 300, "mentions_24h_ago": 280, "upvotes": 900, "rank_24h_ago": 2},
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await reddit_service.get_reddit_sentiment("AAPL")
|
||||
assert result["found"] is True
|
||||
assert result["symbol"] == "AAPL"
|
||||
assert result["rank"] == 3
|
||||
assert result["mentions_24h"] == 150
|
||||
assert result["mentions_24h_ago"] == 100
|
||||
assert result["mentions_change_pct"] == 50.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("reddit_service.httpx.AsyncClient")
|
||||
async def test_reddit_sentiment_symbol_not_found(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"results": [
|
||||
{"ticker": "TSLA", "rank": 1, "mentions": 300, "mentions_24h_ago": 280}
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await reddit_service.get_reddit_sentiment("AAPL")
|
||||
assert result["found"] is False
|
||||
assert result["symbol"] == "AAPL"
|
||||
assert "not in Reddit" in result["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("reddit_service.httpx.AsyncClient")
|
||||
async def test_reddit_sentiment_zero_mentions_prev(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"results": [
|
||||
{"ticker": "AAPL", "rank": 1, "mentions": 50, "mentions_24h_ago": 0, "upvotes": 200, "rank_24h_ago": None}
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await reddit_service.get_reddit_sentiment("AAPL")
|
||||
assert result["found"] is True
|
||||
assert result["mentions_change_pct"] is None # division by zero handled
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("reddit_service.httpx.AsyncClient")
|
||||
async def test_reddit_sentiment_api_failure(mock_client_cls):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = Exception("Connection error")
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await reddit_service.get_reddit_sentiment("AAPL")
|
||||
assert result["symbol"] == "AAPL"
|
||||
assert "error" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("reddit_service.httpx.AsyncClient")
|
||||
async def test_reddit_sentiment_case_insensitive(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"results": [
|
||||
{"ticker": "aapl", "rank": 1, "mentions": 100, "mentions_24h_ago": 80, "upvotes": 400, "rank_24h_ago": 2}
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await reddit_service.get_reddit_sentiment("AAPL")
|
||||
assert result["found"] is True
|
||||
|
||||
|
||||
# --- get_reddit_trending ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("reddit_service.httpx.AsyncClient")
|
||||
async def test_reddit_trending_happy_path(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"results": [
|
||||
{"ticker": "TSLA", "name": "Tesla", "rank": 1, "mentions": 500, "upvotes": 1000, "rank_24h_ago": 2, "mentions_24h_ago": 400},
|
||||
{"ticker": "AAPL", "name": "Apple", "rank": 2, "mentions": 300, "upvotes": 700, "rank_24h_ago": 1, "mentions_24h_ago": 350},
|
||||
{"ticker": "GME", "name": "GameStop", "rank": 3, "mentions": 200, "upvotes": 500, "rank_24h_ago": 3, "mentions_24h_ago": 180},
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await reddit_service.get_reddit_trending()
|
||||
assert len(result) == 3
|
||||
assert result[0]["symbol"] == "TSLA"
|
||||
assert result[0]["rank"] == 1
|
||||
assert result[1]["symbol"] == "AAPL"
|
||||
assert "mentions_24h" in result[0]
|
||||
assert "upvotes" in result[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("reddit_service.httpx.AsyncClient")
|
||||
async def test_reddit_trending_limits_to_25(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {
|
||||
"results": [
|
||||
{"ticker": f"SYM{i}", "rank": i + 1, "mentions": 100 - i, "upvotes": 50, "rank_24h_ago": i, "mentions_24h_ago": 80}
|
||||
for i in range(30)
|
||||
]
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await reddit_service.get_reddit_trending()
|
||||
assert len(result) == 25
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("reddit_service.httpx.AsyncClient")
|
||||
async def test_reddit_trending_empty_results(mock_client_cls):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
mock_resp.json.return_value = {"results": []}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await reddit_service.get_reddit_trending()
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("reddit_service.httpx.AsyncClient")
|
||||
async def test_reddit_trending_api_failure(mock_client_cls):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.side_effect = Exception("ApeWisdom down")
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
result = await reddit_service.get_reddit_trending()
|
||||
assert result == []
|
||||
559
tests/test_portfolio_service.py
Normal file
559
tests/test_portfolio_service.py
Normal file
@@ -0,0 +1,559 @@
|
||||
"""Tests for portfolio optimization service (TDD - RED phase first)."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# --- HRP Optimization ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_hrp_optimize_happy_path(mock_fetch):
|
||||
"""HRP returns weights that sum to ~1.0 for valid symbols."""
|
||||
import pandas as pd
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame(
|
||||
{
|
||||
"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0],
|
||||
"MSFT": [300.0, 302.0, 298.0, 305.0, 307.0],
|
||||
"GOOGL": [2800.0, 2820.0, 2790.0, 2830.0, 2850.0],
|
||||
}
|
||||
)
|
||||
|
||||
import portfolio_service
|
||||
|
||||
result = await portfolio_service.optimize_hrp(
|
||||
["AAPL", "MSFT", "GOOGL"], days=365
|
||||
)
|
||||
|
||||
assert result["method"] == "hrp"
|
||||
assert set(result["weights"].keys()) == {"AAPL", "MSFT", "GOOGL"}
|
||||
total = sum(result["weights"].values())
|
||||
assert abs(total - 1.0) < 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_hrp_optimize_single_symbol(mock_fetch):
|
||||
"""Single symbol gets weight of 1.0."""
|
||||
import pandas as pd
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame(
|
||||
{"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0]}
|
||||
)
|
||||
|
||||
import portfolio_service
|
||||
|
||||
result = await portfolio_service.optimize_hrp(["AAPL"], days=365)
|
||||
|
||||
assert result["weights"]["AAPL"] == pytest.approx(1.0, abs=0.01)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_hrp_optimize_no_data_raises(mock_fetch):
|
||||
"""Raises ValueError when no price data is available."""
|
||||
import pandas as pd
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame()
|
||||
|
||||
import portfolio_service
|
||||
|
||||
with pytest.raises(ValueError, match="No price data"):
|
||||
await portfolio_service.optimize_hrp(["AAPL", "MSFT"], days=365)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hrp_optimize_empty_symbols_raises():
|
||||
"""Raises ValueError for empty symbol list."""
|
||||
import portfolio_service
|
||||
|
||||
with pytest.raises(ValueError, match="symbols"):
|
||||
await portfolio_service.optimize_hrp([], days=365)
|
||||
|
||||
|
||||
# --- Correlation Matrix ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_correlation_matrix_happy_path(mock_fetch):
|
||||
"""Correlation matrix has 1.0 on diagonal and valid shape."""
|
||||
import pandas as pd
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame(
|
||||
{
|
||||
"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0],
|
||||
"MSFT": [300.0, 302.0, 298.0, 305.0, 307.0],
|
||||
"GOOGL": [2800.0, 2820.0, 2790.0, 2830.0, 2850.0],
|
||||
}
|
||||
)
|
||||
|
||||
import portfolio_service
|
||||
|
||||
result = await portfolio_service.compute_correlation(
|
||||
["AAPL", "MSFT", "GOOGL"], days=365
|
||||
)
|
||||
|
||||
assert result["symbols"] == ["AAPL", "MSFT", "GOOGL"]
|
||||
matrix = result["matrix"]
|
||||
assert len(matrix) == 3
|
||||
assert len(matrix[0]) == 3
|
||||
# Diagonal should be 1.0
|
||||
for i in range(3):
|
||||
assert abs(matrix[i][i] - 1.0) < 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_correlation_matrix_two_symbols(mock_fetch):
|
||||
"""Two-symbol correlation is symmetric."""
|
||||
import pandas as pd
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame(
|
||||
{
|
||||
"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0],
|
||||
"MSFT": [300.0, 302.0, 298.0, 305.0, 307.0],
|
||||
}
|
||||
)
|
||||
|
||||
import portfolio_service
|
||||
|
||||
result = await portfolio_service.compute_correlation(["AAPL", "MSFT"], days=365)
|
||||
|
||||
matrix = result["matrix"]
|
||||
# Symmetric: matrix[0][1] == matrix[1][0]
|
||||
assert abs(matrix[0][1] - matrix[1][0]) < 0.001
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_correlation_no_data_raises(mock_fetch):
|
||||
"""Raises ValueError when no data is returned."""
|
||||
import pandas as pd
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame()
|
||||
|
||||
import portfolio_service
|
||||
|
||||
with pytest.raises(ValueError, match="No price data"):
|
||||
await portfolio_service.compute_correlation(["AAPL", "MSFT"], days=365)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_correlation_empty_symbols_raises():
|
||||
"""Raises ValueError for empty symbol list."""
|
||||
import portfolio_service
|
||||
|
||||
with pytest.raises(ValueError, match="symbols"):
|
||||
await portfolio_service.compute_correlation([], days=365)
|
||||
|
||||
|
||||
# --- Risk Parity ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_risk_parity_happy_path(mock_fetch):
|
||||
"""Risk parity returns weights and risk_contributions summing to ~1.0."""
|
||||
import pandas as pd
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame(
|
||||
{
|
||||
"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0],
|
||||
"MSFT": [300.0, 302.0, 298.0, 305.0, 307.0],
|
||||
"GOOGL": [2800.0, 2820.0, 2790.0, 2830.0, 2850.0],
|
||||
}
|
||||
)
|
||||
|
||||
import portfolio_service
|
||||
|
||||
result = await portfolio_service.compute_risk_parity(
|
||||
["AAPL", "MSFT", "GOOGL"], days=365
|
||||
)
|
||||
|
||||
assert result["method"] == "risk_parity"
|
||||
assert set(result["weights"].keys()) == {"AAPL", "MSFT", "GOOGL"}
|
||||
assert set(result["risk_contributions"].keys()) == {"AAPL", "MSFT", "GOOGL"}
|
||||
total_w = sum(result["weights"].values())
|
||||
assert abs(total_w - 1.0) < 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_risk_parity_single_symbol(mock_fetch):
|
||||
"""Single symbol gets weight 1.0 and risk_contribution 1.0."""
|
||||
import pandas as pd
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame(
|
||||
{"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0]}
|
||||
)
|
||||
|
||||
import portfolio_service
|
||||
|
||||
result = await portfolio_service.compute_risk_parity(["AAPL"], days=365)
|
||||
|
||||
assert result["weights"]["AAPL"] == pytest.approx(1.0, abs=0.01)
|
||||
assert result["risk_contributions"]["AAPL"] == pytest.approx(1.0, abs=0.01)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_risk_parity_no_data_raises(mock_fetch):
|
||||
"""Raises ValueError when no price data is available."""
|
||||
import pandas as pd
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame()
|
||||
|
||||
import portfolio_service
|
||||
|
||||
with pytest.raises(ValueError, match="No price data"):
|
||||
await portfolio_service.compute_risk_parity(["AAPL", "MSFT"], days=365)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_risk_parity_empty_symbols_raises():
|
||||
"""Raises ValueError for empty symbol list."""
|
||||
import portfolio_service
|
||||
|
||||
with pytest.raises(ValueError, match="symbols"):
|
||||
await portfolio_service.compute_risk_parity([], days=365)
|
||||
|
||||
|
||||
# --- fetch_historical_prices helper ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical")
|
||||
async def test_fetch_historical_prices_returns_dataframe(mock_fetch_hist):
|
||||
"""fetch_historical_prices assembles a price DataFrame from OBBject results."""
|
||||
import pandas as pd
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.results = [
|
||||
MagicMock(date="2024-01-01", close=150.0),
|
||||
MagicMock(date="2024-01-02", close=151.0),
|
||||
]
|
||||
mock_fetch_hist.return_value = mock_result
|
||||
|
||||
import portfolio_service
|
||||
|
||||
df = await portfolio_service.fetch_historical_prices(["AAPL"], days=30)
|
||||
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert "AAPL" in df.columns
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical")
|
||||
async def test_fetch_historical_prices_skips_none(mock_fetch_hist):
|
||||
"""fetch_historical_prices returns empty DataFrame when all fetches fail."""
|
||||
import pandas as pd
|
||||
|
||||
mock_fetch_hist.return_value = None
|
||||
|
||||
import portfolio_service
|
||||
|
||||
df = await portfolio_service.fetch_historical_prices(["AAPL", "MSFT"], days=30)
|
||||
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert df.empty
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cluster_stocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_prices(symbols: list[str], n_days: int = 60):
|
||||
"""Build a deterministic price DataFrame with enough rows for t-SNE."""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
data = {}
|
||||
for sym in symbols:
|
||||
prices = 100.0 + np.cumsum(rng.normal(0, 1, n_days))
|
||||
data[sym] = prices
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_happy_path(mock_fetch):
|
||||
"""cluster_stocks returns valid structure for 6 symbols."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.cluster_stocks(symbols, days=180)
|
||||
|
||||
assert result["method"] == "t-SNE + KMeans"
|
||||
assert result["days"] == 180
|
||||
assert set(result["symbols"]) == set(symbols)
|
||||
|
||||
coords = result["coordinates"]
|
||||
assert len(coords) == len(symbols)
|
||||
for c in coords:
|
||||
assert "symbol" in c
|
||||
assert "x" in c
|
||||
assert "y" in c
|
||||
assert "cluster" in c
|
||||
assert isinstance(c["x"], float)
|
||||
assert isinstance(c["y"], float)
|
||||
assert isinstance(c["cluster"], int)
|
||||
|
||||
clusters = result["clusters"]
|
||||
assert isinstance(clusters, dict)
|
||||
all_in_clusters = []
|
||||
for members in clusters.values():
|
||||
all_in_clusters.extend(members)
|
||||
assert set(all_in_clusters) == set(symbols)
|
||||
|
||||
assert "n_clusters" in result
|
||||
assert result["n_clusters"] >= 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_custom_n_clusters(mock_fetch):
|
||||
"""Custom n_clusters is respected in the output."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.cluster_stocks(symbols, days=180, n_clusters=3)
|
||||
|
||||
assert result["n_clusters"] == 3
|
||||
assert len(result["clusters"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_minimum_three_symbols(mock_fetch):
|
||||
"""cluster_stocks works correctly with exactly 3 symbols (minimum)."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.cluster_stocks(symbols, days=180)
|
||||
|
||||
assert len(result["coordinates"]) == 3
|
||||
assert set(result["symbols"]) == set(symbols)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cluster_stocks_too_few_symbols_raises():
|
||||
"""cluster_stocks raises ValueError when fewer than 3 symbols are provided."""
|
||||
import portfolio_service
|
||||
|
||||
with pytest.raises(ValueError, match="at least 3"):
|
||||
await portfolio_service.cluster_stocks(["AAPL", "MSFT"], days=180)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cluster_stocks_empty_symbols_raises():
|
||||
"""cluster_stocks raises ValueError for empty symbol list."""
|
||||
import portfolio_service
|
||||
|
||||
with pytest.raises(ValueError, match="at least 3"):
|
||||
await portfolio_service.cluster_stocks([], days=180)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_no_data_raises(mock_fetch):
|
||||
"""cluster_stocks raises ValueError when fetch returns empty DataFrame."""
|
||||
import pandas as pd
|
||||
import portfolio_service
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame()
|
||||
|
||||
with pytest.raises(ValueError, match="No price data"):
|
||||
await portfolio_service.cluster_stocks(["AAPL", "MSFT", "GOOGL"], days=180)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_identical_returns_still_works(mock_fetch):
|
||||
"""t-SNE should not raise even when all symbols have identical returns."""
|
||||
import pandas as pd
|
||||
import portfolio_service
|
||||
|
||||
# All columns identical — edge case for t-SNE
|
||||
flat = pd.DataFrame(
|
||||
{
|
||||
"AAPL": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
|
||||
"MSFT": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
|
||||
"GOOGL": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
|
||||
}
|
||||
)
|
||||
|
||||
mock_fetch.return_value = flat
|
||||
|
||||
result = await portfolio_service.cluster_stocks(
|
||||
["AAPL", "MSFT", "GOOGL"], days=180
|
||||
)
|
||||
|
||||
assert len(result["coordinates"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_coordinates_are_floats(mock_fetch):
|
||||
"""x and y coordinates must be Python floats (JSON-serializable)."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.cluster_stocks(symbols, days=180)
|
||||
|
||||
for c in result["coordinates"]:
|
||||
assert type(c["x"]) is float
|
||||
assert type(c["y"]) is float
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_clusters_key_is_str(mock_fetch):
|
||||
"""clusters dict keys must be strings (JSON object keys)."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.cluster_stocks(symbols, days=180)
|
||||
|
||||
for key in result["clusters"]:
|
||||
assert isinstance(key, str), f"Expected str key, got {type(key)}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_similar_stocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_happy_path(mock_fetch):
|
||||
"""most_similar is sorted descending by correlation; least_similar ascending."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.find_similar_stocks(
|
||||
"AAPL", ["MSFT", "GOOGL", "AMZN", "JPM", "BAC"], days=180, top_n=3
|
||||
)
|
||||
|
||||
assert result["symbol"] == "AAPL"
|
||||
most = result["most_similar"]
|
||||
least = result["least_similar"]
|
||||
|
||||
assert len(most) <= 3
|
||||
assert len(least) <= 3
|
||||
|
||||
# most_similar sorted descending
|
||||
corrs_most = [e["correlation"] for e in most]
|
||||
assert corrs_most == sorted(corrs_most, reverse=True)
|
||||
|
||||
# least_similar sorted ascending
|
||||
corrs_least = [e["correlation"] for e in least]
|
||||
assert corrs_least == sorted(corrs_least)
|
||||
|
||||
# Each entry has symbol and correlation
|
||||
for entry in most + least:
|
||||
assert "symbol" in entry
|
||||
assert "correlation" in entry
|
||||
assert isinstance(entry["correlation"], float)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_top_n_larger_than_universe(mock_fetch):
|
||||
"""top_n larger than universe size is handled gracefully (returns all)."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.find_similar_stocks(
|
||||
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=10
|
||||
)
|
||||
|
||||
# Should return at most len(universe) entries, not crash
|
||||
assert len(result["most_similar"]) <= 2
|
||||
assert len(result["least_similar"]) <= 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_no_overlap_with_most_and_least(mock_fetch):
|
||||
"""most_similar and least_similar should not contain the target symbol."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.find_similar_stocks(
|
||||
"AAPL", ["MSFT", "GOOGL", "AMZN", "JPM"], days=180, top_n=2
|
||||
)
|
||||
|
||||
all_symbols = [e["symbol"] for e in result["most_similar"] + result["least_similar"]]
|
||||
assert "AAPL" not in all_symbols
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_no_data_raises(mock_fetch):
|
||||
"""find_similar_stocks raises ValueError when no price data is returned."""
|
||||
import pandas as pd
|
||||
import portfolio_service
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame()
|
||||
|
||||
with pytest.raises(ValueError, match="No price data"):
|
||||
await portfolio_service.find_similar_stocks(
|
||||
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_target_not_in_data_raises(mock_fetch):
|
||||
"""find_similar_stocks raises ValueError when target symbol has no data."""
|
||||
import portfolio_service
|
||||
|
||||
# Only universe symbols have data, not the target
|
||||
mock_fetch.return_value = _make_prices(["MSFT", "GOOGL"])
|
||||
|
||||
with pytest.raises(ValueError, match="AAPL"):
|
||||
await portfolio_service.find_similar_stocks(
|
||||
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_default_top_n(mock_fetch):
|
||||
"""Default top_n=5 returns at most 5 entries in most_similar."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC", "WFC", "GS"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.find_similar_stocks(
|
||||
"AAPL",
|
||||
["MSFT", "GOOGL", "AMZN", "JPM", "BAC", "WFC", "GS"],
|
||||
days=180,
|
||||
)
|
||||
|
||||
assert len(result["most_similar"]) <= 5
|
||||
assert len(result["least_similar"]) <= 5
|
||||
443
tests/test_routes_backtest.py
Normal file
443
tests/test_routes_backtest.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""Integration tests for backtest routes - written FIRST (TDD RED phase)."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# Shared mock response used across strategy tests
|
||||
MOCK_BACKTEST_RESULT = {
|
||||
"total_return": 0.15,
|
||||
"annualized_return": 0.14,
|
||||
"sharpe_ratio": 1.2,
|
||||
"max_drawdown": -0.08,
|
||||
"win_rate": 0.6,
|
||||
"total_trades": 10,
|
||||
"equity_curve": [10000 + i * 75 for i in range(20)],
|
||||
}
|
||||
|
||||
MOCK_MOMENTUM_RESULT = {
|
||||
**MOCK_BACKTEST_RESULT,
|
||||
"allocation_history": [
|
||||
{"date": "2024-01-01", "symbols": ["AAPL", "MSFT"], "weights": [0.5, 0.5]},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/v1/backtest/sma-crossover
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
|
||||
async def test_sma_crossover_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/sma-crossover",
|
||||
json={
|
||||
"symbol": "AAPL",
|
||||
"short_window": 20,
|
||||
"long_window": 50,
|
||||
"days": 365,
|
||||
"initial_capital": 10000,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["total_return"] == pytest.approx(0.15)
|
||||
assert data["data"]["total_trades"] == 10
|
||||
assert len(data["data"]["equity_curve"]) == 20
|
||||
mock_fn.assert_called_once_with(
|
||||
"AAPL",
|
||||
short_window=20,
|
||||
long_window=50,
|
||||
days=365,
|
||||
initial_capital=10000.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
|
||||
async def test_sma_crossover_default_values(mock_fn, client):
|
||||
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/sma-crossover",
|
||||
json={"symbol": "MSFT"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(
|
||||
"MSFT",
|
||||
short_window=20,
|
||||
long_window=50,
|
||||
days=365,
|
||||
initial_capital=10000.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sma_crossover_missing_symbol(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/sma-crossover",
|
||||
json={"short_window": 20, "long_window": 50},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sma_crossover_short_window_too_small(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/sma-crossover",
|
||||
json={"symbol": "AAPL", "short_window": 2, "long_window": 50},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sma_crossover_long_window_too_large(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/sma-crossover",
|
||||
json={"symbol": "AAPL", "short_window": 20, "long_window": 500},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sma_crossover_days_too_few(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/sma-crossover",
|
||||
json={"symbol": "AAPL", "days": 5},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sma_crossover_days_too_many(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/sma-crossover",
|
||||
json={"symbol": "AAPL", "days": 9999},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sma_crossover_capital_zero(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/sma-crossover",
|
||||
json={"symbol": "AAPL", "initial_capital": 0},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
|
||||
async def test_sma_crossover_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("Data fetch failed")
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/sma-crossover",
|
||||
json={"symbol": "AAPL", "short_window": 20, "long_window": 50},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
|
||||
async def test_sma_crossover_value_error_returns_400(mock_fn, client):
|
||||
mock_fn.side_effect = ValueError("No historical data")
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/sma-crossover",
|
||||
json={"symbol": "AAPL", "short_window": 20, "long_window": 50},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/v1/backtest/rsi
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
|
||||
async def test_rsi_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/rsi",
|
||||
json={
|
||||
"symbol": "AAPL",
|
||||
"period": 14,
|
||||
"oversold": 30,
|
||||
"overbought": 70,
|
||||
"days": 365,
|
||||
"initial_capital": 10000,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["sharpe_ratio"] == pytest.approx(1.2)
|
||||
mock_fn.assert_called_once_with(
|
||||
"AAPL",
|
||||
period=14,
|
||||
oversold=30.0,
|
||||
overbought=70.0,
|
||||
days=365,
|
||||
initial_capital=10000.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
|
||||
async def test_rsi_default_values(mock_fn, client):
|
||||
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||
resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"})
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(
|
||||
"AAPL",
|
||||
period=14,
|
||||
oversold=30.0,
|
||||
overbought=70.0,
|
||||
days=365,
|
||||
initial_capital=10000.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rsi_missing_symbol(client):
|
||||
resp = await client.post("/api/v1/backtest/rsi", json={"period": 14})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rsi_period_too_small(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/rsi",
|
||||
json={"symbol": "AAPL", "period": 1},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rsi_oversold_too_high(client):
|
||||
# oversold must be < 50
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/rsi",
|
||||
json={"symbol": "AAPL", "oversold": 55, "overbought": 70},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rsi_overbought_too_low(client):
|
||||
# overbought must be > 50
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/rsi",
|
||||
json={"symbol": "AAPL", "oversold": 30, "overbought": 45},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
|
||||
async def test_rsi_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("upstream error")
|
||||
resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"})
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
|
||||
async def test_rsi_value_error_returns_400(mock_fn, client):
|
||||
mock_fn.side_effect = ValueError("Insufficient data")
|
||||
resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/v1/backtest/buy-and-hold
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
|
||||
async def test_buy_and_hold_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/buy-and-hold",
|
||||
json={"symbol": "AAPL", "days": 365, "initial_capital": 10000},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["total_return"] == pytest.approx(0.15)
|
||||
mock_fn.assert_called_once_with("AAPL", days=365, initial_capital=10000.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
|
||||
async def test_buy_and_hold_default_values(mock_fn, client):
|
||||
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"})
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with("AAPL", days=365, initial_capital=10000.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buy_and_hold_missing_symbol(client):
|
||||
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"days": 365})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buy_and_hold_days_too_few(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/buy-and-hold",
|
||||
json={"symbol": "AAPL", "days": 10},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
|
||||
async def test_buy_and_hold_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("provider down")
|
||||
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"})
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
|
||||
async def test_buy_and_hold_value_error_returns_400(mock_fn, client):
|
||||
mock_fn.side_effect = ValueError("No historical data")
|
||||
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/v1/backtest/momentum
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
|
||||
async def test_momentum_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = MOCK_MOMENTUM_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/momentum",
|
||||
json={
|
||||
"symbols": ["AAPL", "MSFT", "GOOGL"],
|
||||
"lookback": 60,
|
||||
"top_n": 2,
|
||||
"rebalance_days": 30,
|
||||
"days": 365,
|
||||
"initial_capital": 10000,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert "allocation_history" in data["data"]
|
||||
assert data["data"]["total_trades"] == 10
|
||||
mock_fn.assert_called_once_with(
|
||||
symbols=["AAPL", "MSFT", "GOOGL"],
|
||||
lookback=60,
|
||||
top_n=2,
|
||||
rebalance_days=30,
|
||||
days=365,
|
||||
initial_capital=10000.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
|
||||
async def test_momentum_default_values(mock_fn, client):
|
||||
mock_fn.return_value = MOCK_MOMENTUM_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/momentum",
|
||||
json={"symbols": ["AAPL", "MSFT"]},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(
|
||||
symbols=["AAPL", "MSFT"],
|
||||
lookback=60,
|
||||
top_n=2,
|
||||
rebalance_days=30,
|
||||
days=365,
|
||||
initial_capital=10000.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_missing_symbols(client):
|
||||
resp = await client.post("/api/v1/backtest/momentum", json={"lookback": 60})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_too_few_symbols(client):
|
||||
# min_length=2 on symbols list
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/momentum",
|
||||
json={"symbols": ["AAPL"]},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_too_many_symbols(client):
|
||||
symbols = [f"SYM{i}" for i in range(25)]
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/momentum",
|
||||
json={"symbols": symbols},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_lookback_too_small(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/momentum",
|
||||
json={"symbols": ["AAPL", "MSFT"], "lookback": 2},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_momentum_days_too_few(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/momentum",
|
||||
json={"symbols": ["AAPL", "MSFT"], "days": 10},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
|
||||
async def test_momentum_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("provider down")
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/momentum",
|
||||
json={"symbols": ["AAPL", "MSFT"]},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
|
||||
async def test_momentum_value_error_returns_400(mock_fn, client):
|
||||
mock_fn.side_effect = ValueError("No price data available")
|
||||
resp = await client.post(
|
||||
"/api/v1/backtest/momentum",
|
||||
json={"symbols": ["AAPL", "MSFT"]},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
321
tests/test_routes_cn.py
Normal file
321
tests/test_routes_cn.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""Integration tests for routes_cn.py - written FIRST (TDD RED phase)."""
|
||||
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# ============================================================
|
||||
# A-share quote GET /api/v1/cn/a-share/{symbol}/quote
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_a_share_quote", new_callable=AsyncMock)
|
||||
async def test_a_share_quote_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"symbol": "000001",
|
||||
"name": "平安银行",
|
||||
"price": 12.34,
|
||||
"change": 0.15,
|
||||
"change_percent": 1.23,
|
||||
"volume": 500_000,
|
||||
"turnover": 6_170_000.0,
|
||||
"open": 12.10,
|
||||
"high": 12.50,
|
||||
"low": 12.00,
|
||||
"prev_close": 12.19,
|
||||
}
|
||||
resp = await client.get("/api/v1/cn/a-share/000001/quote")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["symbol"] == "000001"
|
||||
assert data["data"]["name"] == "平安银行"
|
||||
assert data["data"]["price"] == pytest.approx(12.34)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_a_share_quote", new_callable=AsyncMock)
|
||||
async def test_a_share_quote_not_found_returns_404(mock_fn, client):
|
||||
mock_fn.return_value = None
|
||||
resp = await client.get("/api/v1/cn/a-share/000001/quote")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_a_share_quote", new_callable=AsyncMock)
|
||||
async def test_a_share_quote_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("AKShare down")
|
||||
resp = await client.get("/api/v1/cn/a-share/000001/quote")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_a_share_quote_invalid_symbol_returns_400(client):
|
||||
# symbol starting with 1 is invalid for A-shares
|
||||
resp = await client.get("/api/v1/cn/a-share/100001/quote")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_a_share_quote_non_numeric_symbol_returns_400(client):
|
||||
resp = await client.get("/api/v1/cn/a-share/ABCDEF/quote")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_a_share_quote_too_short_returns_422(client):
|
||||
resp = await client.get("/api/v1/cn/a-share/00001/quote")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ============================================================
|
||||
# A-share historical GET /api/v1/cn/a-share/{symbol}/historical
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock)
|
||||
async def test_a_share_historical_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{
|
||||
"date": "2026-01-01",
|
||||
"open": 10.0,
|
||||
"close": 10.5,
|
||||
"high": 11.0,
|
||||
"low": 9.5,
|
||||
"volume": 1_000_000,
|
||||
"turnover": 10_500_000.0,
|
||||
"change_percent": 0.5,
|
||||
"turnover_rate": 0.3,
|
||||
}
|
||||
]
|
||||
resp = await client.get("/api/v1/cn/a-share/000001/historical?days=30")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert isinstance(data["data"], list)
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["date"] == "2026-01-01"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock)
|
||||
async def test_a_share_historical_default_days(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/cn/a-share/600519/historical")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with("600519", days=365)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock)
|
||||
async def test_a_share_historical_empty_returns_200(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/cn/a-share/000001/historical")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock)
|
||||
async def test_a_share_historical_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("AKShare down")
|
||||
resp = await client.get("/api/v1/cn/a-share/000001/historical")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_a_share_historical_invalid_symbol_returns_400(client):
|
||||
resp = await client.get("/api/v1/cn/a-share/100001/historical")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_a_share_historical_days_out_of_range_returns_422(client):
|
||||
resp = await client.get("/api/v1/cn/a-share/000001/historical?days=0")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ============================================================
|
||||
# A-share search GET /api/v1/cn/a-share/search
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.search_a_shares", new_callable=AsyncMock)
|
||||
async def test_a_share_search_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"code": "000001", "name": "平安银行"},
|
||||
{"code": "000002", "name": "平安地产"},
|
||||
]
|
||||
resp = await client.get("/api/v1/cn/a-share/search?query=平安")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 2
|
||||
assert data["data"][0]["code"] == "000001"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.search_a_shares", new_callable=AsyncMock)
|
||||
async def test_a_share_search_empty_results(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/cn/a-share/search?query=NOMATCH")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.search_a_shares", new_callable=AsyncMock)
|
||||
async def test_a_share_search_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("AKShare down")
|
||||
resp = await client.get("/api/v1/cn/a-share/search?query=test")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_a_share_search_missing_query_returns_422(client):
|
||||
resp = await client.get("/api/v1/cn/a-share/search")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ============================================================
|
||||
# HK quote GET /api/v1/cn/hk/{symbol}/quote
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_hk_quote", new_callable=AsyncMock)
|
||||
async def test_hk_quote_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"symbol": "00700",
|
||||
"name": "腾讯控股",
|
||||
"price": 380.0,
|
||||
"change": 5.0,
|
||||
"change_percent": 1.33,
|
||||
"volume": 10_000_000,
|
||||
"turnover": 3_800_000_000.0,
|
||||
"open": 375.0,
|
||||
"high": 385.0,
|
||||
"low": 374.0,
|
||||
"prev_close": 375.0,
|
||||
}
|
||||
resp = await client.get("/api/v1/cn/hk/00700/quote")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["symbol"] == "00700"
|
||||
assert data["data"]["price"] == pytest.approx(380.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_hk_quote", new_callable=AsyncMock)
|
||||
async def test_hk_quote_not_found_returns_404(mock_fn, client):
|
||||
mock_fn.return_value = None
|
||||
resp = await client.get("/api/v1/cn/hk/00700/quote")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_hk_quote", new_callable=AsyncMock)
|
||||
async def test_hk_quote_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("AKShare down")
|
||||
resp = await client.get("/api/v1/cn/hk/00700/quote")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hk_quote_invalid_symbol_letters_returns_400(client):
|
||||
resp = await client.get("/api/v1/cn/hk/ABCDE/quote")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hk_quote_too_short_returns_422(client):
|
||||
resp = await client.get("/api/v1/cn/hk/0070/quote")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hk_quote_too_long_returns_422(client):
|
||||
resp = await client.get("/api/v1/cn/hk/007000/quote")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ============================================================
|
||||
# HK historical GET /api/v1/cn/hk/{symbol}/historical
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock)
|
||||
async def test_hk_historical_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{
|
||||
"date": "2026-01-01",
|
||||
"open": 375.0,
|
||||
"close": 380.0,
|
||||
"high": 385.0,
|
||||
"low": 374.0,
|
||||
"volume": 10_000_000,
|
||||
"turnover": 3_800_000_000.0,
|
||||
"change_percent": 1.33,
|
||||
"turnover_rate": 0.5,
|
||||
}
|
||||
]
|
||||
resp = await client.get("/api/v1/cn/hk/00700/historical?days=90")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert isinstance(data["data"], list)
|
||||
assert data["data"][0]["close"] == pytest.approx(380.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock)
|
||||
async def test_hk_historical_default_days(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/cn/hk/09988/historical")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with("09988", days=365)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock)
|
||||
async def test_hk_historical_empty_returns_200(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/cn/hk/00700/historical")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock)
|
||||
async def test_hk_historical_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("AKShare down")
|
||||
resp = await client.get("/api/v1/cn/hk/00700/historical")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hk_historical_invalid_symbol_returns_400(client):
|
||||
resp = await client.get("/api/v1/cn/hk/ABCDE/historical")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hk_historical_days_out_of_range_returns_422(client):
|
||||
resp = await client.get("/api/v1/cn/hk/00700/historical?days=0")
|
||||
assert resp.status_code == 422
|
||||
98
tests/test_routes_congress.py
Normal file
98
tests/test_routes_congress.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Tests for congress trading routes (TDD - RED phase first)."""
|
||||
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# --- GET /api/v1/regulators/congress/trades ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.congress_service.get_congress_trades", new_callable=AsyncMock)
|
||||
async def test_congress_trades_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{
|
||||
"representative": "Nancy Pelosi",
|
||||
"ticker": "NVDA",
|
||||
"transaction_date": "2024-01-15",
|
||||
"transaction_type": "Purchase",
|
||||
"amount": "$1,000,001-$5,000,000",
|
||||
}
|
||||
]
|
||||
resp = await client.get("/api/v1/regulators/congress/trades")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["representative"] == "Nancy Pelosi"
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.congress_service.get_congress_trades", new_callable=AsyncMock)
|
||||
async def test_congress_trades_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/regulators/congress/trades")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.congress_service.get_congress_trades", new_callable=AsyncMock)
|
||||
async def test_congress_trades_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("Data provider unavailable")
|
||||
resp = await client.get("/api/v1/regulators/congress/trades")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- GET /api/v1/regulators/congress/bills ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.congress_service.search_congress_bills", new_callable=AsyncMock)
|
||||
async def test_congress_bills_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"title": "Infrastructure Investment and Jobs Act", "bill_id": "HR3684"},
|
||||
{"title": "Inflation Reduction Act", "bill_id": "HR5376"},
|
||||
]
|
||||
resp = await client.get("/api/v1/regulators/congress/bills?query=infrastructure")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 2
|
||||
assert data["data"][0]["bill_id"] == "HR3684"
|
||||
mock_fn.assert_called_once_with("infrastructure")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_congress_bills_missing_query(client):
|
||||
resp = await client.get("/api/v1/regulators/congress/bills")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.congress_service.search_congress_bills", new_callable=AsyncMock)
|
||||
async def test_congress_bills_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/regulators/congress/bills?query=nonexistent")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.congress_service.search_congress_bills", new_callable=AsyncMock)
|
||||
async def test_congress_bills_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("Congress API unavailable")
|
||||
resp = await client.get("/api/v1/regulators/congress/bills?query=tax")
|
||||
assert resp.status_code == 502
|
||||
363
tests/test_routes_defi.py
Normal file
363
tests/test_routes_defi.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Tests for routes_defi.py - DeFi API routes.
|
||||
|
||||
TDD: these tests are written before implementation.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/v1/defi/tvl/protocols
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_top_protocols", new_callable=AsyncMock)
|
||||
async def test_tvl_protocols_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{
|
||||
"name": "Aave",
|
||||
"symbol": "AAVE",
|
||||
"tvl": 10_000_000_000.0,
|
||||
"chain": "Ethereum",
|
||||
"chains": ["Ethereum"],
|
||||
"category": "Lending",
|
||||
"change_1d": 0.5,
|
||||
"change_7d": -1.2,
|
||||
}
|
||||
]
|
||||
resp = await client.get("/api/v1/defi/tvl/protocols")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["name"] == "Aave"
|
||||
assert data["data"][0]["tvl"] == 10_000_000_000.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_top_protocols", new_callable=AsyncMock)
|
||||
async def test_tvl_protocols_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/defi/tvl/protocols")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_top_protocols", new_callable=AsyncMock)
|
||||
async def test_tvl_protocols_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("DefiLlama unavailable")
|
||||
resp = await client.get("/api/v1/defi/tvl/protocols")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/v1/defi/tvl/chains
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_chain_tvls", new_callable=AsyncMock)
|
||||
async def test_tvl_chains_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"name": "Ethereum", "tvl": 50_000_000_000.0, "tokenSymbol": "ETH"},
|
||||
{"name": "BSC", "tvl": 5_000_000_000.0, "tokenSymbol": "BNB"},
|
||||
]
|
||||
resp = await client.get("/api/v1/defi/tvl/chains")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 2
|
||||
assert data["data"][0]["name"] == "Ethereum"
|
||||
assert data["data"][0]["tokenSymbol"] == "ETH"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_chain_tvls", new_callable=AsyncMock)
|
||||
async def test_tvl_chains_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/defi/tvl/chains")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_chain_tvls", new_callable=AsyncMock)
|
||||
async def test_tvl_chains_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("upstream error")
|
||||
resp = await client.get("/api/v1/defi/tvl/chains")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/v1/defi/tvl/{protocol}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_protocol_tvl", new_callable=AsyncMock)
|
||||
async def test_protocol_tvl_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = 10_000_000_000.0
|
||||
resp = await client.get("/api/v1/defi/tvl/aave")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["protocol"] == "aave"
|
||||
assert data["data"]["tvl"] == 10_000_000_000.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_protocol_tvl", new_callable=AsyncMock)
|
||||
async def test_protocol_tvl_not_found_returns_404(mock_fn, client):
|
||||
mock_fn.return_value = None
|
||||
resp = await client.get("/api/v1/defi/tvl/nonexistent-protocol")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_protocol_tvl", new_callable=AsyncMock)
|
||||
async def test_protocol_tvl_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("HTTP error")
|
||||
resp = await client.get("/api/v1/defi/tvl/aave")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_protocol_tvl", new_callable=AsyncMock)
|
||||
async def test_protocol_tvl_passes_slug_to_service(mock_fn, client):
|
||||
mock_fn.return_value = 5_000_000_000.0
|
||||
await client.get("/api/v1/defi/tvl/uniswap-v3")
|
||||
|
||||
mock_fn.assert_called_once_with("uniswap-v3")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/v1/defi/yields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
|
||||
async def test_yields_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{
|
||||
"pool": "0xabcd",
|
||||
"chain": "Ethereum",
|
||||
"project": "aave-v3",
|
||||
"symbol": "USDC",
|
||||
"tvlUsd": 1_000_000_000.0,
|
||||
"apy": 3.5,
|
||||
"apyBase": 3.0,
|
||||
"apyReward": 0.5,
|
||||
}
|
||||
]
|
||||
resp = await client.get("/api/v1/defi/yields")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["pool"] == "0xabcd"
|
||||
assert data["data"][0]["apy"] == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
|
||||
async def test_yields_with_chain_filter(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/defi/yields?chain=Ethereum")
|
||||
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(chain="Ethereum", project=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
|
||||
async def test_yields_with_project_filter(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/defi/yields?project=aave-v3")
|
||||
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(chain=None, project="aave-v3")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
|
||||
async def test_yields_with_chain_and_project_filter(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/defi/yields?chain=Polygon&project=curve")
|
||||
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(chain="Polygon", project="curve")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
|
||||
async def test_yields_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/defi/yields")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
|
||||
async def test_yields_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("yields API down")
|
||||
resp = await client.get("/api/v1/defi/yields")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/v1/defi/stablecoins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_stablecoins", new_callable=AsyncMock)
|
||||
async def test_stablecoins_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{
|
||||
"name": "Tether",
|
||||
"symbol": "USDT",
|
||||
"pegType": "peggedUSD",
|
||||
"circulating": 100_000_000_000.0,
|
||||
"price": 1.0,
|
||||
}
|
||||
]
|
||||
resp = await client.get("/api/v1/defi/stablecoins")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["symbol"] == "USDT"
|
||||
assert data["data"][0]["circulating"] == 100_000_000_000.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_stablecoins", new_callable=AsyncMock)
|
||||
async def test_stablecoins_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/defi/stablecoins")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_stablecoins", new_callable=AsyncMock)
|
||||
async def test_stablecoins_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("stables API error")
|
||||
resp = await client.get("/api/v1/defi/stablecoins")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/v1/defi/volumes/dexs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_dex_volumes", new_callable=AsyncMock)
|
||||
async def test_dex_volumes_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"totalVolume24h": 5_000_000_000.0,
|
||||
"totalVolume7d": 30_000_000_000.0,
|
||||
"protocols": [
|
||||
{"name": "Uniswap", "volume24h": 2_000_000_000.0},
|
||||
],
|
||||
}
|
||||
resp = await client.get("/api/v1/defi/volumes/dexs")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["totalVolume24h"] == 5_000_000_000.0
|
||||
assert data["data"]["totalVolume7d"] == 30_000_000_000.0
|
||||
assert len(data["data"]["protocols"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_dex_volumes", new_callable=AsyncMock)
|
||||
async def test_dex_volumes_returns_502_when_service_returns_none(mock_fn, client):
|
||||
mock_fn.return_value = None
|
||||
resp = await client.get("/api/v1/defi/volumes/dexs")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_dex_volumes", new_callable=AsyncMock)
|
||||
async def test_dex_volumes_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("volume API error")
|
||||
resp = await client.get("/api/v1/defi/volumes/dexs")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/v1/defi/fees
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_protocol_fees", new_callable=AsyncMock)
|
||||
async def test_fees_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"name": "Uniswap", "fees24h": 1_000_000.0, "revenue24h": 500_000.0},
|
||||
{"name": "Aave", "fees24h": 800_000.0, "revenue24h": 800_000.0},
|
||||
]
|
||||
resp = await client.get("/api/v1/defi/fees")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 2
|
||||
assert data["data"][0]["name"] == "Uniswap"
|
||||
assert data["data"][0]["fees24h"] == 1_000_000.0
|
||||
assert data["data"][0]["revenue24h"] == 500_000.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_protocol_fees", new_callable=AsyncMock)
|
||||
async def test_fees_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/defi/fees")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_defi.defi_service.get_protocol_fees", new_callable=AsyncMock)
|
||||
async def test_fees_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("fees API error")
|
||||
resp = await client.get("/api/v1/defi/fees")
|
||||
|
||||
assert resp.status_code == 502
|
||||
433
tests/test_routes_economy.py
Normal file
433
tests/test_routes_economy.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Tests for expanded economy routes."""
|
||||
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# --- CPI ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_cpi", new_callable=AsyncMock)
|
||||
async def test_macro_cpi_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-02-01", "value": 312.5, "country": "united_states"}
|
||||
]
|
||||
resp = await client.get("/api/v1/macro/cpi")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["value"] == 312.5
|
||||
mock_fn.assert_called_once_with(country="united_states")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_cpi", new_callable=AsyncMock)
|
||||
async def test_macro_cpi_custom_country(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-02-01", "value": 120.0}]
|
||||
resp = await client.get("/api/v1/macro/cpi?country=germany")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(country="germany")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_macro_cpi_invalid_country(client):
|
||||
resp = await client.get("/api/v1/macro/cpi?country=INVALID!!!COUNTRY")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_cpi", new_callable=AsyncMock)
|
||||
async def test_macro_cpi_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/macro/cpi")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_cpi", new_callable=AsyncMock)
|
||||
async def test_macro_cpi_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("FRED down")
|
||||
resp = await client.get("/api/v1/macro/cpi")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- GDP ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_gdp", new_callable=AsyncMock)
|
||||
async def test_macro_gdp_default_real(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-01-01", "value": 22.5}]
|
||||
resp = await client.get("/api/v1/macro/gdp")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
mock_fn.assert_called_once_with(gdp_type="real")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_gdp", new_callable=AsyncMock)
|
||||
async def test_macro_gdp_nominal(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-01-01", "value": 28.3}]
|
||||
resp = await client.get("/api/v1/macro/gdp?gdp_type=nominal")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(gdp_type="nominal")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_macro_gdp_invalid_type(client):
|
||||
resp = await client.get("/api/v1/macro/gdp?gdp_type=invalid")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_gdp", new_callable=AsyncMock)
|
||||
async def test_macro_gdp_forecast(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2027-01-01", "value": 23.1}]
|
||||
resp = await client.get("/api/v1/macro/gdp?gdp_type=forecast")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(gdp_type="forecast")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_gdp", new_callable=AsyncMock)
|
||||
async def test_macro_gdp_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/macro/gdp")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- Unemployment ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_unemployment", new_callable=AsyncMock)
|
||||
async def test_macro_unemployment_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-02-01", "value": 3.7, "country": "united_states"}]
|
||||
resp = await client.get("/api/v1/macro/unemployment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["value"] == 3.7
|
||||
mock_fn.assert_called_once_with(country="united_states")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_unemployment", new_callable=AsyncMock)
|
||||
async def test_macro_unemployment_custom_country(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-02-01", "value": 5.1}]
|
||||
resp = await client.get("/api/v1/macro/unemployment?country=france")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(country="france")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_unemployment", new_callable=AsyncMock)
|
||||
async def test_macro_unemployment_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/macro/unemployment")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- PCE ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_pce", new_callable=AsyncMock)
|
||||
async def test_macro_pce_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-02-01", "value": 2.8}]
|
||||
resp = await client.get("/api/v1/macro/pce")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["value"] == 2.8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_pce", new_callable=AsyncMock)
|
||||
async def test_macro_pce_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/macro/pce")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_pce", new_callable=AsyncMock)
|
||||
async def test_macro_pce_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("FRED unavailable")
|
||||
resp = await client.get("/api/v1/macro/pce")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- Money Measures ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_money_measures", new_callable=AsyncMock)
|
||||
async def test_macro_money_measures_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-02-01", "m1": 18200.0, "m2": 21000.0}]
|
||||
resp = await client.get("/api/v1/macro/money-measures")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["m2"] == 21000.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_money_measures", new_callable=AsyncMock)
|
||||
async def test_macro_money_measures_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/macro/money-measures")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- CLI ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_composite_leading_indicator", new_callable=AsyncMock)
|
||||
async def test_macro_cli_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-01-01", "value": 99.2, "country": "united_states"}]
|
||||
resp = await client.get("/api/v1/macro/cli")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["value"] == 99.2
|
||||
mock_fn.assert_called_once_with(country="united_states")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_composite_leading_indicator", new_callable=AsyncMock)
|
||||
async def test_macro_cli_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/macro/cli")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- House Price Index ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_house_price_index", new_callable=AsyncMock)
|
||||
async def test_macro_hpi_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-01-01", "value": 350.0, "country": "united_states"}]
|
||||
resp = await client.get("/api/v1/macro/house-price-index")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["value"] == 350.0
|
||||
mock_fn.assert_called_once_with(country="united_states")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_house_price_index", new_callable=AsyncMock)
|
||||
async def test_macro_hpi_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/macro/house-price-index")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- FRED Regional ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_fred_regional", new_callable=AsyncMock)
|
||||
async def test_economy_fred_regional_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [{"region": "CA", "value": 5.2}]
|
||||
resp = await client.get("/api/v1/economy/fred-regional?series_id=CAUR")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["region"] == "CA"
|
||||
mock_fn.assert_called_once_with(series_id="CAUR", region=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_fred_regional", new_callable=AsyncMock)
|
||||
async def test_economy_fred_regional_with_region(mock_fn, client):
|
||||
mock_fn.return_value = [{"region": "state", "value": 4.1}]
|
||||
resp = await client.get("/api/v1/economy/fred-regional?series_id=CAUR®ion=state")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(series_id="CAUR", region="state")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_economy_fred_regional_missing_series_id(client):
|
||||
resp = await client.get("/api/v1/economy/fred-regional")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_fred_regional", new_callable=AsyncMock)
|
||||
async def test_economy_fred_regional_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/fred-regional?series_id=UNKNOWN")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- Primary Dealer Positioning ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_primary_dealer_positioning", new_callable=AsyncMock)
|
||||
async def test_economy_primary_dealer_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-03-12", "treasuries": 250000.0, "mbs": 80000.0}]
|
||||
resp = await client.get("/api/v1/economy/primary-dealer-positioning")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["treasuries"] == 250000.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_primary_dealer_positioning", new_callable=AsyncMock)
|
||||
async def test_economy_primary_dealer_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/primary-dealer-positioning")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- FRED Search ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.fred_search", new_callable=AsyncMock)
|
||||
async def test_economy_fred_search_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"id": "FEDFUNDS", "title": "Effective Federal Funds Rate", "frequency": "Monthly"}
|
||||
]
|
||||
resp = await client.get("/api/v1/economy/fred-search?query=federal+funds")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["id"] == "FEDFUNDS"
|
||||
mock_fn.assert_called_once_with(query="federal funds")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_economy_fred_search_missing_query(client):
|
||||
resp = await client.get("/api/v1/economy/fred-search")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.fred_search", new_callable=AsyncMock)
|
||||
async def test_economy_fred_search_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/fred-search?query=nothingtofind")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- Balance of Payments ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_balance_of_payments", new_callable=AsyncMock)
|
||||
async def test_economy_bop_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-01-01", "current_account": -200.0, "capital_account": 5.0}]
|
||||
resp = await client.get("/api/v1/economy/balance-of-payments")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["current_account"] == -200.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_balance_of_payments", new_callable=AsyncMock)
|
||||
async def test_economy_bop_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/balance-of-payments")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- Central Bank Holdings ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_central_bank_holdings", new_callable=AsyncMock)
|
||||
async def test_economy_central_bank_holdings_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-03-13", "treasuries": 4500000.0, "mbs": 2300000.0}]
|
||||
resp = await client.get("/api/v1/economy/central-bank-holdings")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["treasuries"] == 4500000.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_central_bank_holdings", new_callable=AsyncMock)
|
||||
async def test_economy_central_bank_holdings_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/central-bank-holdings")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- FOMC Documents ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_fomc_documents", new_callable=AsyncMock)
|
||||
async def test_economy_fomc_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-01-28", "type": "Minutes", "url": "https://federalreserve.gov/fomc"}
|
||||
]
|
||||
resp = await client.get("/api/v1/economy/fomc-documents")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["type"] == "Minutes"
|
||||
mock_fn.assert_called_once_with(year=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_fomc_documents", new_callable=AsyncMock)
|
||||
async def test_economy_fomc_with_year(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2024-01-30", "type": "Statement"}]
|
||||
resp = await client.get("/api/v1/economy/fomc-documents?year=2024")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(year=2024)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_economy_fomc_invalid_year_too_low(client):
|
||||
resp = await client.get("/api/v1/economy/fomc-documents?year=1999")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_economy_fomc_invalid_year_too_high(client):
|
||||
resp = await client.get("/api/v1/economy/fomc-documents?year=2100")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_economy.economy_service.get_fomc_documents", new_callable=AsyncMock)
|
||||
async def test_economy_fomc_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/fomc-documents")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
328
tests/test_routes_fixed_income.py
Normal file
328
tests/test_routes_fixed_income.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Tests for fixed income routes."""
|
||||
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# --- Treasury Rates ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_treasury_rates", new_callable=AsyncMock)
|
||||
async def test_treasury_rates_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-18", "week_4": 5.27, "month_3": 5.30, "year_2": 4.85, "year_10": 4.32, "year_30": 4.55}
|
||||
]
|
||||
resp = await client.get("/api/v1/fixed-income/treasury-rates")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["year_10"] == 4.32
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_treasury_rates", new_callable=AsyncMock)
|
||||
async def test_treasury_rates_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/fixed-income/treasury-rates")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_treasury_rates", new_callable=AsyncMock)
|
||||
async def test_treasury_rates_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("Federal Reserve API down")
|
||||
resp = await client.get("/api/v1/fixed-income/treasury-rates")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- Yield Curve ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_yield_curve", new_callable=AsyncMock)
|
||||
async def test_yield_curve_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"maturity": "3M", "rate": 5.30},
|
||||
{"maturity": "2Y", "rate": 4.85},
|
||||
{"maturity": "10Y", "rate": 4.32},
|
||||
]
|
||||
resp = await client.get("/api/v1/fixed-income/yield-curve")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 3
|
||||
mock_fn.assert_called_once_with(date=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_yield_curve", new_callable=AsyncMock)
|
||||
async def test_yield_curve_with_date(mock_fn, client):
|
||||
mock_fn.return_value = [{"maturity": "10Y", "rate": 3.80}]
|
||||
resp = await client.get("/api/v1/fixed-income/yield-curve?date=2024-01-15")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(date="2024-01-15")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_yield_curve_invalid_date_format(client):
|
||||
resp = await client.get("/api/v1/fixed-income/yield-curve?date=not-a-date")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_yield_curve", new_callable=AsyncMock)
|
||||
async def test_yield_curve_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/fixed-income/yield-curve")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- Treasury Auctions ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_treasury_auctions", new_callable=AsyncMock)
|
||||
async def test_treasury_auctions_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"auction_date": "2026-03-10", "security_type": "Note", "security_term": "10-Year", "high_yield": 4.32, "bid_to_cover_ratio": 2.45}
|
||||
]
|
||||
resp = await client.get("/api/v1/fixed-income/treasury-auctions")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["bid_to_cover_ratio"] == 2.45
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_treasury_auctions", new_callable=AsyncMock)
|
||||
async def test_treasury_auctions_with_security_type(mock_fn, client):
|
||||
mock_fn.return_value = [{"security_type": "Bill"}]
|
||||
resp = await client.get("/api/v1/fixed-income/treasury-auctions?security_type=Bill")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(security_type="Bill")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_treasury_auctions", new_callable=AsyncMock)
|
||||
async def test_treasury_auctions_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/fixed-income/treasury-auctions")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_treasury_auctions_invalid_security_type(client):
|
||||
resp = await client.get("/api/v1/fixed-income/treasury-auctions?security_type=DROP;TABLE")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# --- TIPS Yields ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_tips_yields", new_callable=AsyncMock)
|
||||
async def test_tips_yields_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-18", "year_5": 2.10, "year_10": 2.25, "year_30": 2.40}
|
||||
]
|
||||
resp = await client.get("/api/v1/fixed-income/tips-yields")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["year_10"] == 2.25
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_tips_yields", new_callable=AsyncMock)
|
||||
async def test_tips_yields_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/fixed-income/tips-yields")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_tips_yields", new_callable=AsyncMock)
|
||||
async def test_tips_yields_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("FRED unavailable")
|
||||
resp = await client.get("/api/v1/fixed-income/tips-yields")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- EFFR ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_effr", new_callable=AsyncMock)
|
||||
async def test_effr_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-18", "rate": 5.33, "percentile_1": 5.31, "percentile_25": 5.32, "percentile_75": 5.33}
|
||||
]
|
||||
resp = await client.get("/api/v1/fixed-income/effr")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["rate"] == 5.33
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_effr", new_callable=AsyncMock)
|
||||
async def test_effr_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/fixed-income/effr")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- SOFR ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_sofr", new_callable=AsyncMock)
|
||||
async def test_sofr_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-18", "rate": 5.31, "average_30d": 5.31, "average_90d": 5.30}
|
||||
]
|
||||
resp = await client.get("/api/v1/fixed-income/sofr")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["rate"] == 5.31
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_sofr", new_callable=AsyncMock)
|
||||
async def test_sofr_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/fixed-income/sofr")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- HQM ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_hqm", new_callable=AsyncMock)
|
||||
async def test_hqm_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-02-01", "aaa": 5.10, "aa": 5.25, "a": 5.40}
|
||||
]
|
||||
resp = await client.get("/api/v1/fixed-income/hqm")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["aaa"] == 5.10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_hqm", new_callable=AsyncMock)
|
||||
async def test_hqm_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/fixed-income/hqm")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- Commercial Paper ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_commercial_paper", new_callable=AsyncMock)
|
||||
async def test_commercial_paper_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-18", "maturity": "overnight", "financial": 5.28, "nonfinancial": 5.30}
|
||||
]
|
||||
resp = await client.get("/api/v1/fixed-income/commercial-paper")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_commercial_paper", new_callable=AsyncMock)
|
||||
async def test_commercial_paper_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/fixed-income/commercial-paper")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- Spot Rates ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_spot_rates", new_callable=AsyncMock)
|
||||
async def test_spot_rates_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-01", "year_1": 5.50, "year_5": 5.20, "year_10": 5.10}
|
||||
]
|
||||
resp = await client.get("/api/v1/fixed-income/spot-rates")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["year_10"] == 5.10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_spot_rates", new_callable=AsyncMock)
|
||||
async def test_spot_rates_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/fixed-income/spot-rates")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
# --- Spreads ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_spreads", new_callable=AsyncMock)
|
||||
async def test_spreads_default(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-03-18", "spread": 1.10}]
|
||||
resp = await client.get("/api/v1/fixed-income/spreads")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
mock_fn.assert_called_once_with(series="tcm")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_spreads", new_callable=AsyncMock)
|
||||
async def test_spreads_tcm_effr(mock_fn, client):
|
||||
mock_fn.return_value = [{"date": "2026-03-18", "spread": 0.02}]
|
||||
resp = await client.get("/api/v1/fixed-income/spreads?series=tcm_effr")
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(series="tcm_effr")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spreads_invalid_series(client):
|
||||
resp = await client.get("/api/v1/fixed-income/spreads?series=invalid")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_fixed_income.fixed_income_service.get_spreads", new_callable=AsyncMock)
|
||||
async def test_spreads_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/fixed-income/spreads?series=treasury_effr")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
493
tests/test_routes_portfolio.py
Normal file
493
tests/test_routes_portfolio.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""Tests for portfolio optimization routes (TDD - RED phase first)."""
|
||||
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# --- POST /api/v1/portfolio/optimize ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.optimize_hrp", new_callable=AsyncMock)
|
||||
async def test_portfolio_optimize_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"weights": {"AAPL": 0.35, "MSFT": 0.32, "GOOGL": 0.33},
|
||||
"method": "hrp",
|
||||
}
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/optimize",
|
||||
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 365},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["method"] == "hrp"
|
||||
assert "AAPL" in data["data"]["weights"]
|
||||
mock_fn.assert_called_once_with(["AAPL", "MSFT", "GOOGL"], days=365)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_optimize_missing_symbols(client):
|
||||
resp = await client.post("/api/v1/portfolio/optimize", json={"days": 365})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_optimize_empty_symbols(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/optimize", json={"symbols": [], "days": 365}
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_optimize_too_many_symbols(client):
|
||||
symbols = [f"SYM{i}" for i in range(51)]
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/optimize", json={"symbols": symbols, "days": 365}
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.optimize_hrp", new_callable=AsyncMock)
|
||||
async def test_portfolio_optimize_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("Computation failed")
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/optimize",
|
||||
json={"symbols": ["AAPL", "MSFT"], "days": 365},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.optimize_hrp", new_callable=AsyncMock)
|
||||
async def test_portfolio_optimize_value_error_returns_400(mock_fn, client):
|
||||
mock_fn.side_effect = ValueError("No price data available")
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/optimize",
|
||||
json={"symbols": ["AAPL", "MSFT"], "days": 365},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.optimize_hrp", new_callable=AsyncMock)
|
||||
async def test_portfolio_optimize_default_days(mock_fn, client):
|
||||
mock_fn.return_value = {"weights": {"AAPL": 1.0}, "method": "hrp"}
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/optimize", json={"symbols": ["AAPL"]}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(["AAPL"], days=365)
|
||||
|
||||
|
||||
# --- POST /api/v1/portfolio/correlation ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.compute_correlation", new_callable=AsyncMock)
|
||||
async def test_portfolio_correlation_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"symbols": ["AAPL", "MSFT"],
|
||||
"matrix": [[1.0, 0.85], [0.85, 1.0]],
|
||||
}
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/correlation",
|
||||
json={"symbols": ["AAPL", "MSFT"], "days": 365},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["symbols"] == ["AAPL", "MSFT"]
|
||||
assert data["data"]["matrix"][0][0] == pytest.approx(1.0)
|
||||
mock_fn.assert_called_once_with(["AAPL", "MSFT"], days=365)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_correlation_missing_symbols(client):
|
||||
resp = await client.post("/api/v1/portfolio/correlation", json={"days": 365})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_correlation_empty_symbols(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/correlation", json={"symbols": [], "days": 365}
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.compute_correlation", new_callable=AsyncMock)
|
||||
async def test_portfolio_correlation_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("Failed")
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/correlation",
|
||||
json={"symbols": ["AAPL", "MSFT"], "days": 365},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.compute_correlation", new_callable=AsyncMock)
|
||||
async def test_portfolio_correlation_value_error_returns_400(mock_fn, client):
|
||||
mock_fn.side_effect = ValueError("No price data available")
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/correlation",
|
||||
json={"symbols": ["AAPL", "MSFT"], "days": 365},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# --- POST /api/v1/portfolio/risk-parity ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.compute_risk_parity", new_callable=AsyncMock)
|
||||
async def test_portfolio_risk_parity_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"weights": {"AAPL": 0.35, "MSFT": 0.33, "GOOGL": 0.32},
|
||||
"risk_contributions": {"AAPL": 0.34, "MSFT": 0.33, "GOOGL": 0.33},
|
||||
"method": "risk_parity",
|
||||
}
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/risk-parity",
|
||||
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 365},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["method"] == "risk_parity"
|
||||
assert "risk_contributions" in data["data"]
|
||||
mock_fn.assert_called_once_with(["AAPL", "MSFT", "GOOGL"], days=365)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_risk_parity_missing_symbols(client):
|
||||
resp = await client.post("/api/v1/portfolio/risk-parity", json={"days": 365})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_risk_parity_empty_symbols(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/risk-parity", json={"symbols": [], "days": 365}
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.compute_risk_parity", new_callable=AsyncMock)
|
||||
async def test_portfolio_risk_parity_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("Failed")
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/risk-parity",
|
||||
json={"symbols": ["AAPL", "MSFT"], "days": 365},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.compute_risk_parity", new_callable=AsyncMock)
|
||||
async def test_portfolio_risk_parity_value_error_returns_400(mock_fn, client):
|
||||
mock_fn.side_effect = ValueError("No price data available")
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/risk-parity",
|
||||
json={"symbols": ["AAPL", "MSFT"], "days": 365},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.compute_risk_parity", new_callable=AsyncMock)
|
||||
async def test_portfolio_risk_parity_default_days(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"weights": {"AAPL": 1.0},
|
||||
"risk_contributions": {"AAPL": 1.0},
|
||||
"method": "risk_parity",
|
||||
}
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/risk-parity", json={"symbols": ["AAPL"]}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(["AAPL"], days=365)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/v1/portfolio/cluster
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CLUSTER_RESULT = {
|
||||
"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
|
||||
"coordinates": [
|
||||
{"symbol": "AAPL", "x": 12.5, "y": -3.2, "cluster": 0},
|
||||
{"symbol": "MSFT", "x": 11.8, "y": -2.9, "cluster": 0},
|
||||
{"symbol": "GOOGL", "x": 10.1, "y": -1.5, "cluster": 0},
|
||||
{"symbol": "AMZN", "x": 9.5, "y": -0.8, "cluster": 0},
|
||||
{"symbol": "JPM", "x": -5.1, "y": 8.3, "cluster": 1},
|
||||
{"symbol": "BAC", "x": -4.9, "y": 7.9, "cluster": 1},
|
||||
],
|
||||
"clusters": {"0": ["AAPL", "MSFT", "GOOGL", "AMZN"], "1": ["JPM", "BAC"]},
|
||||
"method": "t-SNE + KMeans",
|
||||
"n_clusters": 2,
|
||||
"days": 180,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
|
||||
async def test_portfolio_cluster_happy_path(mock_fn, client):
|
||||
"""POST /cluster returns 200 with valid cluster result."""
|
||||
mock_fn.return_value = _CLUSTER_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/cluster",
|
||||
json={"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"], "days": 180},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["method"] == "t-SNE + KMeans"
|
||||
assert "coordinates" in data["data"]
|
||||
assert "clusters" in data["data"]
|
||||
mock_fn.assert_called_once_with(
|
||||
["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"], days=180, n_clusters=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
|
||||
async def test_portfolio_cluster_with_custom_n_clusters(mock_fn, client):
|
||||
"""n_clusters is forwarded to service when provided."""
|
||||
mock_fn.return_value = _CLUSTER_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/cluster",
|
||||
json={
|
||||
"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
|
||||
"days": 180,
|
||||
"n_clusters": 3,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(
|
||||
["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"], days=180, n_clusters=3
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_cluster_too_few_symbols_returns_422(client):
|
||||
"""Fewer than 3 symbols triggers Pydantic validation error (422)."""
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/cluster",
|
||||
json={"symbols": ["AAPL", "MSFT"], "days": 180},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_cluster_missing_symbols_returns_422(client):
|
||||
"""Missing symbols field returns 422."""
|
||||
resp = await client.post("/api/v1/portfolio/cluster", json={"days": 180})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_cluster_too_many_symbols_returns_422(client):
|
||||
"""More than 50 symbols returns 422."""
|
||||
symbols = [f"SYM{i}" for i in range(51)]
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/cluster", json={"symbols": symbols, "days": 180}
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_cluster_days_below_minimum_returns_422(client):
|
||||
"""days < 30 returns 422."""
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/cluster",
|
||||
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 10},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_cluster_n_clusters_below_minimum_returns_422(client):
|
||||
"""n_clusters < 2 returns 422."""
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/cluster",
|
||||
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 180, "n_clusters": 1},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
|
||||
async def test_portfolio_cluster_value_error_returns_400(mock_fn, client):
|
||||
"""ValueError from service returns 400."""
|
||||
mock_fn.side_effect = ValueError("at least 3 symbols required")
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/cluster",
|
||||
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 180},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
|
||||
async def test_portfolio_cluster_upstream_error_returns_502(mock_fn, client):
|
||||
"""Unexpected exception from service returns 502."""
|
||||
mock_fn.side_effect = RuntimeError("upstream failure")
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/cluster",
|
||||
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 180},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
|
||||
async def test_portfolio_cluster_default_days(mock_fn, client):
|
||||
"""Default days=180 is used when not provided."""
|
||||
mock_fn.return_value = _CLUSTER_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/cluster",
|
||||
json={"symbols": ["AAPL", "MSFT", "GOOGL"]},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with(
|
||||
["AAPL", "MSFT", "GOOGL"], days=180, n_clusters=None
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/v1/portfolio/similar
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SIMILAR_RESULT = {
|
||||
"symbol": "AAPL",
|
||||
"most_similar": [
|
||||
{"symbol": "MSFT", "correlation": 0.85},
|
||||
{"symbol": "GOOGL", "correlation": 0.78},
|
||||
],
|
||||
"least_similar": [
|
||||
{"symbol": "JPM", "correlation": 0.32},
|
||||
{"symbol": "BAC", "correlation": 0.28},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
|
||||
async def test_portfolio_similar_happy_path(mock_fn, client):
|
||||
"""POST /similar returns 200 with most_similar and least_similar."""
|
||||
mock_fn.return_value = _SIMILAR_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/similar",
|
||||
json={
|
||||
"symbol": "AAPL",
|
||||
"universe": ["MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
|
||||
"days": 180,
|
||||
"top_n": 2,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["symbol"] == "AAPL"
|
||||
assert "most_similar" in data["data"]
|
||||
assert "least_similar" in data["data"]
|
||||
mock_fn.assert_called_once_with(
|
||||
"AAPL",
|
||||
["MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
|
||||
days=180,
|
||||
top_n=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_similar_missing_symbol_returns_422(client):
|
||||
"""Missing symbol field returns 422."""
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/similar",
|
||||
json={"universe": ["MSFT", "GOOGL"], "days": 180},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_similar_missing_universe_returns_422(client):
|
||||
"""Missing universe field returns 422."""
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/similar",
|
||||
json={"symbol": "AAPL", "days": 180},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_similar_universe_too_small_returns_422(client):
|
||||
"""universe with fewer than 2 entries returns 422."""
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/similar",
|
||||
json={"symbol": "AAPL", "universe": ["MSFT"], "days": 180},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_portfolio_similar_top_n_below_minimum_returns_422(client):
|
||||
"""top_n < 1 returns 422."""
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/similar",
|
||||
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL"], "days": 180, "top_n": 0},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
|
||||
async def test_portfolio_similar_value_error_returns_400(mock_fn, client):
|
||||
"""ValueError from service returns 400."""
|
||||
mock_fn.side_effect = ValueError("AAPL not found in price data")
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/similar",
|
||||
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL"], "days": 180},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
|
||||
async def test_portfolio_similar_upstream_error_returns_502(mock_fn, client):
|
||||
"""Unexpected exception from service returns 502."""
|
||||
mock_fn.side_effect = RuntimeError("upstream failure")
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/similar",
|
||||
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL"], "days": 180},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
|
||||
async def test_portfolio_similar_default_top_n(mock_fn, client):
|
||||
"""Default top_n=5 is passed to service when not specified."""
|
||||
mock_fn.return_value = _SIMILAR_RESULT
|
||||
resp = await client.post(
|
||||
"/api/v1/portfolio/similar",
|
||||
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL", "AMZN"]},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
mock_fn.assert_called_once_with("AAPL", ["MSFT", "GOOGL", "AMZN"], days=180, top_n=5)
|
||||
228
tests/test_routes_regulators.py
Normal file
228
tests/test_routes_regulators.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Tests for regulatory data routes (CFTC, SEC)."""
|
||||
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# --- COT Report ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.get_cot", new_callable=AsyncMock)
|
||||
async def test_cot_report_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{
|
||||
"date": "2026-03-11",
|
||||
"symbol": "ES",
|
||||
"commercial_long": 250000,
|
||||
"commercial_short": 300000,
|
||||
"noncommercial_long": 180000,
|
||||
"noncommercial_short": 120000,
|
||||
}
|
||||
]
|
||||
resp = await client.get("/api/v1/regulators/cot?symbol=ES")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["commercial_long"] == 250000
|
||||
mock_fn.assert_called_once_with("ES")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.get_cot", new_callable=AsyncMock)
|
||||
async def test_cot_report_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/regulators/cot?symbol=UNKNOWN")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.get_cot", new_callable=AsyncMock)
|
||||
async def test_cot_report_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("CFTC unavailable")
|
||||
resp = await client.get("/api/v1/regulators/cot?symbol=ES")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cot_report_missing_symbol(client):
|
||||
resp = await client.get("/api/v1/regulators/cot")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cot_report_invalid_symbol(client):
|
||||
resp = await client.get("/api/v1/regulators/cot?symbol=DROP;TABLE")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# --- COT Search ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.cot_search", new_callable=AsyncMock)
|
||||
async def test_cot_search_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"code": "13874P", "name": "E-MINI S&P 500"},
|
||||
{"code": "13874A", "name": "S&P 500 CONSOLIDATED"},
|
||||
]
|
||||
resp = await client.get("/api/v1/regulators/cot/search?query=S%26P+500")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 2
|
||||
assert data["data"][0]["name"] == "E-MINI S&P 500"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cot_search_missing_query(client):
|
||||
resp = await client.get("/api/v1/regulators/cot/search")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.cot_search", new_callable=AsyncMock)
|
||||
async def test_cot_search_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/regulators/cot/search?query=nonexistentfutures")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.cot_search", new_callable=AsyncMock)
|
||||
async def test_cot_search_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("CFTC search failed")
|
||||
resp = await client.get("/api/v1/regulators/cot/search?query=gold")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- SEC Litigation ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.get_sec_litigation", new_callable=AsyncMock)
|
||||
async def test_sec_litigation_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{
|
||||
"date": "2026-03-15",
|
||||
"title": "SEC Charges Former CEO with Fraud",
|
||||
"url": "https://sec.gov/litigation/lr/2026/lr-99999.htm",
|
||||
"summary": "The Commission charged...",
|
||||
}
|
||||
]
|
||||
resp = await client.get("/api/v1/regulators/sec/litigation")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 1
|
||||
assert "CEO" in data["data"][0]["title"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.get_sec_litigation", new_callable=AsyncMock)
|
||||
async def test_sec_litigation_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/regulators/sec/litigation")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.get_sec_litigation", new_callable=AsyncMock)
|
||||
async def test_sec_litigation_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("SEC RSS feed unavailable")
|
||||
resp = await client.get("/api/v1/regulators/sec/litigation")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- SEC Institution Search ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.search_institutions", new_callable=AsyncMock)
|
||||
async def test_sec_institutions_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"name": "Vanguard Group Inc", "cik": "0000102909"},
|
||||
{"name": "BlackRock Inc", "cik": "0001364742"},
|
||||
]
|
||||
resp = await client.get("/api/v1/regulators/sec/institutions?query=vanguard")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 2
|
||||
assert data["data"][0]["name"] == "Vanguard Group Inc"
|
||||
mock_fn.assert_called_once_with("vanguard")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sec_institutions_missing_query(client):
|
||||
resp = await client.get("/api/v1/regulators/sec/institutions")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.search_institutions", new_callable=AsyncMock)
|
||||
async def test_sec_institutions_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/regulators/sec/institutions?query=notarealfirm")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.search_institutions", new_callable=AsyncMock)
|
||||
async def test_sec_institutions_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("SEC API failed")
|
||||
resp = await client.get("/api/v1/regulators/sec/institutions?query=blackrock")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- SEC CIK Map ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.get_cik_map", new_callable=AsyncMock)
|
||||
async def test_sec_cik_map_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [{"symbol": "AAPL", "cik": "0000320193"}]
|
||||
resp = await client.get("/api/v1/regulators/sec/cik-map/AAPL")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["cik"] == "0000320193"
|
||||
mock_fn.assert_called_once_with("AAPL")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.get_cik_map", new_callable=AsyncMock)
|
||||
async def test_sec_cik_map_not_found(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/regulators/sec/cik-map/XXXX")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_regulators.regulators_service.get_cik_map", new_callable=AsyncMock)
|
||||
async def test_sec_cik_map_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("SEC lookup failed")
|
||||
resp = await client.get("/api/v1/regulators/sec/cik-map/AAPL")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sec_cik_map_invalid_symbol(client):
|
||||
resp = await client.get("/api/v1/regulators/sec/cik-map/INVALID!!!")
|
||||
assert resp.status_code == 400
|
||||
@@ -14,13 +14,17 @@ async def client():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
|
||||
async def test_stock_sentiment(mock_sentiment, mock_av, client):
|
||||
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
|
||||
async def test_stock_sentiment(mock_av, mock_sentiment, mock_reddit, mock_upgrades, mock_recs, client):
|
||||
# Route was refactored to return composite_score/composite_label/details/source_scores
|
||||
mock_sentiment.return_value = {
|
||||
"symbol": "AAPL",
|
||||
"news_sentiment": {"bullish_percent": 0.7, "bearish_percent": 0.3},
|
||||
"recent_news": [],
|
||||
"recent_news": [{"headline": "Apple strong"}],
|
||||
"analyst_recommendations": [],
|
||||
"recent_upgrades_downgrades": [],
|
||||
}
|
||||
@@ -31,12 +35,22 @@ async def test_stock_sentiment(mock_sentiment, mock_av, client):
|
||||
"overall_sentiment": {"avg_score": 0.4, "label": "Bullish"},
|
||||
"articles": [],
|
||||
}
|
||||
mock_reddit.return_value = {"found": False, "symbol": "AAPL"}
|
||||
mock_upgrades.return_value = []
|
||||
mock_recs.return_value = []
|
||||
resp = await client.get("/api/v1/stock/AAPL/sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert data["symbol"] == "AAPL"
|
||||
assert data["news_sentiment"]["bullish_percent"] == 0.7
|
||||
assert data["alpha_vantage_sentiment"]["overall_sentiment"]["label"] == "Bullish"
|
||||
# New composite response shape
|
||||
assert "composite_score" in data
|
||||
assert "composite_label" in data
|
||||
assert "source_scores" in data
|
||||
assert "details" in data
|
||||
# AV news data accessible via details
|
||||
assert data["details"]["news_sentiment"]["overall_sentiment"]["label"] == "Bullish"
|
||||
# Finnhub news accessible via details
|
||||
assert len(data["details"]["finnhub_news"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
358
tests/test_routes_sentiment_social.py
Normal file
358
tests/test_routes_sentiment_social.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Tests for new sentiment routes: social sentiment, reddit, composite sentiment, trending."""
|
||||
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# --- Social Sentiment (Finnhub) ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_social_sentiment", new_callable=AsyncMock)
|
||||
async def test_stock_social_sentiment_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"configured": True,
|
||||
"symbol": "AAPL",
|
||||
"reddit_summary": {"total_mentions": 150, "avg_score": 0.55, "data_points": 5},
|
||||
"twitter_summary": {"total_mentions": 300, "avg_score": 0.40, "data_points": 8},
|
||||
"reddit": [{"mention": 30, "score": 0.5}],
|
||||
"twitter": [{"mention": 40, "score": 0.4}],
|
||||
}
|
||||
resp = await client.get("/api/v1/stock/AAPL/social-sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["symbol"] == "AAPL"
|
||||
assert data["data"]["reddit_summary"]["total_mentions"] == 150
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_social_sentiment", new_callable=AsyncMock)
|
||||
async def test_stock_social_sentiment_not_configured(mock_fn, client):
|
||||
mock_fn.return_value = {"configured": False, "message": "Set INVEST_API_FINNHUB_API_KEY"}
|
||||
resp = await client.get("/api/v1/stock/AAPL/social-sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["configured"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_social_sentiment", new_callable=AsyncMock)
|
||||
async def test_stock_social_sentiment_premium_required(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"configured": True,
|
||||
"symbol": "AAPL",
|
||||
"premium_required": True,
|
||||
"reddit": [],
|
||||
"twitter": [],
|
||||
}
|
||||
resp = await client.get("/api/v1/stock/AAPL/social-sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["premium_required"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_social_sentiment", new_callable=AsyncMock)
|
||||
async def test_stock_social_sentiment_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("Finnhub error")
|
||||
resp = await client.get("/api/v1/stock/AAPL/social-sentiment")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stock_social_sentiment_invalid_symbol(client):
|
||||
resp = await client.get("/api/v1/stock/INVALID!!!/social-sentiment")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# --- Reddit Sentiment ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
async def test_stock_reddit_sentiment_found(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"symbol": "AAPL",
|
||||
"found": True,
|
||||
"rank": 3,
|
||||
"mentions_24h": 150,
|
||||
"mentions_24h_ago": 100,
|
||||
"mentions_change_pct": 50.0,
|
||||
"upvotes": 500,
|
||||
"rank_24h_ago": 5,
|
||||
}
|
||||
resp = await client.get("/api/v1/stock/AAPL/reddit-sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["found"] is True
|
||||
assert data["data"]["rank"] == 3
|
||||
assert data["data"]["mentions_24h"] == 150
|
||||
assert data["data"]["mentions_change_pct"] == 50.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
async def test_stock_reddit_sentiment_not_found(mock_fn, client):
|
||||
mock_fn.return_value = {
|
||||
"symbol": "OBSCURE",
|
||||
"found": False,
|
||||
"message": "OBSCURE not in Reddit top trending (not enough mentions)",
|
||||
}
|
||||
resp = await client.get("/api/v1/stock/OBSCURE/reddit-sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["data"]["found"] is False
|
||||
assert "not in Reddit" in data["data"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
async def test_stock_reddit_sentiment_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("ApeWisdom down")
|
||||
resp = await client.get("/api/v1/stock/AAPL/reddit-sentiment")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stock_reddit_sentiment_invalid_symbol(client):
|
||||
resp = await client.get("/api/v1/stock/BAD!!!/reddit-sentiment")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# --- Reddit Trending ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_trending", new_callable=AsyncMock)
|
||||
async def test_reddit_trending_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"rank": 1, "symbol": "TSLA", "name": "Tesla", "mentions_24h": 500, "upvotes": 1200, "rank_24h_ago": 2, "mentions_24h_ago": 400},
|
||||
{"rank": 2, "symbol": "AAPL", "name": "Apple", "mentions_24h": 300, "upvotes": 800, "rank_24h_ago": 1, "mentions_24h_ago": 350},
|
||||
{"rank": 3, "symbol": "GME", "name": "GameStop", "mentions_24h": 200, "upvotes": 600, "rank_24h_ago": 3, "mentions_24h_ago": 180},
|
||||
]
|
||||
resp = await client.get("/api/v1/discover/reddit-trending")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 3
|
||||
assert data["data"][0]["symbol"] == "TSLA"
|
||||
assert data["data"][0]["rank"] == 1
|
||||
assert data["data"][1]["symbol"] == "AAPL"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_trending", new_callable=AsyncMock)
|
||||
async def test_reddit_trending_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/discover/reddit-trending")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_trending", new_callable=AsyncMock)
|
||||
async def test_reddit_trending_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("ApeWisdom unavailable")
|
||||
resp = await client.get("/api/v1/discover/reddit-trending")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- Composite /stock/{symbol}/sentiment (aggregation logic) ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
|
||||
async def test_composite_sentiment_all_sources(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
|
||||
mock_av.return_value = {
|
||||
"configured": True,
|
||||
"symbol": "AAPL",
|
||||
"overall_sentiment": {"avg_score": 0.2, "label": "Bullish"},
|
||||
"articles": [],
|
||||
}
|
||||
mock_fh.return_value = {
|
||||
"symbol": "AAPL",
|
||||
"news_sentiment": {},
|
||||
"recent_news": [{"headline": "Apple rises", "source": "Reuters"}],
|
||||
"analyst_recommendations": [],
|
||||
"recent_upgrades_downgrades": [],
|
||||
}
|
||||
mock_reddit.return_value = {
|
||||
"symbol": "AAPL",
|
||||
"found": True,
|
||||
"rank": 2,
|
||||
"mentions_24h": 200,
|
||||
"mentions_24h_ago": 150,
|
||||
"mentions_change_pct": 33.3,
|
||||
"upvotes": 800,
|
||||
}
|
||||
mock_upgrades.return_value = [
|
||||
{"action": "up", "company": "Goldman"},
|
||||
{"action": "down", "company": "Morgan Stanley"},
|
||||
{"action": "init", "company": "JPMorgan"},
|
||||
]
|
||||
mock_recs.return_value = [
|
||||
{"strongBuy": 10, "buy": 15, "hold": 5, "sell": 2, "strongSell": 1}
|
||||
]
|
||||
|
||||
resp = await client.get("/api/v1/stock/AAPL/sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
result = data["data"]
|
||||
assert result["symbol"] == "AAPL"
|
||||
assert result["composite_score"] is not None
|
||||
assert result["composite_label"] in ("Strong Bullish", "Bullish", "Neutral", "Bearish", "Strong Bearish")
|
||||
assert "news" in result["source_scores"]
|
||||
assert "analysts" in result["source_scores"]
|
||||
assert "upgrades" in result["source_scores"]
|
||||
assert "reddit" in result["source_scores"]
|
||||
assert "details" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
|
||||
async def test_composite_sentiment_no_data_returns_unknown(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
|
||||
mock_av.return_value = {}
|
||||
mock_fh.return_value = {}
|
||||
mock_reddit.return_value = {"found": False}
|
||||
mock_upgrades.return_value = []
|
||||
mock_recs.return_value = []
|
||||
|
||||
resp = await client.get("/api/v1/stock/AAPL/sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert data["composite_score"] is None
|
||||
assert data["composite_label"] == "Unknown"
|
||||
assert data["source_scores"] == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
|
||||
async def test_composite_sentiment_strong_bullish_label(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
|
||||
# All signals strongly bullish
|
||||
mock_av.return_value = {"overall_sentiment": {"avg_score": 0.35}}
|
||||
mock_fh.return_value = {}
|
||||
mock_reddit.return_value = {"found": True, "mentions_24h": 500, "mentions_change_pct": 100.0}
|
||||
mock_upgrades.return_value = [{"action": "up"}, {"action": "up"}, {"action": "up"}]
|
||||
mock_recs.return_value = [{"strongBuy": 20, "buy": 10, "hold": 1, "sell": 0, "strongSell": 0}]
|
||||
|
||||
resp = await client.get("/api/v1/stock/AAPL/sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert data["composite_score"] >= 0.5
|
||||
assert data["composite_label"] == "Strong Bullish"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
|
||||
async def test_composite_sentiment_bearish_label(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
|
||||
# All signals bearish
|
||||
mock_av.return_value = {"overall_sentiment": {"avg_score": -0.3}}
|
||||
mock_fh.return_value = {}
|
||||
mock_reddit.return_value = {"found": True, "mentions_24h": 200, "mentions_change_pct": -70.0}
|
||||
mock_upgrades.return_value = [{"action": "down"}, {"action": "down"}, {"action": "down"}]
|
||||
mock_recs.return_value = [{"strongBuy": 0, "buy": 2, "hold": 5, "sell": 10, "strongSell": 5}]
|
||||
|
||||
resp = await client.get("/api/v1/stock/AAPL/sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert data["composite_label"] in ("Bearish", "Strong Bearish")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
|
||||
async def test_composite_sentiment_one_source_failing_is_graceful(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
|
||||
# Simulate an exception from one source — gather uses return_exceptions=True
|
||||
mock_av.side_effect = RuntimeError("AV down")
|
||||
mock_fh.return_value = {}
|
||||
mock_reddit.return_value = {"found": False}
|
||||
mock_upgrades.return_value = []
|
||||
mock_recs.return_value = [{"strongBuy": 5, "buy": 5, "hold": 3, "sell": 1, "strongSell": 0}]
|
||||
|
||||
resp = await client.get("/api/v1/stock/AAPL/sentiment")
|
||||
# Should still succeed, gracefully skipping the failed source
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert data["symbol"] == "AAPL"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_sentiment_invalid_symbol(client):
|
||||
resp = await client.get("/api/v1/stock/INVALID!!!/sentiment")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
|
||||
async def test_composite_sentiment_reddit_low_mentions_excluded(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
|
||||
# Reddit mentions < 10 threshold should exclude reddit from scoring
|
||||
mock_av.return_value = {}
|
||||
mock_fh.return_value = {}
|
||||
mock_reddit.return_value = {"found": True, "mentions_24h": 5, "mentions_change_pct": 50.0}
|
||||
mock_upgrades.return_value = []
|
||||
mock_recs.return_value = []
|
||||
|
||||
resp = await client.get("/api/v1/stock/AAPL/sentiment")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()["data"]
|
||||
assert "reddit" not in data["source_scores"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
|
||||
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
|
||||
async def test_composite_sentiment_details_structure(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
|
||||
mock_av.return_value = {"overall_sentiment": {"avg_score": 0.1}}
|
||||
mock_fh.return_value = {"recent_news": [{"headline": "Test news"}]}
|
||||
mock_reddit.return_value = {"found": False}
|
||||
mock_upgrades.return_value = [{"action": "up"}, {"action": "up"}]
|
||||
mock_recs.return_value = []
|
||||
|
||||
resp = await client.get("/api/v1/stock/MSFT/sentiment")
|
||||
assert resp.status_code == 200
|
||||
details = resp.json()["data"]["details"]
|
||||
assert "news_sentiment" in details
|
||||
assert "analyst_recommendations" in details
|
||||
assert "recent_upgrades" in details
|
||||
assert "reddit" in details
|
||||
assert "finnhub_news" in details
|
||||
176
tests/test_routes_shorts.py
Normal file
176
tests/test_routes_shorts.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Tests for shorts and dark pool routes."""
|
||||
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# --- Short Volume ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_short_volume", new_callable=AsyncMock)
|
||||
async def test_short_volume_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-18", "short_volume": 5000000, "short_exempt_volume": 10000, "total_volume": 20000000, "short_volume_percent": 0.25}
|
||||
]
|
||||
resp = await client.get("/api/v1/stock/AAPL/shorts/volume")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["short_volume"] == 5000000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_short_volume", new_callable=AsyncMock)
|
||||
async def test_short_volume_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/stock/GME/shorts/volume")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_short_volume", new_callable=AsyncMock)
|
||||
async def test_short_volume_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("stockgrid unavailable")
|
||||
resp = await client.get("/api/v1/stock/AAPL/shorts/volume")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_volume_invalid_symbol(client):
|
||||
resp = await client.get("/api/v1/stock/AAPL;DROP/shorts/volume")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# --- Fails To Deliver ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_fails_to_deliver", new_callable=AsyncMock)
|
||||
async def test_ftd_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-01", "cusip": "037833100", "failure_quantity": 50000, "symbol": "AAPL", "price": 175.0}
|
||||
]
|
||||
resp = await client.get("/api/v1/stock/AAPL/shorts/ftd")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["symbol"] == "AAPL"
|
||||
assert data["data"][0]["failure_quantity"] == 50000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_fails_to_deliver", new_callable=AsyncMock)
|
||||
async def test_ftd_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/stock/TSLA/shorts/ftd")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_fails_to_deliver", new_callable=AsyncMock)
|
||||
async def test_ftd_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("SEC connection failed")
|
||||
resp = await client.get("/api/v1/stock/AAPL/shorts/ftd")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ftd_invalid_symbol(client):
|
||||
resp = await client.get("/api/v1/stock/BAD!!!/shorts/ftd")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# --- Short Interest ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_short_interest", new_callable=AsyncMock)
|
||||
async def test_short_interest_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"settlement_date": "2026-02-28", "symbol": "GME", "short_interest": 20000000, "days_to_cover": 3.5}
|
||||
]
|
||||
resp = await client.get("/api/v1/stock/GME/shorts/interest")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["short_interest"] == 20000000
|
||||
assert data["data"][0]["days_to_cover"] == 3.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_short_interest", new_callable=AsyncMock)
|
||||
async def test_short_interest_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/stock/NVDA/shorts/interest")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_short_interest", new_callable=AsyncMock)
|
||||
async def test_short_interest_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("FINRA unavailable")
|
||||
resp = await client.get("/api/v1/stock/AAPL/shorts/interest")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_interest_invalid_symbol(client):
|
||||
resp = await client.get("/api/v1/stock/INVALID!!!/shorts/interest")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# --- Dark Pool OTC ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_darkpool_otc", new_callable=AsyncMock)
|
||||
async def test_darkpool_otc_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-18", "symbol": "AAPL", "shares": 3000000, "percentage": 12.5}
|
||||
]
|
||||
resp = await client.get("/api/v1/darkpool/AAPL/otc")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["percentage"] == 12.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_darkpool_otc", new_callable=AsyncMock)
|
||||
async def test_darkpool_otc_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/darkpool/TSLA/otc")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_shorts.shorts_service.get_darkpool_otc", new_callable=AsyncMock)
|
||||
async def test_darkpool_otc_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("FINRA connection timeout")
|
||||
resp = await client.get("/api/v1/darkpool/AAPL/otc")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_darkpool_otc_invalid_symbol(client):
|
||||
resp = await client.get("/api/v1/darkpool/BAD!!!/otc")
|
||||
assert resp.status_code == 400
|
||||
189
tests/test_routes_surveys.py
Normal file
189
tests/test_routes_surveys.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Tests for economy survey routes."""
|
||||
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
# --- Michigan Consumer Sentiment ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_michigan", new_callable=AsyncMock)
|
||||
async def test_survey_michigan_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-01", "consumer_sentiment": 76.5, "inflation_expectation_1yr": 3.1}
|
||||
]
|
||||
resp = await client.get("/api/v1/economy/surveys/michigan")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["consumer_sentiment"] == 76.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_michigan", new_callable=AsyncMock)
|
||||
async def test_survey_michigan_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/surveys/michigan")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_michigan", new_callable=AsyncMock)
|
||||
async def test_survey_michigan_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("FRED unavailable")
|
||||
resp = await client.get("/api/v1/economy/surveys/michigan")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- SLOOS ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_sloos", new_callable=AsyncMock)
|
||||
async def test_survey_sloos_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-01-01", "c_i_tightening_pct": 25.0, "consumer_tightening_pct": 10.0}
|
||||
]
|
||||
resp = await client.get("/api/v1/economy/surveys/sloos")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["c_i_tightening_pct"] == 25.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_sloos", new_callable=AsyncMock)
|
||||
async def test_survey_sloos_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/surveys/sloos")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_sloos", new_callable=AsyncMock)
|
||||
async def test_survey_sloos_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("FRED down")
|
||||
resp = await client.get("/api/v1/economy/surveys/sloos")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- Nonfarm Payrolls ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_nonfarm_payrolls", new_callable=AsyncMock)
|
||||
async def test_survey_nfp_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-07", "value": 275000, "industry": "total_nonfarm"}
|
||||
]
|
||||
resp = await client.get("/api/v1/economy/surveys/nonfarm-payrolls")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["value"] == 275000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_nonfarm_payrolls", new_callable=AsyncMock)
|
||||
async def test_survey_nfp_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/surveys/nonfarm-payrolls")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_nonfarm_payrolls", new_callable=AsyncMock)
|
||||
async def test_survey_nfp_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("BLS unavailable")
|
||||
resp = await client.get("/api/v1/economy/surveys/nonfarm-payrolls")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- Empire State Manufacturing ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_empire_state", new_callable=AsyncMock)
|
||||
async def test_survey_empire_state_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"date": "2026-03-01", "general_business_conditions": -7.58}
|
||||
]
|
||||
resp = await client.get("/api/v1/economy/surveys/empire-state")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert data["data"][0]["general_business_conditions"] == -7.58
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_empire_state", new_callable=AsyncMock)
|
||||
async def test_survey_empire_state_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/surveys/empire-state")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.get_empire_state", new_callable=AsyncMock)
|
||||
async def test_survey_empire_state_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("FRED connection error")
|
||||
resp = await client.get("/api/v1/economy/surveys/empire-state")
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# --- BLS Search ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.bls_search", new_callable=AsyncMock)
|
||||
async def test_survey_bls_search_happy_path(mock_fn, client):
|
||||
mock_fn.return_value = [
|
||||
{"series_id": "CES0000000001", "series_title": "All employees, thousands, total nonfarm"},
|
||||
{"series_id": "CES1000000001", "series_title": "All employees, thousands, mining and logging"},
|
||||
]
|
||||
resp = await client.get("/api/v1/economy/surveys/bls-search?query=nonfarm+payrolls")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["success"] is True
|
||||
assert len(data["data"]) == 2
|
||||
assert data["data"][0]["series_id"] == "CES0000000001"
|
||||
mock_fn.assert_called_once_with(query="nonfarm payrolls")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_survey_bls_search_missing_query(client):
|
||||
resp = await client.get("/api/v1/economy/surveys/bls-search")
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.bls_search", new_callable=AsyncMock)
|
||||
async def test_survey_bls_search_empty(mock_fn, client):
|
||||
mock_fn.return_value = []
|
||||
resp = await client.get("/api/v1/economy/surveys/bls-search?query=nothingtofind")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["data"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("routes_surveys.surveys_service.bls_search", new_callable=AsyncMock)
|
||||
async def test_survey_bls_search_service_error_returns_502(mock_fn, client):
|
||||
mock_fn.side_effect = RuntimeError("BLS API down")
|
||||
resp = await client.get("/api/v1/economy/surveys/bls-search?query=wages")
|
||||
assert resp.status_code == 502
|
||||
Reference in New Issue
Block a user