Compare commits

...

11 Commits

Author SHA1 Message Date
Yaojia Wang
ec005c91a9 chore: fix all ruff lint warnings
All checks were successful
continuous-integration/drone/push Build is passing
- Remove unused datetime imports from openbb_service, market_service,
  quantitative_service (now using obb_utils.days_ago)
- Remove unused variable 'maintains' in routes_sentiment
- Remove unused imports in test files
- Fix forward reference annotation in test helper
2026-03-19 23:19:08 +01:00
Yaojia Wang
0f7341b158 refactor: address architect review findings (6 items)
R1: Extend @safe to catch ValueError->400, simplify routes_backtest
    (eliminated 4 copies of duplicated try/except)
R2: Consolidate PROVIDER constant into obb_utils.py (single source)
R3: Add days_ago() helper to obb_utils.py, replace 8+ duplications
R4: Extract Reddit/ApeWisdom into reddit_service.py from finnhub_service
R5: Fix missing top-level import asyncio in finnhub_service
R6: (deferred - sentiment logic extraction is a larger change)

All 561 tests passing.
2026-03-19 23:15:00 +01:00
Yaojia Wang
37c46e76ae feat: add DeFi data via DefiLlama API (TDD)
7 new endpoints under /api/v1/defi/ (all free, no API key):
- GET /defi/tvl/protocols - top DeFi protocols by TVL
- GET /defi/tvl/chains - chain TVL rankings
- GET /defi/tvl/{protocol} - single protocol TVL
- GET /defi/yields - top yield pools (filter by chain/project)
- GET /defi/stablecoins - stablecoin market data
- GET /defi/volumes/dexs - DEX volume overview
- GET /defi/fees - protocol fees/revenue overview

Data source: DefiLlama API (free, no key needed)
58 new tests (33 service + 25 route). All 561 tests passing.
2026-03-19 23:03:01 +01:00
Yaojia Wang
4915f1bae4 feat: add t-SNE stock clustering and similarity search (TDD)
2 new endpoints:
- POST /portfolio/cluster - t-SNE + KMeans clustering by return
  similarity. Maps stocks to 2D coordinates with cluster labels.
- POST /portfolio/similar - find most/least similar stocks by
  return correlation against a target symbol.

Implementation:
- sklearn TSNE (method=exact) + KMeans with auto n_clusters
- Jitter handling for identical returns edge case
- 33 new tests (17 service unit + 16 route integration)
- All 503 tests passing
2026-03-19 22:53:27 +01:00
Yaojia Wang
9ee3ec9b4e feat: add A-share and HK stock data via AKShare (TDD)
5 new endpoints under /api/v1/cn/:
- GET /cn/a-share/{symbol}/quote - A-share real-time quote
- GET /cn/a-share/{symbol}/historical - A-share historical OHLCV
- GET /cn/a-share/search?query= - search A-shares by name
- GET /cn/hk/{symbol}/quote - HK stock real-time quote
- GET /cn/hk/{symbol}/historical - HK stock historical OHLCV

Features:
- Chinese column names auto-mapped to English
- Symbol validation: A-share ^[036]\d{5}$, HK ^\d{5}$
- qfq (forward-adjusted) prices by default
- 79 new tests (51 service unit + 28 route integration)
- All 470 tests passing
2026-03-19 22:44:30 +01:00
Yaojia Wang
5c7a0ee4c0 feat: add backtesting engine with 4 strategies (TDD)
Strategies:
- POST /backtest/sma-crossover - SMA crossover (short/long window)
- POST /backtest/rsi - RSI oversold/overbought signals
- POST /backtest/buy-and-hold - passive benchmark
- POST /backtest/momentum - multi-symbol momentum rotation

Returns: total_return, annualized_return, sharpe_ratio, max_drawdown,
win_rate, total_trades, equity_curve (last 20 points)

Implementation: pure pandas/numpy, no external backtesting libs.
Shared _compute_metrics helper across all strategies.
79 new tests (46 service unit + 33 route integration).
All 391 tests passing.
2026-03-19 22:35:00 +01:00
Yaojia Wang
42ba359c48 feat: add portfolio optimization and congress tracking (TDD)
Portfolio optimization (3 endpoints):
- POST /portfolio/optimize - HRP optimal weights via scipy clustering
- POST /portfolio/correlation - pairwise correlation matrix
- POST /portfolio/risk-parity - inverse-volatility risk parity weights

Congress tracking (2 endpoints):
- GET /regulators/congress/trades - congress member stock trades
- GET /regulators/congress/bills?query= - search congress bills

Implementation:
- portfolio_service.py: HRP with scipy fallback to inverse-vol
- congress_service.py: multi-provider fallback pattern
- 51 new tests (14 portfolio unit, 20 portfolio route, 12 congress
  unit, 7 congress route)
- All 312 tests passing
2026-03-19 22:27:03 +01:00
Yaojia Wang
27b131492f test: add 159 tests for all new modules
New test files (171 tests):
- test_routes_shorts.py (16) - short volume, FTD, interest, darkpool
- test_routes_fixed_income.py (34) - treasury, yield curve, SOFR, etc.
- test_routes_economy.py (44) - CPI, GDP, FRED search, Fed holdings
- test_routes_surveys.py (17) - Michigan, SLOOS, NFP, Empire State
- test_routes_regulators.py (20) - COT, SEC litigation, institutions
- test_finnhub_service_social.py (20) - social/reddit sentiment unit tests
- test_routes_sentiment_social.py (20) - social endpoints + composite

Updated:
- test_routes_sentiment.py - match new composite sentiment response shape

Total: 261 tests passing (was 102)
2026-03-19 22:12:27 +01:00
Yaojia Wang
ea72497587 docs: update README for sentiment aggregation and social endpoints
- Update endpoint count to 102
- Add social-sentiment, reddit-sentiment, reddit-trending endpoints
- Document composite sentiment scoring (4 sources, weighted)
- Add ApeWisdom to data sources table
2026-03-19 20:59:26 +01:00
Yaojia Wang
3c725c45fa feat: aggregate all sentiment sources into composite score
Redesign /stock/{symbol}/sentiment to combine 4 data sources with
weighted scoring:
- News sentiment (Alpha Vantage, 25%) - article-level bullish/bearish
- Analyst recommendations (Finnhub, 30%) - buy/sell ratio
- Upgrade/downgrade activity (yfinance, 20%) - recent actions
- Reddit buzz (ApeWisdom, 25%) - mention change trend

Returns composite_score (-1 to +1), composite_label, per-source
scores, and full detail data from each source.
2026-03-19 20:55:52 +01:00
Yaojia Wang
4eb06dd8e5 feat: add social media sentiment endpoints
- /stock/{symbol}/social-sentiment -- Finnhub Reddit+Twitter sentiment
  (requires premium, gracefully degrades)
- /stock/{symbol}/reddit-sentiment -- Reddit WSB/stocks/investing
  mentions, upvotes, rank via ApeWisdom (free, no key)
- /discover/reddit-trending -- Top 25 trending stocks on Reddit
  (free, no key)

ApeWisdom provides real-time Reddit data without API key.
Finnhub social-sentiment requires premium plan but endpoint
responds gracefully with premium_required flag.
2026-03-19 20:50:28 +01:00
39 changed files with 8202 additions and 46 deletions

View File

@@ -106,7 +106,7 @@ curl -X POST http://localhost:8000/api/v1/portfolio/analyze \
-d '{"holdings":[{"symbol":"AAPL","shares":100,"buy_in_price":150},{"symbol":"VOLV-B.ST","shares":50,"buy_in_price":250}]}'
```
## API Endpoints (99 total)
## API Endpoints (102 total)
### Health
@@ -130,15 +130,27 @@ curl -X POST http://localhost:8000/api/v1/portfolio/analyze \
| GET | `/api/v1/stock/{symbol}/filings?form_type=10-K` | SEC filings (10-K, 10-Q, 8-K) |
| GET | `/api/v1/search?query=` | Company search by name (SEC/NASDAQ) |
### Sentiment & Analyst Data (Finnhub + Alpha Vantage + yfinance)
### Sentiment & Analyst Data (Finnhub + Alpha Vantage + yfinance + Reddit)
| Method | Path | Description |
|--------|------|-------------|
| GET | `/api/v1/stock/{symbol}/sentiment` | Aggregated: news sentiment + recommendations + upgrades |
| GET | `/api/v1/stock/{symbol}/sentiment` | Composite sentiment score from all sources (-1 to +1) |
| GET | `/api/v1/stock/{symbol}/news-sentiment?limit=30` | News articles with per-ticker sentiment scores (Alpha Vantage) |
| GET | `/api/v1/stock/{symbol}/social-sentiment` | Social media sentiment from Reddit + Twitter (Finnhub) |
| GET | `/api/v1/stock/{symbol}/reddit-sentiment` | Reddit mentions, upvotes, rank (ApeWisdom, free) |
| GET | `/api/v1/stock/{symbol}/insider-trades` | Insider transactions via Finnhub |
| GET | `/api/v1/stock/{symbol}/recommendations` | Monthly analyst buy/hold/sell counts (Finnhub) |
| GET | `/api/v1/stock/{symbol}/upgrades` | Analyst upgrades/downgrades with price targets (yfinance) |
| GET | `/api/v1/discover/reddit-trending` | Top 25 trending stocks on Reddit (free) |
The `/sentiment` endpoint aggregates 4 sources into a weighted composite score:
| Source | Weight | Data |
|--------|--------|------|
| News (Alpha Vantage) | 25% | Article-level bullish/bearish scores |
| Analysts (Finnhub) | 30% | Buy/sell recommendation ratio |
| Upgrades (yfinance) | 20% | Recent upgrade/downgrade actions |
| Reddit (ApeWisdom) | 25% | 24h mention change trend |
### Technical Analysis (14 indicators, local computation, no key needed)
@@ -555,6 +567,7 @@ docker run -p 8000:8000 invest-api
| **Alpha Vantage** | Free | Yes (free registration) | News sentiment scores (bullish/bearish per ticker per article), 25 req/day |
| **FRED** | Free | Yes (free registration) | Fed rate, treasury yields, CPI, PCE, money supply, surveys, 800K+ economic series |
| **Federal Reserve** | Free | No | EFFR, SOFR, money measures, central bank holdings, primary dealer positions, FOMC documents |
| **ApeWisdom** | Free | No | Reddit stock mentions, upvotes, trending (WSB, r/stocks, r/investing) |
| **openbb-technical** | Free | No (local) | ATR, ADX, Stochastic, OBV, Ichimoku, Donchian, Aroon, CCI, Keltner, Fibonacci, A/D, VWAP, Volatility Cones, Relative Rotation |
| **openbb-quantitative** | Free | No (local) | Sharpe, Sortino, Omega ratios, CAPM, normality tests, unit root tests, rolling statistics |

185
akshare_service.py Normal file
View File

@@ -0,0 +1,185 @@
"""A-share and HK stock data service using AKShare."""
import asyncio
import logging
import re
from datetime import datetime, timedelta
from typing import Any
import akshare as ak
import pandas as pd
logger = logging.getLogger(__name__)
# --- Symbol validation patterns ---
_A_SHARE_PATTERN = re.compile(r"^[036]\d{5}$")
_HK_PATTERN = re.compile(r"^\d{5}$")
# --- Chinese column name mappings ---
_HIST_COLUMNS: dict[str, str] = {
"日期": "date",
"开盘": "open",
"收盘": "close",
"最高": "high",
"最低": "low",
"成交量": "volume",
"成交额": "turnover",
"振幅": "amplitude",
"涨跌幅": "change_percent",
"涨跌额": "change",
"换手率": "turnover_rate",
}
_QUOTE_COLUMNS: dict[str, str] = {
"代码": "symbol",
"名称": "name",
"最新价": "price",
"涨跌幅": "change_percent",
"涨跌额": "change",
"成交量": "volume",
"成交额": "turnover",
"今开": "open",
"最高": "high",
"最低": "low",
"昨收": "prev_close",
}
# --- Validation helpers ---
def validate_a_share_symbol(symbol: str) -> bool:
"""Return True if symbol matches A-share format (6 digits, starts with 0, 3, or 6)."""
return bool(_A_SHARE_PATTERN.match(symbol))
def validate_hk_symbol(symbol: str) -> bool:
"""Return True if symbol matches HK stock format (exactly 5 digits)."""
return bool(_HK_PATTERN.match(symbol))
# --- DataFrame parsers ---
def _parse_hist_df(df: pd.DataFrame) -> list[dict[str, Any]]:
"""Convert a Chinese-column historical DataFrame to a list of English-key dicts."""
if df.empty:
return []
df = df.rename(columns=_HIST_COLUMNS)
# Serialize date column to ISO string
if "date" in df.columns:
df["date"] = df["date"].astype(str)
return df.to_dict(orient="records")
def _parse_spot_row(df: pd.DataFrame, symbol: str) -> dict[str, Any] | None:
"""
Filter a spot quote DataFrame by symbol code column and return
an English-key dict for the matching row, or None if not found.
"""
if df.empty:
return None
code_col = "代码"
if code_col not in df.columns:
return None
matched = df[df[code_col] == symbol]
if matched.empty:
return None
row = matched.iloc[0]
result: dict[str, Any] = {}
for cn_key, en_key in _QUOTE_COLUMNS.items():
result[en_key] = row.get(cn_key)
return result
# --- Date helpers ---
def _date_range(days: int) -> tuple[str, str]:
"""Return (start_date, end_date) strings in YYYYMMDD format for the given window."""
end = datetime.now()
start = end - timedelta(days=days)
return start.strftime("%Y%m%d"), end.strftime("%Y%m%d")
# --- A-share service functions ---
async def get_a_share_quote(symbol: str) -> dict[str, Any] | None:
"""
Fetch real-time A-share quote for a single symbol.
Returns a dict with English keys, or None if the symbol is not found
in the spot market data. Propagates AKShare exceptions to the caller.
"""
df: pd.DataFrame = await asyncio.to_thread(ak.stock_zh_a_spot_em)
return _parse_spot_row(df, symbol)
async def get_a_share_historical(
symbol: str, *, days: int = 365
) -> list[dict[str, Any]]:
"""
Fetch daily OHLCV history for an A-share symbol with qfq (前复权) adjustment.
Propagates AKShare exceptions to the caller.
"""
start_date, end_date = _date_range(days)
df: pd.DataFrame = await asyncio.to_thread(
ak.stock_zh_a_hist,
symbol=symbol,
period="daily",
start_date=start_date,
end_date=end_date,
adjust="qfq",
)
return _parse_hist_df(df)
async def search_a_shares(query: str) -> list[dict[str, Any]]:
"""
Search A-share stocks by name substring.
Returns a list of {code, name} dicts. An empty query returns all stocks.
Propagates AKShare exceptions to the caller.
"""
df: pd.DataFrame = await asyncio.to_thread(ak.stock_info_a_code_name)
if query:
df = df[df["name"].str.contains(query, na=False)]
return df[["code", "name"]].to_dict(orient="records")
# --- HK stock service functions ---
async def get_hk_quote(symbol: str) -> dict[str, Any] | None:
"""
Fetch real-time HK stock quote for a single symbol.
Returns a dict with English keys, or None if the symbol is not found.
Propagates AKShare exceptions to the caller.
"""
df: pd.DataFrame = await asyncio.to_thread(ak.stock_hk_spot_em)
return _parse_spot_row(df, symbol)
async def get_hk_historical(
symbol: str, *, days: int = 365
) -> list[dict[str, Any]]:
"""
Fetch daily OHLCV history for a HK stock symbol with qfq adjustment.
Propagates AKShare exceptions to the caller.
"""
start_date, end_date = _date_range(days)
df: pd.DataFrame = await asyncio.to_thread(
ak.stock_hk_hist,
symbol=symbol,
period="daily",
start_date=start_date,
end_date=end_date,
adjust="qfq",
)
return _parse_hist_df(df)

372
backtest_service.py Normal file
View File

@@ -0,0 +1,372 @@
"""Backtesting engine using pure pandas/numpy - no external backtesting libraries."""
import logging
from typing import Any
import numpy as np
import pandas as pd
from obb_utils import fetch_historical
logger = logging.getLogger(__name__)
_EQUITY_CURVE_MAX_POINTS = 20
_MIN_BARS_FOR_SINGLE_POINT = 2
# ---------------------------------------------------------------------------
# Internal signal computation helpers
# ---------------------------------------------------------------------------
def _extract_closes(result: Any) -> pd.Series:
"""Pull close prices from an OBBject result into a float Series."""
bars = result.results
closes = [getattr(bar, "close", None) for bar in bars]
return pd.Series(closes, dtype=float).dropna().reset_index(drop=True)
def _compute_sma_signals(
prices: pd.Series, short_window: int, long_window: int
) -> pd.Series:
"""Return position series (1=long, 0=flat) from SMA crossover strategy.
Buy when short SMA crosses above long SMA; sell when it crosses below.
"""
short_ma = prices.rolling(short_window).mean()
long_ma = prices.rolling(long_window).mean()
# 1 where short > long, else 0; NaN before long_window filled with 0
signal = (short_ma > long_ma).astype(int)
signal.iloc[: long_window - 1] = 0
return signal
def _compute_rsi(prices: pd.Series, period: int) -> pd.Series:
"""Compute Wilder RSI for a price series."""
delta = prices.diff()
gain = delta.clip(lower=0)
loss = (-delta).clip(lower=0)
avg_gain = gain.ewm(alpha=1 / period, adjust=False).mean()
avg_loss = loss.ewm(alpha=1 / period, adjust=False).mean()
# When avg_loss == 0 and avg_gain > 0, RSI = 100; avoid division by zero.
rsi = pd.Series(np.where(
avg_loss == 0,
np.where(avg_gain == 0, 50.0, 100.0),
100 - (100 / (1 + avg_gain / avg_loss)),
), index=prices.index, dtype=float)
# Preserve NaN for the initial diff period
rsi[avg_gain.isna()] = np.nan
return rsi
def _compute_rsi_signals(
prices: pd.Series, period: int, oversold: float, overbought: float
) -> pd.Series:
"""Return position series (1=long, 0=flat) from RSI strategy.
Buy when RSI < oversold; sell when RSI > overbought.
"""
rsi = _compute_rsi(prices, period)
position = pd.Series(0, index=prices.index, dtype=int)
in_trade = False
for i in range(len(prices)):
rsi_val = rsi.iloc[i]
if pd.isna(rsi_val):
continue
if not in_trade and rsi_val < oversold:
in_trade = True
elif in_trade and rsi_val > overbought:
in_trade = False
if in_trade:
position.iloc[i] = 1
return position
# ---------------------------------------------------------------------------
# Shared metrics computation
# ---------------------------------------------------------------------------
def _compute_metrics(equity: pd.Series, trades: int) -> dict[str, Any]:
"""Compute standard backtest performance metrics from an equity curve.
Parameters
----------
equity:
Daily portfolio value series starting from initial_capital.
trades:
Number of completed round-trip trades.
Returns
-------
dict with keys: total_return, annualized_return, sharpe_ratio,
max_drawdown, win_rate, total_trades, equity_curve.
"""
n = len(equity)
initial = float(equity.iloc[0])
final = float(equity.iloc[-1])
total_return = (final - initial) / initial if initial != 0 else 0.0
trading_days = max(n - 1, 1)
annualized_return = (1 + total_return) ** (252 / trading_days) - 1
# Sharpe ratio (annualized, risk-free rate = 0)
sharpe_ratio: float | None = None
if n > 1:
daily_returns = equity.pct_change().dropna()
std = float(daily_returns.std())
if std > 0:
sharpe_ratio = float(daily_returns.mean() / std * np.sqrt(252))
# Maximum drawdown
rolling_max = equity.cummax()
drawdown = (equity - rolling_max) / rolling_max
max_drawdown = float(drawdown.min())
# Win rate: undefined when no trades
win_rate: float | None = None
if trades > 0:
# Approximate: compare each trade entry/exit pair captured in equity
win_rate = None # will be overridden by callers that track trades
# Equity curve - last N points as plain Python floats
last_n = equity.iloc[-_EQUITY_CURVE_MAX_POINTS:]
equity_curve = [round(float(v), 4) for v in last_n]
return {
"total_return": round(total_return, 6),
"annualized_return": round(annualized_return, 6),
"sharpe_ratio": round(sharpe_ratio, 6) if sharpe_ratio is not None else None,
"max_drawdown": round(max_drawdown, 6),
"win_rate": win_rate,
"total_trades": trades,
"equity_curve": equity_curve,
}
def _simulate_positions(
prices: pd.Series,
positions: pd.Series,
initial_capital: float,
) -> tuple[pd.Series, int, int]:
"""Simulate portfolio equity given a position series and prices.
Returns (equity_curve, total_trades, winning_trades).
A trade is a complete buy->sell round-trip.
"""
# Daily returns when in position
price_returns = prices.pct_change().fillna(0.0)
strategy_returns = positions.shift(1).fillna(0).astype(float) * price_returns
equity = initial_capital * (1 + strategy_returns).cumprod()
equity.iloc[0] = initial_capital
# Count round trips
trade_changes = positions.diff().abs()
entries = int((trade_changes == 1).sum())
exits = int((trade_changes == -1).sum())
total_trades = min(entries, exits) # only completed round trips
# Count wins: each completed trade where exit value > entry value
winning_trades = 0
in_trade = False
entry_price = 0.0
for i in range(len(positions)):
pos = int(positions.iloc[i])
price = float(prices.iloc[i])
if not in_trade and pos == 1:
in_trade = True
entry_price = price
elif in_trade and pos == 0:
in_trade = False
if price > entry_price:
winning_trades += 1
return equity, total_trades, winning_trades
# ---------------------------------------------------------------------------
# Public strategy functions
# ---------------------------------------------------------------------------
async def backtest_sma_crossover(
symbol: str,
short_window: int,
long_window: int,
days: int,
initial_capital: float,
) -> dict[str, Any]:
"""Run SMA crossover backtest for a single symbol."""
hist = await fetch_historical(symbol, days)
if hist is None:
raise ValueError(f"No historical data available for {symbol}")
prices = _extract_closes(hist)
if len(prices) <= long_window:
raise ValueError(
f"Insufficient data: need >{long_window} bars, got {len(prices)}"
)
positions = _compute_sma_signals(prices, short_window, long_window)
equity, total_trades, winning_trades = _simulate_positions(
prices, positions, initial_capital
)
result = _compute_metrics(equity, total_trades)
if total_trades > 0:
result["win_rate"] = round(winning_trades / total_trades, 6)
return result
async def backtest_rsi(
symbol: str,
period: int,
oversold: float,
overbought: float,
days: int,
initial_capital: float,
) -> dict[str, Any]:
"""Run RSI-based backtest for a single symbol."""
hist = await fetch_historical(symbol, days)
if hist is None:
raise ValueError(f"No historical data available for {symbol}")
prices = _extract_closes(hist)
if len(prices) <= period:
raise ValueError(
f"Insufficient data: need >{period} bars, got {len(prices)}"
)
positions = _compute_rsi_signals(prices, period, oversold, overbought)
equity, total_trades, winning_trades = _simulate_positions(
prices, positions, initial_capital
)
result = _compute_metrics(equity, total_trades)
if total_trades > 0:
result["win_rate"] = round(winning_trades / total_trades, 6)
return result
async def backtest_buy_and_hold(
symbol: str,
days: int,
initial_capital: float,
) -> dict[str, Any]:
"""Run a simple buy-and-hold backtest as a benchmark."""
hist = await fetch_historical(symbol, days)
if hist is None:
raise ValueError(f"No historical data available for {symbol}")
prices = _extract_closes(hist)
if len(prices) < _MIN_BARS_FOR_SINGLE_POINT:
raise ValueError(
f"Insufficient data: need at least 2 bars, got {len(prices)}"
)
# Always fully invested - position is 1 from day 0
positions = pd.Series(1, index=prices.index, dtype=int)
equity, _, _ = _simulate_positions(prices, positions, initial_capital)
result = _compute_metrics(equity, trades=1)
# Buy-and-hold: 1 trade, win_rate is whether final > initial
result["win_rate"] = 1.0 if result["total_return"] > 0 else 0.0
result["total_trades"] = 1
return result
async def backtest_momentum(
symbols: list[str],
lookback: int,
top_n: int,
rebalance_days: int,
days: int,
initial_capital: float,
) -> dict[str, Any]:
"""Run momentum strategy: every rebalance_days pick top_n symbols by lookback return."""
# Fetch all price series
price_map: dict[str, pd.Series] = {}
for sym in symbols:
hist = await fetch_historical(sym, days)
if hist is not None:
closes = _extract_closes(hist)
if len(closes) > lookback:
price_map[sym] = closes
if not price_map:
raise ValueError("No price data available for any of the requested symbols")
# Align all price series to the same length (min across symbols)
min_len = min(len(v) for v in price_map.values())
aligned = {sym: s.iloc[:min_len].reset_index(drop=True) for sym, s in price_map.items()}
n_bars = min_len
portfolio_value = initial_capital
equity_values: list[float] = [initial_capital]
allocation_history: list[dict[str, Any]] = []
total_trades = 0
current_symbols: list[str] = []
current_weights: list[float] = []
entry_prices: dict[str, float] = {}
winning_trades = 0
for bar in range(1, n_bars):
# Rebalance check
if bar % rebalance_days == 0 and bar >= lookback:
# Rank symbols by lookback-period return
returns: dict[str, float] = {}
for sym, prices in aligned.items():
if bar >= lookback:
ret = (prices.iloc[bar] / prices.iloc[bar - lookback]) - 1
returns[sym] = ret
sorted_syms = sorted(returns, key=returns.get, reverse=True) # type: ignore[arg-type]
selected = sorted_syms[:top_n]
weight = 1.0 / len(selected) if selected else 0.0
# Count closed positions as trades
for sym in current_symbols:
if sym in aligned:
exit_price = float(aligned[sym].iloc[bar])
entry_price = entry_prices.get(sym, exit_price)
total_trades += 1
if exit_price > entry_price:
winning_trades += 1
current_symbols = selected
current_weights = [weight] * len(selected)
entry_prices = {sym: float(aligned[sym].iloc[bar]) for sym in selected}
allocation_history.append({
"bar": bar,
"symbols": selected,
"weights": current_weights,
})
# Compute portfolio daily return
if current_symbols:
daily_ret = 0.0
for sym, w in zip(current_symbols, current_weights):
prev_bar = bar - 1
prev_price = float(aligned[sym].iloc[prev_bar])
curr_price = float(aligned[sym].iloc[bar])
if prev_price != 0:
daily_ret += w * (curr_price / prev_price - 1)
portfolio_value = portfolio_value * (1 + daily_ret)
equity_values.append(portfolio_value)
equity = pd.Series(equity_values, dtype=float)
result = _compute_metrics(equity, total_trades)
if total_trades > 0:
result["win_rate"] = round(winning_trades / total_trades, 6)
result["allocation_history"] = allocation_history
return result

68
congress_service.py Normal file
View File

@@ -0,0 +1,68 @@
"""Congress trading data: member trades and bill search."""
import asyncio
import logging
from typing import Any
from openbb import obb
from obb_utils import to_list
logger = logging.getLogger(__name__)
async def _try_obb_call(fn, *args, **kwargs) -> list[dict[str, Any]] | None:
"""Attempt a single OBB call and return to_list result, or None on failure."""
try:
result = await asyncio.to_thread(fn, *args, **kwargs)
return to_list(result)
except Exception as exc:
logger.debug("OBB call failed: %s", exc)
return None
def _get_congress_fn():
"""Resolve the congress trading OBB function safely."""
try:
return obb.regulators.government_us.congress_trading
except AttributeError:
logger.debug("obb.regulators.government_us.congress_trading not available")
return None
async def get_congress_trades() -> list[dict[str, Any]]:
"""Get recent US congress member stock trades.
Returns an empty list if the data provider is unavailable.
"""
fn = _get_congress_fn()
if fn is None:
return []
providers = ["quiverquant", "fmp"]
for provider in providers:
data = await _try_obb_call(fn, provider=provider)
if data is not None:
return data
logger.warning("All congress trades providers failed")
return []
async def search_congress_bills(query: str) -> list[dict[str, Any]]:
"""Search US congress bills by keyword.
Returns an empty list if the data provider is unavailable.
"""
fn = _get_congress_fn()
if fn is None:
return []
providers = ["quiverquant", "fmp"]
for provider in providers:
data = await _try_obb_call(fn, query, provider=provider)
if data is not None:
return data
logger.warning("All congress bills providers failed for query: %s", query)
return []

195
defi_service.py Normal file
View File

@@ -0,0 +1,195 @@
"""DeFi data service via DefiLlama API (no API key required)."""
import logging
from typing import Any
import httpx
logger = logging.getLogger(__name__)
LLAMA_BASE = "https://api.llama.fi"
STABLES_BASE = "https://stablecoins.llama.fi"
YIELDS_BASE = "https://yields.llama.fi"
TIMEOUT = 15.0
async def get_top_protocols(limit: int = 20) -> list[dict[str, Any]]:
"""Fetch top DeFi protocols ranked by TVL from DefiLlama."""
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
resp = await client.get(f"{LLAMA_BASE}/protocols")
resp.raise_for_status()
data = resp.json()
return [
{
"name": p.get("name"),
"symbol": p.get("symbol"),
"tvl": p.get("tvl"),
"chain": p.get("chain"),
"chains": p.get("chains", []),
"category": p.get("category"),
"change_1d": p.get("change_1d"),
"change_7d": p.get("change_7d"),
}
for p in data[:limit]
]
except Exception:
logger.exception("Failed to fetch top protocols from DefiLlama")
return []
async def get_chain_tvls() -> list[dict[str, Any]]:
"""Fetch TVL rankings for all chains from DefiLlama."""
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
resp = await client.get(f"{LLAMA_BASE}/v2/chains")
resp.raise_for_status()
data = resp.json()
return [
{
"name": c.get("name"),
"tvl": c.get("tvl"),
"tokenSymbol": c.get("tokenSymbol"),
}
for c in data
]
except Exception:
logger.exception("Failed to fetch chain TVLs from DefiLlama")
return []
async def get_protocol_tvl(protocol: str) -> float | None:
"""Fetch current TVL for a specific protocol slug."""
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
resp = await client.get(f"{LLAMA_BASE}/tvl/{protocol}")
resp.raise_for_status()
return resp.json()
except Exception:
logger.exception("Failed to fetch TVL for protocol %s", protocol)
return None
async def get_yield_pools(
chain: str | None = None,
project: str | None = None,
) -> list[dict[str, Any]]:
"""Fetch yield pools from DefiLlama, optionally filtered by chain and/or project.
Returns top 20 by TVL descending.
"""
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
resp = await client.get(f"{YIELDS_BASE}/pools")
resp.raise_for_status()
payload = resp.json()
pools: list[dict[str, Any]] = payload.get("data", [])
if chain is not None:
pools = [p for p in pools if p.get("chain") == chain]
if project is not None:
pools = [p for p in pools if p.get("project") == project]
pools = sorted(pools, key=lambda p: p.get("tvlUsd") or 0, reverse=True)[:20]
return [
{
"pool": p.get("pool"),
"chain": p.get("chain"),
"project": p.get("project"),
"symbol": p.get("symbol"),
"tvlUsd": p.get("tvlUsd"),
"apy": p.get("apy"),
"apyBase": p.get("apyBase"),
"apyReward": p.get("apyReward"),
}
for p in pools
]
except Exception:
logger.exception("Failed to fetch yield pools from DefiLlama")
return []
def _extract_circulating(asset: dict[str, Any]) -> float | None:
"""Extract the primary circulating supply value from a stablecoin asset dict."""
raw = asset.get("circulating")
if raw is None:
return None
if isinstance(raw, (int, float)):
return float(raw)
if isinstance(raw, dict):
# DefiLlama returns {"peggedUSD": <amount>, ...}
values = [v for v in raw.values() if isinstance(v, (int, float))]
return values[0] if values else None
return None
async def get_stablecoins(limit: int = 20) -> list[dict[str, Any]]:
"""Fetch top stablecoins by circulating supply from DefiLlama."""
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
resp = await client.get(f"{STABLES_BASE}/stablecoins")
resp.raise_for_status()
payload = resp.json()
assets: list[dict[str, Any]] = payload.get("peggedAssets", [])
return [
{
"name": a.get("name"),
"symbol": a.get("symbol"),
"pegType": a.get("pegType"),
"circulating": _extract_circulating(a),
"price": a.get("price"),
}
for a in assets[:limit]
]
except Exception:
logger.exception("Failed to fetch stablecoins from DefiLlama")
return []
async def get_dex_volumes() -> dict[str, Any] | None:
"""Fetch DEX volume overview from DefiLlama."""
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
resp = await client.get(f"{LLAMA_BASE}/overview/dexs")
resp.raise_for_status()
payload = resp.json()
protocols = [
{
"name": p.get("name"),
"volume24h": p.get("total24h"),
}
for p in payload.get("protocols", [])
]
return {
"totalVolume24h": payload.get("total24h"),
"totalVolume7d": payload.get("total7d"),
"protocols": protocols,
}
except Exception:
logger.exception("Failed to fetch DEX volumes from DefiLlama")
return None
async def get_protocol_fees() -> list[dict[str, Any]]:
"""Fetch protocol fees and revenue overview from DefiLlama."""
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
resp = await client.get(f"{LLAMA_BASE}/overview/fees")
resp.raise_for_status()
payload = resp.json()
return [
{
"name": p.get("name"),
"fees24h": p.get("total24h"),
"revenue24h": p.get("revenue24h"),
}
for p in payload.get("protocols", [])
]
except Exception:
logger.exception("Failed to fetch protocol fees from DefiLlama")
return []

View File

@@ -1,5 +1,6 @@
"""Finnhub API client for sentiment, insider trades, and analyst data."""
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Any
@@ -104,12 +105,67 @@ async def get_upgrade_downgrade(
return data if isinstance(data, list) else []
async def get_social_sentiment(symbol: str) -> dict[str, Any]:
"""Get social media sentiment from Reddit and Twitter.
Returns mention counts, positive/negative scores, and trends.
"""
if not _is_configured():
return {"configured": False, "message": "Set INVEST_API_FINNHUB_API_KEY"}
start = (datetime.now() - timedelta(days=3)).strftime("%Y-%m-%d")
async with _client() as client:
resp = await client.get(
"/stock/social-sentiment",
params={"symbol": symbol, "from": start},
)
if resp.status_code in (403, 401):
logger.debug("social-sentiment requires premium, skipping")
return {"configured": True, "symbol": symbol, "premium_required": True, "reddit": [], "twitter": []}
resp.raise_for_status()
data = resp.json()
if not isinstance(data, dict):
return {"configured": True, "symbol": symbol, "reddit": [], "twitter": []}
reddit = data.get("reddit", [])
twitter = data.get("twitter", [])
# Compute summary stats
reddit_summary = _summarize_social(reddit) if reddit else None
twitter_summary = _summarize_social(twitter) if twitter else None
return {
"configured": True,
"symbol": symbol,
"reddit_summary": reddit_summary,
"twitter_summary": twitter_summary,
"reddit": reddit[-20:],
"twitter": twitter[-20:],
}
def _summarize_social(entries: list[dict[str, Any]]) -> dict[str, Any]:
"""Summarize social sentiment entries into aggregate stats."""
if not entries:
return {}
total_mentions = sum(e.get("mention", 0) for e in entries)
total_positive = sum(e.get("positiveScore", 0) for e in entries)
total_negative = sum(e.get("negativeScore", 0) for e in entries)
avg_score = sum(e.get("score", 0) for e in entries) / len(entries)
return {
"total_mentions": total_mentions,
"total_positive": total_positive,
"total_negative": total_negative,
"avg_score": round(avg_score, 4),
"data_points": len(entries),
}
# Reddit sentiment moved to reddit_service.py
async def get_sentiment_summary(symbol: str) -> dict[str, Any]:
"""Aggregate all sentiment data for a symbol into one response."""
if not _is_configured():
return {"configured": False, "message": "Set INVEST_API_FINNHUB_API_KEY to enable sentiment data"}
import asyncio
news_sentiment, company_news, recommendations, upgrades = await asyncio.gather(
get_news_sentiment(symbol),
get_company_news(symbol, days=7),

View File

@@ -37,6 +37,10 @@ from routes_sentiment import router as sentiment_router # noqa: E402
from routes_shorts import router as shorts_router # noqa: E402
from routes_surveys import router as surveys_router # noqa: E402
from routes_technical import router as technical_router # noqa: E402
from routes_portfolio import router as portfolio_router # noqa: E402
from routes_backtest import router as backtest_router # noqa: E402
from routes_cn import router as cn_router # noqa: E402
from routes_defi import router as defi_router # noqa: E402
logging.basicConfig(
level=settings.log_level.upper(),
@@ -81,6 +85,10 @@ app.include_router(fixed_income_router)
app.include_router(economy_router)
app.include_router(surveys_router)
app.include_router(regulators_router)
app.include_router(portfolio_router)
app.include_router(backtest_router)
app.include_router(cn_router)
app.include_router(defi_router)
@app.get("/health", response_model=dict[str, str])

View File

@@ -2,17 +2,14 @@
import asyncio
import logging
from datetime import datetime, timezone, timedelta
from typing import Any
from openbb import obb
from obb_utils import to_list
from obb_utils import to_list, days_ago, PROVIDER
logger = logging.getLogger(__name__)
PROVIDER = "yfinance"
# --- ETF ---
@@ -30,7 +27,7 @@ async def get_etf_info(symbol: str) -> dict[str, Any]:
async def get_etf_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]:
"""Get ETF price history."""
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
start = days_ago(days)
try:
result = await asyncio.to_thread(
obb.etf.historical, symbol, start_date=start, provider=PROVIDER
@@ -66,7 +63,7 @@ async def get_available_indices() -> list[dict[str, Any]]:
async def get_index_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]:
"""Get index price history."""
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
start = days_ago(days)
try:
result = await asyncio.to_thread(
obb.index.price.historical, symbol, start_date=start, provider=PROVIDER
@@ -82,7 +79,7 @@ async def get_index_historical(symbol: str, days: int = 365) -> list[dict[str, A
async def get_crypto_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]:
"""Get cryptocurrency price history."""
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
start = days_ago(days)
try:
result = await asyncio.to_thread(
obb.crypto.price.historical, symbol, start_date=start, provider=PROVIDER
@@ -110,7 +107,7 @@ async def get_currency_historical(
symbol: str, days: int = 365
) -> list[dict[str, Any]]:
"""Get forex price history (e.g., EURUSD)."""
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
start = days_ago(days)
try:
result = await asyncio.to_thread(
obb.currency.price.historical, symbol, start_date=start, provider=PROVIDER
@@ -140,7 +137,7 @@ async def get_futures_historical(
symbol: str, days: int = 365
) -> list[dict[str, Any]]:
"""Get futures price history."""
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
start = days_ago(days)
try:
result = await asyncio.to_thread(
obb.derivatives.futures.historical, symbol, start_date=start, provider=PROVIDER

View File

@@ -66,11 +66,16 @@ def first_or_empty(result: Any) -> dict[str, Any]:
return items[0] if items else {}
def days_ago(days: int) -> str:
"""Return a YYYY-MM-DD date string for N days ago (UTC)."""
return (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
async def fetch_historical(
symbol: str, days: int = 365, provider: str = PROVIDER,
) -> Any | None:
"""Fetch historical price data, returning the OBBject result or None."""
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
start = days_ago(days)
try:
result = await asyncio.to_thread(
obb.equity.price.historical, symbol, start_date=start, provider=provider,

View File

@@ -1,17 +1,14 @@
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from typing import Any
import yfinance as yf
from openbb import obb
from obb_utils import to_list, first_or_empty
from obb_utils import to_list, first_or_empty, days_ago, PROVIDER
logger = logging.getLogger(__name__)
PROVIDER = "yfinance"
async def get_quote(symbol: str) -> dict[str, Any]:
result = await asyncio.to_thread(
@@ -21,7 +18,7 @@ async def get_quote(symbol: str) -> dict[str, Any]:
async def get_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]:
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
start = days_ago(days)
result = await asyncio.to_thread(
obb.equity.price.historical,
symbol,

372
portfolio_service.py Normal file
View File

@@ -0,0 +1,372 @@
"""Portfolio optimization: HRP, correlation matrix, risk parity, t-SNE clustering."""
import asyncio
import logging
from math import isqrt
from typing import Any
import numpy as np
import pandas as pd
from obb_utils import fetch_historical
logger = logging.getLogger(__name__)
async def fetch_historical_prices(symbols: list[str], days: int = 365) -> pd.DataFrame:
"""Fetch closing prices for multiple symbols and return as a DataFrame.
Columns are symbol names; rows are dates. Symbols with no data are skipped.
"""
tasks = [fetch_historical(sym, days=days) for sym in symbols]
results = await asyncio.gather(*tasks)
price_series: dict[str, pd.Series] = {}
for sym, result in zip(symbols, results):
if result is None or result.results is None:
logger.warning("No historical data for %s, skipping", sym)
continue
rows = result.results
if not rows:
continue
dates = []
closes = []
for row in rows:
d = getattr(row, "date", None)
c = getattr(row, "close", None)
if d is not None and c is not None:
dates.append(str(d))
closes.append(float(c))
if dates:
price_series[sym] = pd.Series(closes, index=dates)
if not price_series:
return pd.DataFrame()
df = pd.DataFrame(price_series)
df = df.dropna(how="all")
return df
def _compute_returns(prices: pd.DataFrame) -> pd.DataFrame:
"""Compute daily log returns from a price DataFrame."""
return prices.pct_change().dropna()
def _inverse_volatility_weights(returns: pd.DataFrame) -> dict[str, float]:
"""Compute inverse-volatility weights."""
vols = returns.std()
inv_vols = 1.0 / vols
weights = inv_vols / inv_vols.sum()
return {sym: float(w) for sym, w in weights.items()}
def _hrp_weights(returns: pd.DataFrame) -> dict[str, float]:
"""Compute Hierarchical Risk Parity weights via scipy clustering.
Falls back to inverse-volatility if scipy is unavailable.
"""
symbols = list(returns.columns)
n = len(symbols)
if n == 1:
return {symbols[0]: 1.0}
try:
from scipy.cluster.hierarchy import linkage, leaves_list
from scipy.spatial.distance import squareform
corr = returns.corr().fillna(0).values
# Convert correlation to distance: d = sqrt(0.5 * (1 - corr))
dist = np.sqrt(np.clip(0.5 * (1 - corr), 0, 1))
np.fill_diagonal(dist, 0.0)
condensed = squareform(dist)
link = linkage(condensed, method="single")
order = leaves_list(link)
sorted_symbols = [symbols[i] for i in order]
except ImportError:
logger.warning("scipy not available; using inverse-volatility for HRP")
return _inverse_volatility_weights(returns)
cov = returns.cov().values
def _bisect_weights(items: list[str]) -> dict[str, float]:
if len(items) == 1:
return {items[0]: 1.0}
mid = len(items) // 2
left_items = items[:mid]
right_items = items[mid:]
left_idx = [sorted_symbols.index(s) for s in left_items]
right_idx = [sorted_symbols.index(s) for s in right_items]
def _cluster_var(idx: list[int]) -> float:
sub_cov = cov[np.ix_(idx, idx)]
w = np.ones(len(idx)) / len(idx)
return float(w @ sub_cov @ w)
v_left = _cluster_var(left_idx)
v_right = _cluster_var(right_idx)
total = v_left + v_right
alpha = 1.0 - v_left / total if total > 0 else 0.5
w_left = _bisect_weights(left_items)
w_right = _bisect_weights(right_items)
result = {}
for sym, w in w_left.items():
result[sym] = w * (1.0 - alpha)
for sym, w in w_right.items():
result[sym] = w * alpha
return result
raw = _bisect_weights(sorted_symbols)
total = sum(raw.values())
return {sym: float(w / total) for sym, w in raw.items()}
async def optimize_hrp(symbols: list[str], days: int = 365) -> dict[str, Any]:
"""Compute Hierarchical Risk Parity portfolio weights.
Args:
symbols: List of ticker symbols (1-50).
days: Number of historical days to use.
Returns:
Dict with keys ``weights`` (symbol -> float) and ``method``.
Raises:
ValueError: If symbols is empty or no price data is available.
"""
if not symbols:
raise ValueError("symbols must not be empty")
prices = await fetch_historical_prices(symbols, days=days)
if prices.empty:
raise ValueError("No price data available for the given symbols")
returns = _compute_returns(prices)
weights = _hrp_weights(returns)
return {"weights": weights, "method": "hrp"}
async def compute_correlation(
symbols: list[str], days: int = 365
) -> dict[str, Any]:
"""Compute correlation matrix for a list of symbols.
Args:
symbols: List of ticker symbols (1-50).
days: Number of historical days to use.
Returns:
Dict with keys ``symbols`` (list) and ``matrix`` (list of lists).
Raises:
ValueError: If symbols is empty or no price data is available.
"""
if not symbols:
raise ValueError("symbols must not be empty")
prices = await fetch_historical_prices(symbols, days=days)
if prices.empty:
raise ValueError("No price data available for the given symbols")
returns = _compute_returns(prices)
available = list(returns.columns)
corr = returns.corr().fillna(0)
matrix = corr.values.tolist()
return {"symbols": available, "matrix": matrix}
async def compute_risk_parity(
symbols: list[str], days: int = 365
) -> dict[str, Any]:
"""Compute equal risk contribution (inverse-volatility) weights.
Args:
symbols: List of ticker symbols (1-50).
days: Number of historical days to use.
Returns:
Dict with keys ``weights``, ``risk_contributions``, and ``method``.
Raises:
ValueError: If symbols is empty or no price data is available.
"""
if not symbols:
raise ValueError("symbols must not be empty")
prices = await fetch_historical_prices(symbols, days=days)
if prices.empty:
raise ValueError("No price data available for the given symbols")
returns = _compute_returns(prices)
weights = _inverse_volatility_weights(returns)
# Risk contributions: w_i * sigma_i / sum(w_j * sigma_j)
vols = returns.std()
weighted_risk = {sym: weights[sym] * float(vols[sym]) for sym in weights}
total_risk = sum(weighted_risk.values())
if total_risk > 0:
risk_contributions = {sym: v / total_risk for sym, v in weighted_risk.items()}
else:
n = len(weights)
risk_contributions = {sym: 1.0 / n for sym in weights}
return {
"weights": weights,
"risk_contributions": risk_contributions,
"method": "risk_parity",
}
def _auto_n_clusters(n: int) -> int:
"""Return a sensible default cluster count: max(2, floor(sqrt(n)))."""
return max(2, isqrt(n))
def _run_tsne_kmeans(
returns_matrix: np.ndarray, n_clusters: int
) -> tuple[np.ndarray, np.ndarray]:
"""Run t-SNE then KMeans on a (n_symbols x n_days) returns matrix.
Returns (coords, labels) where coords has shape (n_symbols, 2).
CPU-heavy: caller must wrap in asyncio.to_thread.
"""
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
n_samples = returns_matrix.shape[0]
perplexity = min(5, n_samples - 1)
# Add tiny noise to prevent numerical singularity when returns are identical
rng = np.random.default_rng(42)
jittered = returns_matrix + rng.normal(0, 1e-10, returns_matrix.shape)
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42, method="exact")
coords = tsne.fit_transform(jittered)
km = KMeans(n_clusters=n_clusters, random_state=42, n_init="auto")
labels = km.fit_predict(coords)
return coords, labels
async def cluster_stocks(
symbols: list[str],
days: int = 180,
n_clusters: int | None = None,
) -> dict[str, Any]:
"""Cluster stocks by return similarity using t-SNE + KMeans.
Args:
symbols: List of ticker symbols. Minimum 3, maximum 50.
days: Number of historical trading days to use.
n_clusters: Number of clusters. Defaults to floor(sqrt(n_symbols)).
Returns:
Dict with keys ``symbols``, ``coordinates``, ``clusters``,
``method``, ``n_clusters``, and ``days``.
Raises:
ValueError: Fewer than 3 symbols, or no price data available.
"""
if len(symbols) < 3:
raise ValueError("cluster_stocks requires at least 3 symbols")
prices = await fetch_historical_prices(symbols, days=days)
if prices.empty:
raise ValueError("No price data available for the given symbols")
returns = _compute_returns(prices)
available = list(returns.columns)
n = len(available)
k = n_clusters if n_clusters is not None else _auto_n_clusters(n)
# Build (n_symbols x n_days) matrix; fill NaN with column mean
matrix = returns[available].T.fillna(0).values.astype(float)
coords, labels = await asyncio.to_thread(_run_tsne_kmeans, matrix, k)
coordinates = [
{
"symbol": sym,
"x": float(coords[i, 0]),
"y": float(coords[i, 1]),
"cluster": int(labels[i]),
}
for i, sym in enumerate(available)
]
clusters: dict[str, list[str]] = {}
for sym, label in zip(available, labels):
key = str(int(label))
clusters.setdefault(key, []).append(sym)
return {
"symbols": available,
"coordinates": coordinates,
"clusters": clusters,
"method": "t-SNE + KMeans",
"n_clusters": k,
"days": days,
}
async def find_similar_stocks(
symbol: str,
universe: list[str],
days: int = 180,
top_n: int = 5,
) -> dict[str, Any]:
"""Find stocks most/least similar to a target by return correlation.
Args:
symbol: Target ticker symbol.
universe: List of candidate symbols to compare against.
days: Number of historical trading days to use.
top_n: Number of most- and least-similar stocks to return.
Returns:
Dict with keys ``symbol``, ``most_similar``, ``least_similar``.
Raises:
ValueError: No price data available, or target symbol missing from data.
"""
all_symbols = [symbol] + [s for s in universe if s != symbol]
prices = await fetch_historical_prices(all_symbols, days=days)
if prices.empty:
raise ValueError("No price data available for the given symbols")
if symbol not in prices.columns:
raise ValueError(
f"{symbol} not found in price data; it may have no available history"
)
returns = _compute_returns(prices)
target_returns = returns[symbol]
peers = [s for s in universe if s in returns.columns and s != symbol]
correlations: list[dict[str, Any]] = []
for peer in peers:
corr_val = float(target_returns.corr(returns[peer]))
if not np.isnan(corr_val):
correlations.append({"symbol": peer, "correlation": corr_val})
correlations.sort(key=lambda e: e["correlation"], reverse=True)
n = min(top_n, len(correlations))
most_similar = correlations[:n]
least_similar = sorted(correlations, key=lambda e: e["correlation"])[:n]
return {
"symbol": symbol,
"most_similar": most_similar,
"least_similar": least_similar,
}

View File

@@ -10,6 +10,7 @@ dependencies = [
"pydantic-settings",
"httpx",
"curl_cffi==0.7.4",
"akshare",
]
[project.optional-dependencies]

View File

@@ -2,17 +2,14 @@
import asyncio
import logging
from datetime import datetime, timezone, timedelta
from typing import Any
from openbb import obb
from obb_utils import extract_single, safe_last, fetch_historical, to_list
from obb_utils import extract_single, safe_last, fetch_historical, to_list, days_ago, PROVIDER
logger = logging.getLogger(__name__)
PROVIDER = "yfinance"
# Need 252+ trading days for default window; 730 calendar days is safe
PERF_DAYS = 730
TARGET = "close"
@@ -22,7 +19,7 @@ async def get_performance_metrics(symbol: str, days: int = 365) -> dict[str, Any
"""Calculate Sharpe ratio, summary stats, and volatility for a symbol."""
# Need at least 252 trading days for Sharpe window
fetch_days = max(days, PERF_DAYS)
start = (datetime.now(tz=timezone.utc) - timedelta(days=fetch_days)).strftime("%Y-%m-%d")
start = days_ago(fetch_days)
try:
hist = await asyncio.to_thread(
@@ -64,7 +61,7 @@ async def get_performance_metrics(symbol: str, days: int = 365) -> dict[str, Any
async def get_capm(symbol: str) -> dict[str, Any]:
"""Calculate CAPM metrics: beta, alpha, systematic/idiosyncratic risk."""
start = (datetime.now(tz=timezone.utc) - timedelta(days=PERF_DAYS)).strftime("%Y-%m-%d")
start = days_ago(PERF_DAYS)
try:
hist = await asyncio.to_thread(
@@ -85,7 +82,7 @@ async def get_capm(symbol: str) -> dict[str, Any]:
async def get_normality_test(symbol: str, days: int = 365) -> dict[str, Any]:
"""Run normality tests (Jarque-Bera, Shapiro-Wilk, etc.) on returns."""
fetch_days = max(days, PERF_DAYS)
start = (datetime.now(tz=timezone.utc) - timedelta(days=fetch_days)).strftime("%Y-%m-%d")
start = days_ago(fetch_days)
try:
hist = await asyncio.to_thread(
@@ -106,7 +103,7 @@ async def get_normality_test(symbol: str, days: int = 365) -> dict[str, Any]:
async def get_unitroot_test(symbol: str, days: int = 365) -> dict[str, Any]:
"""Run unit root tests (ADF, KPSS) for stationarity."""
fetch_days = max(days, PERF_DAYS)
start = (datetime.now(tz=timezone.utc) - timedelta(days=fetch_days)).strftime("%Y-%m-%d")
start = days_ago(fetch_days)
try:
hist = await asyncio.to_thread(

80
reddit_service.py Normal file
View File

@@ -0,0 +1,80 @@
"""Reddit stock sentiment via ApeWisdom API (free, no key needed)."""
import logging
from typing import Any
import httpx
logger = logging.getLogger(__name__)
APEWISDOM_URL = "https://apewisdom.io/api/v1.0/filter/all-stocks/page/1"
TIMEOUT = 10.0
async def get_reddit_sentiment(symbol: str) -> dict[str, Any]:
"""Get Reddit sentiment for a symbol.
Tracks mentions and upvotes across r/wallstreetbets, r/stocks, r/investing.
"""
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
resp = await client.get(APEWISDOM_URL)
resp.raise_for_status()
data = resp.json()
results = data.get("results", [])
match = next(
(r for r in results if r.get("ticker", "").upper() == symbol.upper()),
None,
)
if match is None:
return {
"symbol": symbol,
"found": False,
"message": f"{symbol} not in Reddit top trending (not enough mentions)",
}
mentions_prev = match.get("mentions_24h_ago", 0)
mentions_now = match.get("mentions", 0)
change_pct = (
round((mentions_now - mentions_prev) / mentions_prev * 100, 1)
if mentions_prev > 0
else None
)
return {
"symbol": symbol,
"found": True,
"rank": match.get("rank"),
"mentions_24h": mentions_now,
"mentions_24h_ago": mentions_prev,
"mentions_change_pct": change_pct,
"upvotes": match.get("upvotes"),
"rank_24h_ago": match.get("rank_24h_ago"),
}
except Exception:
logger.warning("Reddit sentiment failed for %s", symbol, exc_info=True)
return {"symbol": symbol, "error": "Failed to fetch Reddit sentiment"}
async def get_reddit_trending() -> list[dict[str, Any]]:
"""Get top trending stocks on Reddit (free, no key)."""
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
resp = await client.get(APEWISDOM_URL)
resp.raise_for_status()
data = resp.json()
return [
{
"rank": r.get("rank"),
"symbol": r.get("ticker"),
"name": r.get("name"),
"mentions_24h": r.get("mentions"),
"upvotes": r.get("upvotes"),
"rank_24h_ago": r.get("rank_24h_ago"),
"mentions_24h_ago": r.get("mentions_24h_ago"),
}
for r in data.get("results", [])[:25]
]
except Exception:
logger.warning("Reddit trending failed", exc_info=True)
return []

View File

@@ -23,13 +23,19 @@ def validate_symbol(symbol: str) -> str:
def safe(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
"""Decorator to catch upstream errors and return 502."""
"""Decorator to catch upstream errors and return 502.
ValueError is caught separately and returned as 400 (bad request).
All other non-HTTP exceptions become 502 (upstream error).
"""
@functools.wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
return await fn(*args, **kwargs)
except HTTPException:
raise
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
except Exception:
logger.exception("Upstream data error")
raise HTTPException(

106
routes_backtest.py Normal file
View File

@@ -0,0 +1,106 @@
"""Routes for backtesting strategies."""
from fastapi import APIRouter
from pydantic import BaseModel, Field
import backtest_service
from models import ApiResponse
from route_utils import safe
router = APIRouter(prefix="/api/v1/backtest", tags=["backtest"])
# ---------------------------------------------------------------------------
# Request models
# ---------------------------------------------------------------------------
class BacktestRequest(BaseModel):
symbol: str = Field(..., min_length=1, max_length=20)
days: int = Field(default=365, ge=30, le=3650)
initial_capital: float = Field(default=10000.0, gt=0, le=1_000_000_000)
class SMARequest(BacktestRequest):
short_window: int = Field(default=20, ge=5, le=100)
long_window: int = Field(default=50, ge=10, le=400)
class RSIRequest(BacktestRequest):
period: int = Field(default=14, ge=2, le=50)
oversold: float = Field(default=30.0, ge=1, le=49)
overbought: float = Field(default=70.0, ge=51, le=99)
class BuyAndHoldRequest(BacktestRequest):
pass
class MomentumRequest(BaseModel):
symbols: list[str] = Field(..., min_length=2, max_length=20)
lookback: int = Field(default=60, ge=5, le=252)
top_n: int = Field(default=2, ge=1)
rebalance_days: int = Field(default=30, ge=5, le=252)
days: int = Field(default=365, ge=60, le=3650)
initial_capital: float = Field(default=10000.0, gt=0)
# ---------------------------------------------------------------------------
# Route handlers
# ---------------------------------------------------------------------------
@router.post("/sma-crossover", response_model=ApiResponse)
@safe
async def sma_crossover(req: SMARequest) -> ApiResponse:
"""SMA crossover strategy: buy when short SMA crosses above long SMA."""
result = await backtest_service.backtest_sma_crossover(
req.symbol,
short_window=req.short_window,
long_window=req.long_window,
days=req.days,
initial_capital=req.initial_capital,
)
return ApiResponse(data=result)
@router.post("/rsi", response_model=ApiResponse)
@safe
async def rsi_strategy(req: RSIRequest) -> ApiResponse:
"""RSI strategy: buy when RSI < oversold, sell when RSI > overbought."""
result = await backtest_service.backtest_rsi(
req.symbol,
period=req.period,
oversold=req.oversold,
overbought=req.overbought,
days=req.days,
initial_capital=req.initial_capital,
)
return ApiResponse(data=result)
@router.post("/buy-and-hold", response_model=ApiResponse)
@safe
async def buy_and_hold(req: BuyAndHoldRequest) -> ApiResponse:
"""Buy-and-hold benchmark: buy on day 1, hold through end of period."""
result = await backtest_service.backtest_buy_and_hold(
req.symbol,
days=req.days,
initial_capital=req.initial_capital,
)
return ApiResponse(data=result)
@router.post("/momentum", response_model=ApiResponse)
@safe
async def momentum_strategy(req: MomentumRequest) -> ApiResponse:
"""Momentum strategy: every rebalance_days pick top_n by lookback return."""
result = await backtest_service.backtest_momentum(
symbols=req.symbols,
lookback=req.lookback,
top_n=req.top_n,
rebalance_days=req.rebalance_days,
days=req.days,
initial_capital=req.initial_capital,
)
return ApiResponse(data=result)

105
routes_cn.py Normal file
View File

@@ -0,0 +1,105 @@
"""Routes for A-share (China) and Hong Kong stock market data via AKShare."""
from fastapi import APIRouter, HTTPException, Path, Query
from models import ApiResponse
from route_utils import safe
import akshare_service
router = APIRouter(prefix="/api/v1/cn", tags=["China & HK Markets"])
# --- Validation helpers ---
def _validate_a_share(symbol: str) -> str:
"""Validate A-share symbol format; raise 400 on failure."""
if not akshare_service.validate_a_share_symbol(symbol):
raise HTTPException(
status_code=400,
detail=(
f"Invalid A-share symbol '{symbol}'. "
"Must be 6 digits starting with 0, 3, or 6 (e.g. 000001, 300001, 600519)."
),
)
return symbol
def _validate_hk(symbol: str) -> str:
"""Validate HK stock symbol format; raise 400 on failure."""
if not akshare_service.validate_hk_symbol(symbol):
raise HTTPException(
status_code=400,
detail=(
f"Invalid HK symbol '{symbol}'. "
"Must be exactly 5 digits (e.g. 00700, 09988)."
),
)
return symbol
# --- A-share routes ---
# NOTE: /a-share/search MUST be registered before /a-share/{symbol} to avoid shadowing.
@router.get("/a-share/search", response_model=ApiResponse)
@safe
async def a_share_search(
query: str = Query(..., description="Stock name to search for (partial match)"),
) -> ApiResponse:
"""Search A-share stocks by name (partial match)."""
data = await akshare_service.search_a_shares(query)
return ApiResponse(data=data)
@router.get("/a-share/{symbol}/quote", response_model=ApiResponse)
@safe
async def a_share_quote(
symbol: str = Path(..., min_length=6, max_length=6),
) -> ApiResponse:
"""Get real-time A-share quote (沪深 real-time price)."""
symbol = _validate_a_share(symbol)
data = await akshare_service.get_a_share_quote(symbol)
if data is None:
raise HTTPException(status_code=404, detail=f"A-share symbol '{symbol}' not found.")
return ApiResponse(data=data)
@router.get("/a-share/{symbol}/historical", response_model=ApiResponse)
@safe
async def a_share_historical(
symbol: str = Path(..., min_length=6, max_length=6),
days: int = Query(default=365, ge=1, le=3650),
) -> ApiResponse:
"""Get A-share daily OHLCV history with qfq (前复权) adjustment."""
symbol = _validate_a_share(symbol)
data = await akshare_service.get_a_share_historical(symbol, days=days)
return ApiResponse(data=data)
# --- HK stock routes ---
@router.get("/hk/{symbol}/quote", response_model=ApiResponse)
@safe
async def hk_quote(
symbol: str = Path(..., min_length=5, max_length=5),
) -> ApiResponse:
"""Get real-time HK stock quote (港股 real-time price)."""
symbol = _validate_hk(symbol)
data = await akshare_service.get_hk_quote(symbol)
if data is None:
raise HTTPException(status_code=404, detail=f"HK symbol '{symbol}' not found.")
return ApiResponse(data=data)
@router.get("/hk/{symbol}/historical", response_model=ApiResponse)
@safe
async def hk_historical(
symbol: str = Path(..., min_length=5, max_length=5),
days: int = Query(default=365, ge=1, le=3650),
) -> ApiResponse:
"""Get HK stock daily OHLCV history with qfq adjustment."""
symbol = _validate_hk(symbol)
data = await akshare_service.get_hk_historical(symbol, days=days)
return ApiResponse(data=data)

75
routes_defi.py Normal file
View File

@@ -0,0 +1,75 @@
"""DeFi data routes via DefiLlama API."""
from fastapi import APIRouter, HTTPException, Query
import defi_service
from models import ApiResponse
from route_utils import safe
router = APIRouter(prefix="/api/v1/defi")
@router.get("/tvl/protocols", response_model=ApiResponse)
@safe
async def tvl_protocols() -> ApiResponse:
"""Get top DeFi protocols ranked by TVL."""
data = await defi_service.get_top_protocols()
return ApiResponse(data=data)
@router.get("/tvl/chains", response_model=ApiResponse)
@safe
async def tvl_chains() -> ApiResponse:
"""Get TVL rankings for all chains."""
data = await defi_service.get_chain_tvls()
return ApiResponse(data=data)
@router.get("/tvl/{protocol}", response_model=ApiResponse)
@safe
async def protocol_tvl(protocol: str) -> ApiResponse:
"""Get current TVL for a specific protocol slug."""
tvl = await defi_service.get_protocol_tvl(protocol)
if tvl is None:
raise HTTPException(status_code=404, detail=f"Protocol '{protocol}' not found")
return ApiResponse(data={"protocol": protocol, "tvl": tvl})
@router.get("/yields", response_model=ApiResponse)
@safe
async def yield_pools(
chain: str | None = Query(default=None, description="Filter by chain name"),
project: str | None = Query(default=None, description="Filter by project name"),
) -> ApiResponse:
"""Get top yield pools, optionally filtered by chain and/or project."""
data = await defi_service.get_yield_pools(chain=chain, project=project)
return ApiResponse(data=data)
@router.get("/stablecoins", response_model=ApiResponse)
@safe
async def stablecoins() -> ApiResponse:
"""Get top stablecoins by circulating supply."""
data = await defi_service.get_stablecoins()
return ApiResponse(data=data)
@router.get("/volumes/dexs", response_model=ApiResponse)
@safe
async def dex_volumes() -> ApiResponse:
"""Get DEX volume overview including top protocols."""
data = await defi_service.get_dex_volumes()
if data is None:
raise HTTPException(
status_code=502,
detail="Failed to fetch DEX volume data from DefiLlama",
)
return ApiResponse(data=data)
@router.get("/fees", response_model=ApiResponse)
@safe
async def protocol_fees() -> ApiResponse:
"""Get protocol fees and revenue overview."""
data = await defi_service.get_protocol_fees()
return ApiResponse(data=data)

96
routes_portfolio.py Normal file
View File

@@ -0,0 +1,96 @@
"""Routes for portfolio optimization (HRP, correlation, risk parity)."""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from models import ApiResponse
from route_utils import safe
import portfolio_service
router = APIRouter(prefix="/api/v1/portfolio")
class PortfolioOptimizeRequest(BaseModel):
symbols: list[str] = Field(..., min_length=1, max_length=50)
days: int = Field(default=365, ge=1, le=3650)
@router.post("/optimize", response_model=ApiResponse)
@safe
async def portfolio_optimize(request: PortfolioOptimizeRequest):
"""Compute HRP optimal weights for a list of symbols."""
try:
result = await portfolio_service.optimize_hrp(
request.symbols, days=request.days
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
return ApiResponse(data=result)
@router.post("/correlation", response_model=ApiResponse)
@safe
async def portfolio_correlation(request: PortfolioOptimizeRequest):
"""Compute correlation matrix for a list of symbols."""
try:
result = await portfolio_service.compute_correlation(
request.symbols, days=request.days
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
return ApiResponse(data=result)
@router.post("/risk-parity", response_model=ApiResponse)
@safe
async def portfolio_risk_parity(request: PortfolioOptimizeRequest):
"""Compute equal risk contribution weights for a list of symbols."""
try:
result = await portfolio_service.compute_risk_parity(
request.symbols, days=request.days
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
return ApiResponse(data=result)
class ClusterRequest(BaseModel):
symbols: list[str] = Field(..., min_length=3, max_length=50)
days: int = Field(default=180, ge=30, le=3650)
n_clusters: int | None = Field(default=None, ge=2, le=20)
class SimilarRequest(BaseModel):
symbol: str = Field(..., min_length=1, max_length=20)
universe: list[str] = Field(..., min_length=2, max_length=50)
days: int = Field(default=180, ge=30, le=3650)
top_n: int = Field(default=5, ge=1, le=20)
@router.post("/cluster", response_model=ApiResponse)
@safe
async def portfolio_cluster(request: ClusterRequest):
"""Cluster stocks by return similarity using t-SNE + KMeans."""
try:
result = await portfolio_service.cluster_stocks(
request.symbols, days=request.days, n_clusters=request.n_clusters
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
return ApiResponse(data=result)
@router.post("/similar", response_model=ApiResponse)
@safe
async def portfolio_similar(request: SimilarRequest):
"""Find stocks most/least similar to a target by return correlation."""
try:
result = await portfolio_service.find_similar_stocks(
request.symbol,
request.universe,
days=request.days,
top_n=request.top_n,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
return ApiResponse(data=result)

View File

@@ -1,10 +1,11 @@
"""Routes for regulatory data (CFTC, SEC)."""
"""Routes for regulatory data (CFTC, SEC, Congress)."""
from fastapi import APIRouter, Path, Query
from models import ApiResponse
from route_utils import safe, validate_symbol
import regulators_service
import congress_service
router = APIRouter(prefix="/api/v1/regulators")
@@ -49,3 +50,22 @@ async def sec_cik_map(symbol: str = Path(..., min_length=1, max_length=20)):
symbol = validate_symbol(symbol)
data = await regulators_service.get_cik_map(symbol)
return ApiResponse(data=data)
# --- Congress Trading ---
@router.get("/congress/trades", response_model=ApiResponse)
@safe
async def congress_trades():
"""Recent US congress member stock trades."""
data = await congress_service.get_congress_trades()
return ApiResponse(data=data)
@router.get("/congress/bills", response_model=ApiResponse)
@safe
async def congress_bills(query: str = Query(..., min_length=1, max_length=200)):
"""Search US congress bills by keyword."""
data = await congress_service.search_congress_bills(query)
return ApiResponse(data=data)

View File

@@ -9,6 +9,7 @@ from route_utils import safe, validate_symbol
import alphavantage_service
import finnhub_service
import openbb_service
import reddit_service
import logging
@@ -23,22 +24,111 @@ router = APIRouter(prefix="/api/v1")
@router.get("/stock/{symbol}/sentiment", response_model=ApiResponse)
@safe
async def stock_sentiment(symbol: str = Path(..., min_length=1, max_length=20)):
"""Get aggregated sentiment: Alpha Vantage news sentiment + Finnhub analyst data."""
"""Aggregated sentiment from all sources with composite score.
Combines: Alpha Vantage news sentiment, Finnhub analyst data,
Reddit mentions, and analyst upgrades into a single composite score.
Score range: -1.0 (extreme bearish) to +1.0 (extreme bullish).
"""
symbol = validate_symbol(symbol)
finnhub_data, av_data = await asyncio.gather(
finnhub_service.get_sentiment_summary(symbol),
# Fetch all sources in parallel
av_data, finnhub_data, reddit_data, upgrades_data, recs_data = await asyncio.gather(
alphavantage_service.get_news_sentiment(symbol, limit=20),
finnhub_service.get_sentiment_summary(symbol),
reddit_service.get_reddit_sentiment(symbol),
openbb_service.get_upgrades_downgrades(symbol, limit=10),
finnhub_service.get_recommendation_trends(symbol),
return_exceptions=True,
)
if isinstance(finnhub_data, BaseException):
logger.exception("Finnhub error", exc_info=finnhub_data)
finnhub_data = {}
if isinstance(av_data, BaseException):
logger.exception("Alpha Vantage error", exc_info=av_data)
av_data = {}
data = {**finnhub_data, "alpha_vantage_sentiment": av_data}
return ApiResponse(data=data)
def _safe(result, default):
return default if isinstance(result, BaseException) else result
av_data = _safe(av_data, {})
finnhub_data = _safe(finnhub_data, {})
reddit_data = _safe(reddit_data, {})
upgrades_data = _safe(upgrades_data, [])
recs_data = _safe(recs_data, [])
# --- Score each source ---
scores: list[tuple[str, float, float]] = [] # (source, score, weight)
# 1. News sentiment (Alpha Vantage): avg_score ranges ~-0.35 to +0.35
if isinstance(av_data, dict) and av_data.get("overall_sentiment"):
av_score = av_data["overall_sentiment"].get("avg_score")
if av_score is not None:
# Normalize to -1..+1 (AV scores are typically -0.35 to +0.35)
normalized = max(-1.0, min(1.0, av_score * 2.5))
scores.append(("news", round(normalized, 3), 0.25))
# 2. Analyst recommendations (Finnhub): buy/sell ratio
if isinstance(recs_data, list) and recs_data:
latest = recs_data[0]
total = sum(latest.get(k, 0) for k in ("strongBuy", "buy", "hold", "sell", "strongSell"))
if total > 0:
bullish = latest.get("strongBuy", 0) + latest.get("buy", 0)
bearish = latest.get("sell", 0) + latest.get("strongSell", 0)
ratio = (bullish - bearish) / total # -1 to +1
scores.append(("analysts", round(ratio, 3), 0.30))
# 3. Analyst upgrades vs downgrades (yfinance)
if isinstance(upgrades_data, list) and upgrades_data:
ups = sum(1 for u in upgrades_data if u.get("action") in ("up", "init"))
downs = sum(1 for u in upgrades_data if u.get("action") == "down")
if len(upgrades_data) > 0:
upgrade_score = (ups - downs) / len(upgrades_data)
scores.append(("upgrades", round(upgrade_score, 3), 0.20))
# 4. Reddit buzz (ApeWisdom)
if isinstance(reddit_data, dict) and reddit_data.get("found"):
mentions = reddit_data.get("mentions_24h", 0)
change = reddit_data.get("mentions_change_pct")
if change is not None and mentions > 10:
# Positive change = bullish buzz, capped at +/- 1
reddit_score = max(-1.0, min(1.0, change / 100))
scores.append(("reddit", round(reddit_score, 3), 0.25))
# --- Compute weighted composite ---
if scores:
total_weight = sum(w for _, _, w in scores)
composite = sum(s * w for _, s, w in scores) / total_weight
composite = round(composite, 3)
else:
composite = None
# Map to label
if composite is None:
label = "Unknown"
elif composite >= 0.5:
label = "Strong Bullish"
elif composite >= 0.15:
label = "Bullish"
elif composite >= -0.15:
label = "Neutral"
elif composite >= -0.5:
label = "Bearish"
else:
label = "Strong Bearish"
return ApiResponse(data={
"symbol": symbol,
"composite_score": composite,
"composite_label": label,
"source_scores": {name: score for name, score, _ in scores},
"source_weights": {name: weight for name, _, weight in scores},
"details": {
"news_sentiment": av_data if isinstance(av_data, dict) else {},
"analyst_recommendations": recs_data[0] if isinstance(recs_data, list) and recs_data else {},
"recent_upgrades": upgrades_data[:5] if isinstance(upgrades_data, list) else [],
"reddit": reddit_data if isinstance(reddit_data, dict) else {},
"finnhub_news": (
finnhub_data.get("recent_news", [])[:5]
if isinstance(finnhub_data, dict)
else []
),
},
})
@router.get("/stock/{symbol}/news-sentiment", response_model=ApiResponse)
@@ -101,3 +191,33 @@ async def stock_upgrades(symbol: str = Path(..., min_length=1, max_length=20)):
symbol = validate_symbol(symbol)
data = await openbb_service.get_upgrades_downgrades(symbol)
return ApiResponse(data=data)
@router.get("/stock/{symbol}/social-sentiment", response_model=ApiResponse)
@safe
async def stock_social_sentiment(
symbol: str = Path(..., min_length=1, max_length=20),
):
"""Social media sentiment from Reddit and Twitter (Finnhub)."""
symbol = validate_symbol(symbol)
data = await finnhub_service.get_social_sentiment(symbol)
return ApiResponse(data=data)
@router.get("/stock/{symbol}/reddit-sentiment", response_model=ApiResponse)
@safe
async def stock_reddit_sentiment(
symbol: str = Path(..., min_length=1, max_length=20),
):
"""Reddit sentiment: mentions, upvotes, rank on WSB/stocks/investing (free, no key)."""
symbol = validate_symbol(symbol)
data = await reddit_service.get_reddit_sentiment(symbol)
return ApiResponse(data=data)
@router.get("/discover/reddit-trending", response_model=ApiResponse)
@safe
async def reddit_trending():
"""Top 25 trending stocks on Reddit (WSB, r/stocks, r/investing). Free, no key."""
data = await reddit_service.get_reddit_trending()
return ApiResponse(data=data)

View File

@@ -0,0 +1,442 @@
"""Unit tests for akshare_service.py - written FIRST (TDD RED phase)."""
from unittest.mock import patch
import pandas as pd
import pytest
import akshare_service
# --- Helpers ---
def _make_hist_df(rows: int = 3) -> pd.DataFrame:
"""Return a minimal historical DataFrame with Chinese column names."""
dates = pd.date_range("2026-01-01", periods=rows, freq="D")
return pd.DataFrame(
{
"日期": dates,
"开盘": [10.0] * rows,
"收盘": [10.5] * rows,
"最高": [11.0] * rows,
"最低": [9.5] * rows,
"成交量": [1_000_000] * rows,
"成交额": [10_500_000.0] * rows,
"振幅": [1.5] * rows,
"涨跌幅": [0.5] * rows,
"涨跌额": [0.05] * rows,
"换手率": [0.3] * rows,
}
)
def _make_spot_df(code: str = "000001", name: str = "平安银行") -> pd.DataFrame:
"""Return a minimal real-time quote DataFrame with Chinese column names."""
return pd.DataFrame(
[
{
"代码": code,
"名称": name,
"最新价": 12.34,
"涨跌幅": 1.23,
"涨跌额": 0.15,
"成交量": 500_000,
"成交额": 6_170_000.0,
"今开": 12.10,
"最高": 12.50,
"最低": 12.00,
"昨收": 12.19,
}
]
)
def _make_code_name_df() -> pd.DataFrame:
"""Return a minimal code/name mapping DataFrame."""
return pd.DataFrame(
[
{"code": "000001", "name": "平安银行"},
{"code": "600519", "name": "贵州茅台"},
{"code": "000002", "name": "万科A"},
]
)
# ============================================================
# Symbol validation
# ============================================================
class TestValidateAShare:
def test_valid_starts_with_0(self):
assert akshare_service.validate_a_share_symbol("000001") is True
def test_valid_starts_with_3(self):
assert akshare_service.validate_a_share_symbol("300001") is True
def test_valid_starts_with_6(self):
assert akshare_service.validate_a_share_symbol("600519") is True
def test_invalid_starts_with_1(self):
assert akshare_service.validate_a_share_symbol("100001") is False
def test_invalid_too_short(self):
assert akshare_service.validate_a_share_symbol("00001") is False
def test_invalid_too_long(self):
assert akshare_service.validate_a_share_symbol("0000011") is False
def test_invalid_letters(self):
assert akshare_service.validate_a_share_symbol("00000A") is False
def test_invalid_empty(self):
assert akshare_service.validate_a_share_symbol("") is False
class TestValidateHKSymbol:
def test_valid_five_digits(self):
assert akshare_service.validate_hk_symbol("00700") is True
def test_valid_all_nines(self):
assert akshare_service.validate_hk_symbol("99999") is True
def test_invalid_too_short(self):
assert akshare_service.validate_hk_symbol("0070") is False
def test_invalid_too_long(self):
assert akshare_service.validate_hk_symbol("007000") is False
def test_invalid_letters(self):
assert akshare_service.validate_hk_symbol("0070A") is False
def test_invalid_empty(self):
assert akshare_service.validate_hk_symbol("") is False
# ============================================================
# _parse_hist_df
# ============================================================
class TestParseHistDf:
def test_returns_list_of_dicts(self):
df = _make_hist_df(2)
result = akshare_service._parse_hist_df(df)
assert isinstance(result, list)
assert len(result) == 2
def test_keys_are_english(self):
df = _make_hist_df(1)
result = akshare_service._parse_hist_df(df)
row = result[0]
assert "date" in row
assert "open" in row
assert "close" in row
assert "high" in row
assert "low" in row
assert "volume" in row
assert "turnover" in row
assert "change_percent" in row
assert "turnover_rate" in row
def test_no_chinese_keys_remain(self):
df = _make_hist_df(1)
result = akshare_service._parse_hist_df(df)
row = result[0]
for key in row:
assert not any(ord(c) > 127 for c in key), f"Non-ASCII key found: {key}"
def test_date_is_string(self):
df = _make_hist_df(1)
result = akshare_service._parse_hist_df(df)
assert isinstance(result[0]["date"], str)
def test_values_are_correct(self):
df = _make_hist_df(1)
result = akshare_service._parse_hist_df(df)
assert result[0]["open"] == pytest.approx(10.0)
assert result[0]["close"] == pytest.approx(10.5)
def test_empty_df_returns_empty_list(self):
df = pd.DataFrame()
result = akshare_service._parse_hist_df(df)
assert result == []
# ============================================================
# _parse_spot_row
# ============================================================
class TestParseSpotRow:
def test_returns_dict_with_english_keys(self):
df = _make_spot_df("000001", "平安银行")
result = akshare_service._parse_spot_row(df, "000001")
assert result is not None
assert "symbol" in result
assert "name" in result
assert "price" in result
def test_correct_symbol_extracted(self):
df = _make_spot_df("000001")
result = akshare_service._parse_spot_row(df, "000001")
assert result["symbol"] == "000001"
def test_returns_none_when_symbol_not_found(self):
df = _make_spot_df("000001")
result = akshare_service._parse_spot_row(df, "999999")
assert result is None
def test_price_value_correct(self):
df = _make_spot_df("600519")
df["代码"] = "600519"
result = akshare_service._parse_spot_row(df, "600519")
assert result["price"] == pytest.approx(12.34)
def test_all_quote_fields_present(self):
df = _make_spot_df("000001")
result = akshare_service._parse_spot_row(df, "000001")
expected_keys = {
"symbol", "name", "price", "change", "change_percent",
"volume", "turnover", "open", "high", "low", "prev_close",
}
assert expected_keys.issubset(set(result.keys()))
# ============================================================
# get_a_share_quote
# ============================================================
class TestGetAShareQuote:
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_zh_a_spot_em")
async def test_returns_quote_dict(self, mock_spot):
mock_spot.return_value = _make_spot_df("000001")
result = await akshare_service.get_a_share_quote("000001")
assert result is not None
assert result["symbol"] == "000001"
assert result["name"] == "平安银行"
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_zh_a_spot_em")
async def test_returns_none_for_unknown_symbol(self, mock_spot):
mock_spot.return_value = _make_spot_df("000001")
result = await akshare_service.get_a_share_quote("999999")
assert result is None
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_zh_a_spot_em")
async def test_propagates_exception(self, mock_spot):
mock_spot.side_effect = RuntimeError("AKShare unavailable")
with pytest.raises(RuntimeError):
await akshare_service.get_a_share_quote("000001")
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_zh_a_spot_em")
async def test_akshare_called_once(self, mock_spot):
mock_spot.return_value = _make_spot_df("000001")
await akshare_service.get_a_share_quote("000001")
mock_spot.assert_called_once()
# ============================================================
# get_a_share_historical
# ============================================================
class TestGetAShareHistorical:
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_zh_a_hist")
async def test_returns_list_of_bars(self, mock_hist):
mock_hist.return_value = _make_hist_df(5)
result = await akshare_service.get_a_share_historical("000001", days=30)
assert isinstance(result, list)
assert len(result) == 5
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_zh_a_hist")
async def test_bars_have_english_keys(self, mock_hist):
mock_hist.return_value = _make_hist_df(1)
result = await akshare_service.get_a_share_historical("000001", days=30)
assert "date" in result[0]
assert "open" in result[0]
assert "close" in result[0]
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_zh_a_hist")
async def test_called_with_correct_symbol(self, mock_hist):
mock_hist.return_value = _make_hist_df(1)
await akshare_service.get_a_share_historical("600519", days=90)
call_kwargs = mock_hist.call_args
assert call_kwargs.kwargs.get("symbol") == "600519" or call_kwargs.args[0] == "600519"
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_zh_a_hist")
async def test_adjust_is_qfq(self, mock_hist):
mock_hist.return_value = _make_hist_df(1)
await akshare_service.get_a_share_historical("000001", days=30)
call_kwargs = mock_hist.call_args
assert call_kwargs.kwargs.get("adjust") == "qfq"
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_zh_a_hist")
async def test_empty_df_returns_empty_list(self, mock_hist):
mock_hist.return_value = pd.DataFrame()
result = await akshare_service.get_a_share_historical("000001", days=30)
assert result == []
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_zh_a_hist")
async def test_propagates_exception(self, mock_hist):
mock_hist.side_effect = RuntimeError("network error")
with pytest.raises(RuntimeError):
await akshare_service.get_a_share_historical("000001", days=30)
# ============================================================
# search_a_shares
# ============================================================
class TestSearchAShares:
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_info_a_code_name")
async def test_returns_matching_results(self, mock_codes):
mock_codes.return_value = _make_code_name_df()
result = await akshare_service.search_a_shares("平安")
assert len(result) == 1
assert result[0]["code"] == "000001"
assert result[0]["name"] == "平安银行"
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_info_a_code_name")
async def test_returns_empty_list_when_no_match(self, mock_codes):
mock_codes.return_value = _make_code_name_df()
result = await akshare_service.search_a_shares("NONEXISTENT")
assert result == []
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_info_a_code_name")
async def test_returns_multiple_matches(self, mock_codes):
mock_codes.return_value = _make_code_name_df()
result = await akshare_service.search_a_shares("")
assert len(result) == 1
assert result[0]["code"] == "000002"
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_info_a_code_name")
async def test_result_has_code_and_name_keys(self, mock_codes):
mock_codes.return_value = _make_code_name_df()
result = await akshare_service.search_a_shares("茅台")
assert len(result) == 1
assert "code" in result[0]
assert "name" in result[0]
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_info_a_code_name")
async def test_propagates_exception(self, mock_codes):
mock_codes.side_effect = RuntimeError("timeout")
with pytest.raises(RuntimeError):
await akshare_service.search_a_shares("平安")
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_info_a_code_name")
async def test_empty_query_returns_all(self, mock_codes):
mock_codes.return_value = _make_code_name_df()
result = await akshare_service.search_a_shares("")
assert len(result) == 3
# ============================================================
# get_hk_quote
# ============================================================
class TestGetHKQuote:
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_hk_spot_em")
async def test_returns_quote_dict(self, mock_spot):
mock_spot.return_value = _make_spot_df("00700", "腾讯控股")
result = await akshare_service.get_hk_quote("00700")
assert result is not None
assert result["symbol"] == "00700"
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_hk_spot_em")
async def test_returns_none_for_unknown_symbol(self, mock_spot):
mock_spot.return_value = _make_spot_df("00700")
result = await akshare_service.get_hk_quote("99999")
assert result is None
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_hk_spot_em")
async def test_propagates_exception(self, mock_spot):
mock_spot.side_effect = RuntimeError("AKShare unavailable")
with pytest.raises(RuntimeError):
await akshare_service.get_hk_quote("00700")
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_hk_spot_em")
async def test_all_fields_present(self, mock_spot):
mock_spot.return_value = _make_spot_df("00700", "腾讯控股")
result = await akshare_service.get_hk_quote("00700")
expected_keys = {
"symbol", "name", "price", "change", "change_percent",
"volume", "turnover", "open", "high", "low", "prev_close",
}
assert expected_keys.issubset(set(result.keys()))
# ============================================================
# get_hk_historical
# ============================================================
class TestGetHKHistorical:
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_hk_hist")
async def test_returns_list_of_bars(self, mock_hist):
mock_hist.return_value = _make_hist_df(4)
result = await akshare_service.get_hk_historical("00700", days=30)
assert isinstance(result, list)
assert len(result) == 4
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_hk_hist")
async def test_bars_have_english_keys(self, mock_hist):
mock_hist.return_value = _make_hist_df(1)
result = await akshare_service.get_hk_historical("00700", days=30)
assert "date" in result[0]
assert "close" in result[0]
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_hk_hist")
async def test_called_with_correct_symbol(self, mock_hist):
mock_hist.return_value = _make_hist_df(1)
await akshare_service.get_hk_historical("09988", days=90)
call_kwargs = mock_hist.call_args
assert call_kwargs.kwargs.get("symbol") == "09988" or call_kwargs.args[0] == "09988"
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_hk_hist")
async def test_adjust_is_qfq(self, mock_hist):
mock_hist.return_value = _make_hist_df(1)
await akshare_service.get_hk_historical("00700", days=30)
call_kwargs = mock_hist.call_args
assert call_kwargs.kwargs.get("adjust") == "qfq"
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_hk_hist")
async def test_empty_df_returns_empty_list(self, mock_hist):
mock_hist.return_value = pd.DataFrame()
result = await akshare_service.get_hk_historical("00700", days=30)
assert result == []
@pytest.mark.asyncio
@patch("akshare_service.ak.stock_hk_hist")
async def test_propagates_exception(self, mock_hist):
mock_hist.side_effect = RuntimeError("network error")
with pytest.raises(RuntimeError):
await akshare_service.get_hk_historical("00700", days=30)

View File

@@ -0,0 +1,627 @@
"""Unit tests for backtest_service - written FIRST (TDD RED phase)."""
import numpy as np
import pandas as pd
import pytest
import backtest_service
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_equity(values: list[float]) -> pd.Series:
"""Build a simple equity-curve Series from a list of values."""
return pd.Series(values, dtype=float)
def _rising_prices(n: int = 100, start: float = 100.0, step: float = 1.0) -> pd.Series:
"""Linearly rising price series."""
return pd.Series([start + i * step for i in range(n)], dtype=float)
def _flat_prices(n: int = 100, price: float = 100.0) -> pd.Series:
"""Flat price series - no movement."""
return pd.Series([price] * n, dtype=float)
def _oscillating_prices(n: int = 200, period: int = 40) -> pd.Series:
"""Sinusoidal price series to generate crossover signals."""
t = np.arange(n)
prices = 100 + 20 * np.sin(2 * np.pi * t / period)
return pd.Series(prices, dtype=float)
# ---------------------------------------------------------------------------
# _compute_metrics tests
# ---------------------------------------------------------------------------
class TestComputeMetrics:
def test_total_return_positive(self):
equity = _make_equity([10000, 11000, 12000])
result = backtest_service._compute_metrics(equity, trades=1)
assert result["total_return"] == pytest.approx(0.2, abs=1e-6)
def test_total_return_negative(self):
equity = _make_equity([10000, 9000, 8000])
result = backtest_service._compute_metrics(equity, trades=1)
assert result["total_return"] == pytest.approx(-0.2, abs=1e-6)
def test_total_return_zero_on_flat(self):
equity = _make_equity([10000, 10000, 10000])
result = backtest_service._compute_metrics(equity, trades=0)
assert result["total_return"] == pytest.approx(0.0, abs=1e-6)
def test_annualized_return_shape(self):
# 252 daily bars => 1 trading year; 10000 -> 11000 = +10% annualized
values = [10000 * (1.0 + 0.1 / 252) ** i for i in range(253)]
equity = _make_equity(values)
result = backtest_service._compute_metrics(equity, trades=5)
# Should be close to 10% annualized
assert result["annualized_return"] == pytest.approx(0.1, abs=0.01)
def test_sharpe_ratio_positive_drift(self):
# Steadily rising equity with small daily increments -> positive Sharpe
values = [10000 + i * 10 for i in range(252)]
equity = _make_equity(values)
result = backtest_service._compute_metrics(equity, trades=5)
assert result["sharpe_ratio"] > 0
def test_sharpe_ratio_none_on_single_point(self):
equity = _make_equity([10000])
result = backtest_service._compute_metrics(equity, trades=0)
assert result["sharpe_ratio"] is None
def test_sharpe_ratio_none_on_zero_std(self):
# Perfectly flat equity => std = 0, Sharpe undefined
equity = _make_equity([10000] * 50)
result = backtest_service._compute_metrics(equity, trades=0)
assert result["sharpe_ratio"] is None
def test_max_drawdown_known_value(self):
# Peak 12000, trough 8000 => drawdown = (8000-12000)/12000 = -1/3
equity = _make_equity([10000, 12000, 8000, 9000])
result = backtest_service._compute_metrics(equity, trades=2)
assert result["max_drawdown"] == pytest.approx(-1 / 3, abs=1e-6)
def test_max_drawdown_zero_on_monotone_rise(self):
equity = _make_equity([10000, 11000, 12000, 13000])
result = backtest_service._compute_metrics(equity, trades=1)
assert result["max_drawdown"] == pytest.approx(0.0, abs=1e-6)
def test_total_trades_propagated(self):
equity = _make_equity([10000, 11000])
result = backtest_service._compute_metrics(equity, trades=7)
assert result["total_trades"] == 7
def test_win_rate_zero_trades(self):
equity = _make_equity([10000, 10000])
result = backtest_service._compute_metrics(equity, trades=0)
assert result["win_rate"] is None
def test_equity_curve_last_20_points(self):
values = list(range(100, 160)) # 60 points
equity = _make_equity(values)
result = backtest_service._compute_metrics(equity, trades=10)
assert len(result["equity_curve"]) == 20
assert result["equity_curve"][-1] == pytest.approx(159.0, abs=1e-6)
def test_equity_curve_shorter_than_20(self):
values = [10000, 11000, 12000]
equity = _make_equity(values)
result = backtest_service._compute_metrics(equity, trades=1)
assert len(result["equity_curve"]) == 3
def test_result_keys_present(self):
equity = _make_equity([10000, 11000])
result = backtest_service._compute_metrics(equity, trades=1)
expected_keys = {
"total_return",
"annualized_return",
"sharpe_ratio",
"max_drawdown",
"win_rate",
"total_trades",
"equity_curve",
}
assert expected_keys.issubset(result.keys())
# ---------------------------------------------------------------------------
# _compute_sma_signals tests
# ---------------------------------------------------------------------------
class TestComputeSmaSignals:
def test_returns_series_with_position_column(self):
prices = _oscillating_prices(200, period=40)
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
assert isinstance(positions, pd.Series)
assert len(positions) == len(prices)
def test_positions_are_zero_or_one(self):
prices = _oscillating_prices(200, period=40)
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
unique_vals = set(positions.dropna().unique())
assert unique_vals.issubset({0, 1})
def test_no_position_before_long_window(self):
prices = _oscillating_prices(200, period=40)
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
# Before long_window-1 data points, positions should be 0
assert (positions.iloc[: 19] == 0).all()
def test_generates_at_least_one_signal_on_oscillating(self):
prices = _oscillating_prices(300, period=60)
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
# Should flip between 0 and 1 at least once on oscillating data
changes = positions.diff().abs().sum()
assert changes > 0
def test_flat_prices_produce_no_signals(self):
prices = _flat_prices(100)
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
# After warm-up both SMAs equal price; short never strictly above long
assert (positions == 0).all()
# ---------------------------------------------------------------------------
# _compute_rsi tests
# ---------------------------------------------------------------------------
class TestComputeRsi:
def test_rsi_length(self):
prices = _rising_prices(50)
rsi = backtest_service._compute_rsi(prices, period=14)
assert len(rsi) == len(prices)
def test_rsi_range(self):
prices = _oscillating_prices(100, period=20)
rsi = backtest_service._compute_rsi(prices, period=14)
valid = rsi.dropna()
assert (valid >= 0).all()
assert (valid <= 100).all()
def test_rsi_rising_prices_high(self):
# Monotonically rising prices => RSI should be high (>= 70)
prices = _rising_prices(80, step=1.0)
rsi = backtest_service._compute_rsi(prices, period=14)
# After warm-up period, RSI should be very high
assert rsi.iloc[-1] >= 70
def test_rsi_falling_prices_low(self):
# Monotonically falling prices => RSI should be low (<= 30)
prices = pd.Series([100 - i * 0.8 for i in range(80)], dtype=float)
rsi = backtest_service._compute_rsi(prices, period=14)
assert rsi.iloc[-1] <= 30
# ---------------------------------------------------------------------------
# _compute_rsi_signals tests
# ---------------------------------------------------------------------------
class TestComputeRsiSignals:
def test_returns_series(self):
prices = _oscillating_prices(200, period=40)
positions = backtest_service._compute_rsi_signals(
prices, period=14, oversold=30, overbought=70
)
assert isinstance(positions, pd.Series)
assert len(positions) == len(prices)
def test_positions_are_zero_or_one(self):
prices = _oscillating_prices(200, period=40)
positions = backtest_service._compute_rsi_signals(
prices, period=14, oversold=30, overbought=70
)
unique_vals = set(positions.dropna().unique())
assert unique_vals.issubset({0, 1})
# ---------------------------------------------------------------------------
# backtest_sma_crossover tests (async integration of service layer)
# ---------------------------------------------------------------------------
class TestBacktestSmaCrossover:
@pytest.fixture
def mock_hist(self, monkeypatch):
"""Patch fetch_historical to return a synthetic OBBject-like result."""
prices = _oscillating_prices(300, period=60).tolist()
class FakeBar:
def __init__(self, close):
self.close = close
class FakeResult:
results = [FakeBar(p) for p in prices]
async def fake_fetch(symbol, days, **kwargs):
return FakeResult()
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
return FakeResult()
@pytest.mark.asyncio
async def test_returns_all_required_keys(self, mock_hist):
result = await backtest_service.backtest_sma_crossover(
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
)
required = {
"total_return",
"annualized_return",
"sharpe_ratio",
"max_drawdown",
"win_rate",
"total_trades",
"equity_curve",
}
assert required.issubset(result.keys())
@pytest.mark.asyncio
async def test_equity_curve_max_20_points(self, mock_hist):
result = await backtest_service.backtest_sma_crossover(
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
)
assert len(result["equity_curve"]) <= 20
@pytest.mark.asyncio
async def test_raises_value_error_on_no_data(self, monkeypatch):
async def fake_fetch(symbol, days, **kwargs):
return None
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
with pytest.raises(ValueError, match="No historical data"):
await backtest_service.backtest_sma_crossover(
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
)
@pytest.mark.asyncio
async def test_initial_capital_reflected_in_equity(self, mock_hist):
result = await backtest_service.backtest_sma_crossover(
"AAPL", short_window=5, long_window=20, days=365, initial_capital=50000
)
# equity_curve values should be in range related to 50000 initial capital
assert result["equity_curve"][0] > 0
# ---------------------------------------------------------------------------
# backtest_rsi tests
# ---------------------------------------------------------------------------
class TestBacktestRsi:
@pytest.fixture
def mock_hist(self, monkeypatch):
prices = _oscillating_prices(300, period=60).tolist()
class FakeBar:
def __init__(self, close):
self.close = close
class FakeResult:
results = [FakeBar(p) for p in prices]
async def fake_fetch(symbol, days, **kwargs):
return FakeResult()
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
@pytest.mark.asyncio
async def test_returns_all_required_keys(self, mock_hist):
result = await backtest_service.backtest_rsi(
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
)
required = {
"total_return",
"annualized_return",
"sharpe_ratio",
"max_drawdown",
"win_rate",
"total_trades",
"equity_curve",
}
assert required.issubset(result.keys())
@pytest.mark.asyncio
async def test_equity_curve_max_20_points(self, mock_hist):
result = await backtest_service.backtest_rsi(
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
)
assert len(result["equity_curve"]) <= 20
@pytest.mark.asyncio
async def test_raises_value_error_on_no_data(self, monkeypatch):
async def fake_fetch(symbol, days, **kwargs):
return None
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
with pytest.raises(ValueError, match="No historical data"):
await backtest_service.backtest_rsi(
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
)
# ---------------------------------------------------------------------------
# backtest_buy_and_hold tests
# ---------------------------------------------------------------------------
class TestBacktestBuyAndHold:
@pytest.fixture
def mock_hist_rising(self, monkeypatch):
prices = _rising_prices(252, start=100.0, step=1.0).tolist()
class FakeBar:
def __init__(self, close):
self.close = close
class FakeResult:
results = [FakeBar(p) for p in prices]
async def fake_fetch(symbol, days, **kwargs):
return FakeResult()
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
@pytest.mark.asyncio
async def test_returns_all_required_keys(self, mock_hist_rising):
result = await backtest_service.backtest_buy_and_hold(
"AAPL", days=365, initial_capital=10000
)
required = {
"total_return",
"annualized_return",
"sharpe_ratio",
"max_drawdown",
"win_rate",
"total_trades",
"equity_curve",
}
assert required.issubset(result.keys())
@pytest.mark.asyncio
async def test_total_trades_always_one(self, mock_hist_rising):
result = await backtest_service.backtest_buy_and_hold(
"AAPL", days=365, initial_capital=10000
)
assert result["total_trades"] == 1
@pytest.mark.asyncio
async def test_rising_prices_positive_return(self, mock_hist_rising):
result = await backtest_service.backtest_buy_and_hold(
"AAPL", days=365, initial_capital=10000
)
assert result["total_return"] > 0
@pytest.mark.asyncio
async def test_known_return_value(self, monkeypatch):
# 100 -> 200: 100% total return
prices = [100.0, 200.0]
class FakeBar:
def __init__(self, close):
self.close = close
class FakeResult:
results = [FakeBar(p) for p in prices]
async def fake_fetch(symbol, days, **kwargs):
return FakeResult()
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
result = await backtest_service.backtest_buy_and_hold(
"AAPL", days=365, initial_capital=10000
)
assert result["total_return"] == pytest.approx(1.0, abs=1e-6)
assert result["equity_curve"][-1] == pytest.approx(20000.0, abs=1e-6)
@pytest.mark.asyncio
async def test_raises_value_error_on_no_data(self, monkeypatch):
async def fake_fetch(symbol, days, **kwargs):
return None
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
with pytest.raises(ValueError, match="No historical data"):
await backtest_service.backtest_buy_and_hold("AAPL", days=365, initial_capital=10000)
@pytest.mark.asyncio
async def test_flat_prices_zero_return(self, monkeypatch):
prices = _flat_prices(50).tolist()
class FakeBar:
def __init__(self, close):
self.close = close
class FakeResult:
results = [FakeBar(p) for p in prices]
async def fake_fetch(symbol, days, **kwargs):
return FakeResult()
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
result = await backtest_service.backtest_buy_and_hold(
"AAPL", days=365, initial_capital=10000
)
assert result["total_return"] == pytest.approx(0.0, abs=1e-6)
# ---------------------------------------------------------------------------
# backtest_momentum tests
# ---------------------------------------------------------------------------
class TestBacktestMomentum:
@pytest.fixture
def mock_multi_hist(self, monkeypatch):
"""Three symbols with different return profiles."""
aapl_prices = _rising_prices(200, start=100.0, step=2.0).tolist()
msft_prices = _rising_prices(200, start=100.0, step=0.5).tolist()
googl_prices = _flat_prices(200, price=150.0).tolist()
price_map = {
"AAPL": aapl_prices,
"MSFT": msft_prices,
"GOOGL": googl_prices,
}
class FakeBar:
def __init__(self, close):
self.close = close
class FakeResult:
def __init__(self, prices):
self.results = [FakeBar(p) for p in prices]
async def fake_fetch(symbol, days, **kwargs):
return FakeResult(price_map[symbol])
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
@pytest.mark.asyncio
async def test_returns_all_required_keys(self, mock_multi_hist):
result = await backtest_service.backtest_momentum(
symbols=["AAPL", "MSFT", "GOOGL"],
lookback=20,
top_n=2,
rebalance_days=30,
days=365,
initial_capital=10000,
)
required = {
"total_return",
"annualized_return",
"sharpe_ratio",
"max_drawdown",
"win_rate",
"total_trades",
"equity_curve",
"allocation_history",
}
assert required.issubset(result.keys())
@pytest.mark.asyncio
async def test_allocation_history_is_list(self, mock_multi_hist):
result = await backtest_service.backtest_momentum(
symbols=["AAPL", "MSFT", "GOOGL"],
lookback=20,
top_n=2,
rebalance_days=30,
days=365,
initial_capital=10000,
)
assert isinstance(result["allocation_history"], list)
@pytest.mark.asyncio
async def test_top_n_respected_in_allocations(self, mock_multi_hist):
result = await backtest_service.backtest_momentum(
symbols=["AAPL", "MSFT", "GOOGL"],
lookback=20,
top_n=2,
rebalance_days=30,
days=365,
initial_capital=10000,
)
for entry in result["allocation_history"]:
assert len(entry["symbols"]) <= 2
@pytest.mark.asyncio
async def test_raises_value_error_on_no_data(self, monkeypatch):
async def fake_fetch(symbol, days, **kwargs):
return None
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
with pytest.raises(ValueError, match="No price data"):
await backtest_service.backtest_momentum(
symbols=["AAPL", "MSFT"],
lookback=20,
top_n=1,
rebalance_days=30,
days=365,
initial_capital=10000,
)
@pytest.mark.asyncio
async def test_equity_curve_max_20_points(self, mock_multi_hist):
result = await backtest_service.backtest_momentum(
symbols=["AAPL", "MSFT", "GOOGL"],
lookback=20,
top_n=2,
rebalance_days=30,
days=365,
initial_capital=10000,
)
assert len(result["equity_curve"]) <= 20
# ---------------------------------------------------------------------------
# Edge case: insufficient data
# ---------------------------------------------------------------------------
class TestEdgeCases:
@pytest.mark.asyncio
async def test_sma_crossover_insufficient_bars_raises(self, monkeypatch):
"""Fewer bars than long_window should raise ValueError."""
prices = [100.0, 101.0, 102.0] # Only 3 bars
class FakeBar:
def __init__(self, close):
self.close = close
class FakeResult:
results = [FakeBar(p) for p in prices]
async def fake_fetch(symbol, days, **kwargs):
return FakeResult()
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
with pytest.raises(ValueError, match="Insufficient data"):
await backtest_service.backtest_sma_crossover(
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
)
@pytest.mark.asyncio
async def test_rsi_insufficient_bars_raises(self, monkeypatch):
prices = [100.0, 101.0]
class FakeBar:
def __init__(self, close):
self.close = close
class FakeResult:
results = [FakeBar(p) for p in prices]
async def fake_fetch(symbol, days, **kwargs):
return FakeResult()
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
with pytest.raises(ValueError, match="Insufficient data"):
await backtest_service.backtest_rsi(
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
)
@pytest.mark.asyncio
async def test_buy_and_hold_single_bar_raises(self, monkeypatch):
prices = [100.0]
class FakeBar:
def __init__(self, close):
self.close = close
class FakeResult:
results = [FakeBar(p) for p in prices]
async def fake_fetch(symbol, days, **kwargs):
return FakeResult()
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
with pytest.raises(ValueError, match="Insufficient data"):
await backtest_service.backtest_buy_and_hold(
"AAPL", days=365, initial_capital=10000
)

View File

@@ -0,0 +1,187 @@
"""Tests for congress trading service (TDD - RED phase first)."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# --- get_congress_trades ---
@pytest.mark.asyncio
async def test_get_congress_trades_happy_path():
"""Returns list of trade dicts when OBB call succeeds."""
expected = [
{
"representative": "Nancy Pelosi",
"ticker": "NVDA",
"transaction_date": "2024-01-15",
"transaction_type": "Purchase",
"amount": "$1,000,001-$5,000,000",
}
]
import congress_service
mock_fn = MagicMock()
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=expected):
result = await congress_service.get_congress_trades()
assert isinstance(result, list)
assert len(result) == 1
assert result[0]["representative"] == "Nancy Pelosi"
@pytest.mark.asyncio
async def test_get_congress_trades_returns_empty_when_fn_not_available():
"""Returns empty list when OBB congress function is not available."""
import congress_service
with patch.object(congress_service, "_get_congress_fn", return_value=None):
result = await congress_service.get_congress_trades()
assert result == []
@pytest.mark.asyncio
async def test_get_congress_trades_returns_empty_on_all_provider_failures():
"""Returns empty list when all providers fail (_try_obb_call returns None)."""
import congress_service
mock_fn = MagicMock()
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=None):
result = await congress_service.get_congress_trades()
assert result == []
@pytest.mark.asyncio
async def test_get_congress_trades_empty_list_result():
"""Returns empty list when _try_obb_call returns empty list."""
import congress_service
mock_fn = MagicMock()
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=[]):
result = await congress_service.get_congress_trades()
assert result == []
# --- _get_congress_fn ---
def test_get_congress_fn_returns_none_when_attribute_missing():
"""Returns None gracefully when obb.regulators.government_us is absent."""
import congress_service
mock_obb = MagicMock(spec=[]) # spec with no attributes
with patch.object(congress_service, "obb", mock_obb):
result = congress_service._get_congress_fn()
assert result is None
def test_get_congress_fn_returns_callable_when_available():
"""Returns the congress_trading callable when attribute exists."""
import congress_service
mock_fn = MagicMock()
mock_obb = MagicMock()
mock_obb.regulators.government_us.congress_trading = mock_fn
with patch.object(congress_service, "obb", mock_obb):
result = congress_service._get_congress_fn()
assert result is mock_fn
# --- _try_obb_call ---
@pytest.mark.asyncio
async def test_try_obb_call_returns_list_on_success():
"""_try_obb_call converts OBBject result to list via to_list."""
import congress_service
mock_result = MagicMock()
expected = [{"ticker": "AAPL"}]
with patch.object(congress_service, "to_list", return_value=expected), \
patch("congress_service.asyncio.to_thread", new_callable=AsyncMock, return_value=mock_result):
result = await congress_service._try_obb_call(MagicMock())
assert result == expected
@pytest.mark.asyncio
async def test_try_obb_call_returns_none_on_exception():
"""_try_obb_call returns None when asyncio.to_thread raises."""
import congress_service
with patch("congress_service.asyncio.to_thread", new_callable=AsyncMock, side_effect=Exception("fail")):
result = await congress_service._try_obb_call(MagicMock())
assert result is None
# --- search_congress_bills ---
@pytest.mark.asyncio
async def test_search_congress_bills_happy_path():
"""Returns list of bill dicts when OBB call succeeds."""
expected = [
{"title": "Infrastructure Investment and Jobs Act", "bill_id": "HR3684"},
{"title": "Inflation Reduction Act", "bill_id": "HR5376"},
]
import congress_service
mock_fn = MagicMock()
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=expected):
result = await congress_service.search_congress_bills("infrastructure")
assert isinstance(result, list)
assert len(result) == 2
assert result[0]["bill_id"] == "HR3684"
@pytest.mark.asyncio
async def test_search_congress_bills_returns_empty_when_fn_not_available():
"""Returns empty list when OBB function is not available."""
import congress_service
with patch.object(congress_service, "_get_congress_fn", return_value=None):
result = await congress_service.search_congress_bills("taxes")
assert result == []
@pytest.mark.asyncio
async def test_search_congress_bills_returns_empty_on_failure():
"""Returns empty list when all providers fail."""
import congress_service
mock_fn = MagicMock()
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=None):
result = await congress_service.search_congress_bills("taxes")
assert result == []
@pytest.mark.asyncio
async def test_search_congress_bills_empty_results():
"""Returns empty list when _try_obb_call returns empty list."""
import congress_service
mock_fn = MagicMock()
with patch.object(congress_service, "_get_congress_fn", return_value=mock_fn), \
patch.object(congress_service, "_try_obb_call", new_callable=AsyncMock, return_value=[]):
result = await congress_service.search_congress_bills("nonexistent")
assert result == []

655
tests/test_defi_service.py Normal file
View File

@@ -0,0 +1,655 @@
"""Tests for defi_service.py - DefiLlama API integration.
TDD: these tests are written before implementation.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from defi_service import (
get_chain_tvls,
get_dex_volumes,
get_protocol_fees,
get_protocol_tvl,
get_stablecoins,
get_top_protocols,
get_yield_pools,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
SAMPLE_PROTOCOLS = [
{
"name": "Aave",
"symbol": "AAVE",
"tvl": 10_000_000_000.0,
"chain": "Ethereum",
"chains": ["Ethereum", "Polygon"],
"category": "Lending",
"change_1d": 0.5,
"change_7d": -1.2,
},
{
"name": "Uniswap",
"symbol": "UNI",
"tvl": 8_000_000_000.0,
"chain": "Ethereum",
"chains": ["Ethereum"],
"category": "DEX",
"change_1d": 1.0,
"change_7d": 2.0,
},
]
SAMPLE_CHAINS = [
{"name": "Ethereum", "tvl": 50_000_000_000.0, "tokenSymbol": "ETH"},
{"name": "BSC", "tvl": 5_000_000_000.0, "tokenSymbol": "BNB"},
]
SAMPLE_POOLS = [
{
"pool": "0xabcd",
"chain": "Ethereum",
"project": "aave-v3",
"symbol": "USDC",
"tvlUsd": 1_000_000_000.0,
"apy": 3.5,
"apyBase": 3.0,
"apyReward": 0.5,
},
{
"pool": "0x1234",
"chain": "Polygon",
"project": "curve",
"symbol": "DAI",
"tvlUsd": 500_000_000.0,
"apy": 4.2,
"apyBase": 2.5,
"apyReward": 1.7,
},
]
SAMPLE_STABLECOINS_RESPONSE = {
"peggedAssets": [
{
"name": "Tether",
"symbol": "USDT",
"pegType": "peggedUSD",
"circulating": {"peggedUSD": 100_000_000_000.0},
"price": 1.0,
},
{
"name": "USD Coin",
"symbol": "USDC",
"pegType": "peggedUSD",
"circulating": {"peggedUSD": 40_000_000_000.0},
"price": 1.0,
},
]
}
SAMPLE_DEX_RESPONSE = {
"total24h": 5_000_000_000.0,
"total7d": 30_000_000_000.0,
"protocols": [
{"name": "Uniswap", "total24h": 2_000_000_000.0},
{"name": "Curve", "total24h": 500_000_000.0},
],
}
SAMPLE_FEES_RESPONSE = {
"protocols": [
{"name": "Uniswap", "total24h": 1_000_000.0, "revenue24h": 500_000.0},
{"name": "Aave", "total24h": 800_000.0, "revenue24h": 800_000.0},
],
}
def _make_mock_client(json_return):
"""Build a fully configured AsyncMock for httpx.AsyncClient context manager.
Uses MagicMock for the response object so that resp.json() (a sync method
in httpx) returns the value directly rather than a coroutine.
"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = json_return
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
return mock_client
# ---------------------------------------------------------------------------
# get_top_protocols
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_top_protocols_returns_top_20(mock_client_cls):
raw = [
{
"name": f"Protocol{i}",
"symbol": f"P{i}",
"tvl": float(100 - i),
"chain": "Ethereum",
"chains": ["Ethereum"],
"category": "Lending",
"change_1d": 0.1,
"change_7d": 0.2,
}
for i in range(30)
]
mock_client_cls.return_value = _make_mock_client(raw)
result = await get_top_protocols()
assert len(result) == 20
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_top_protocols_returns_correct_fields(mock_client_cls):
mock_client_cls.return_value = _make_mock_client(SAMPLE_PROTOCOLS)
result = await get_top_protocols()
assert len(result) == 2
first = result[0]
assert first["name"] == "Aave"
assert first["symbol"] == "AAVE"
assert first["tvl"] == 10_000_000_000.0
assert first["chain"] == "Ethereum"
assert first["chains"] == ["Ethereum", "Polygon"]
assert first["category"] == "Lending"
assert first["change_1d"] == 0.5
assert first["change_7d"] == -1.2
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_top_protocols_respects_custom_limit(mock_client_cls):
raw = [
{"name": f"P{i}", "symbol": f"S{i}", "tvl": float(i), "chain": "ETH",
"chains": [], "category": "DEX", "change_1d": 0.0, "change_7d": 0.0}
for i in range(25)
]
mock_client_cls.return_value = _make_mock_client(raw)
result = await get_top_protocols(limit=5)
assert len(result) == 5
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_top_protocols_handles_missing_fields(mock_client_cls):
raw = [{"name": "Sparse"}] # missing all optional fields
mock_client_cls.return_value = _make_mock_client(raw)
result = await get_top_protocols()
assert len(result) == 1
assert result[0]["name"] == "Sparse"
assert result[0]["tvl"] is None
assert result[0]["symbol"] is None
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_top_protocols_returns_empty_on_http_error(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status.side_effect = Exception("HTTP 500")
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await get_top_protocols()
assert result == []
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_top_protocols_calls_correct_url(mock_client_cls):
mock_client = _make_mock_client([])
mock_client_cls.return_value = mock_client
await get_top_protocols()
mock_client.get.assert_called_once_with("https://api.llama.fi/protocols")
# ---------------------------------------------------------------------------
# get_chain_tvls
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_chain_tvls_returns_all_chains(mock_client_cls):
mock_client_cls.return_value = _make_mock_client(SAMPLE_CHAINS)
result = await get_chain_tvls()
assert len(result) == 2
assert result[0]["name"] == "Ethereum"
assert result[0]["tvl"] == 50_000_000_000.0
assert result[0]["tokenSymbol"] == "ETH"
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_chain_tvls_handles_missing_token_symbol(mock_client_cls):
raw = [{"name": "SomeChain", "tvl": 100.0}]
mock_client_cls.return_value = _make_mock_client(raw)
result = await get_chain_tvls()
assert result[0]["tokenSymbol"] is None
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_chain_tvls_returns_empty_on_error(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status.side_effect = Exception("timeout")
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await get_chain_tvls()
assert result == []
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_chain_tvls_calls_correct_url(mock_client_cls):
mock_client = _make_mock_client([])
mock_client_cls.return_value = mock_client
await get_chain_tvls()
mock_client.get.assert_called_once_with("https://api.llama.fi/v2/chains")
# ---------------------------------------------------------------------------
# get_protocol_tvl
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_protocol_tvl_returns_numeric_value(mock_client_cls):
mock_client_cls.return_value = _make_mock_client(10_000_000_000.0)
result = await get_protocol_tvl("aave")
assert result == 10_000_000_000.0
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_protocol_tvl_calls_correct_url(mock_client_cls):
mock_client = _make_mock_client(1234.0)
mock_client_cls.return_value = mock_client
await get_protocol_tvl("uniswap")
mock_client.get.assert_called_once_with("https://api.llama.fi/tvl/uniswap")
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_protocol_tvl_returns_none_on_error(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status.side_effect = Exception("404 not found")
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await get_protocol_tvl("nonexistent")
assert result is None
# ---------------------------------------------------------------------------
# get_yield_pools
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_yield_pools_returns_top_20_by_tvl(mock_client_cls):
# Create 30 pools with decreasing TVL
raw_data = {
"data": [
{
"pool": f"pool{i}",
"chain": "Ethereum",
"project": "aave",
"symbol": "USDC",
"tvlUsd": float(1000 - i),
"apy": 3.0,
"apyBase": 2.5,
"apyReward": 0.5,
}
for i in range(30)
]
}
mock_client_cls.return_value = _make_mock_client(raw_data)
result = await get_yield_pools()
assert len(result) == 20
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_yield_pools_returns_correct_fields(mock_client_cls):
mock_client_cls.return_value = _make_mock_client({"data": SAMPLE_POOLS})
result = await get_yield_pools()
assert len(result) == 2
first = result[0]
assert first["pool"] == "0xabcd"
assert first["chain"] == "Ethereum"
assert first["project"] == "aave-v3"
assert first["symbol"] == "USDC"
assert first["tvlUsd"] == 1_000_000_000.0
assert first["apy"] == 3.5
assert first["apyBase"] == 3.0
assert first["apyReward"] == 0.5
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_yield_pools_filters_by_chain(mock_client_cls):
mock_client_cls.return_value = _make_mock_client({"data": SAMPLE_POOLS})
result = await get_yield_pools(chain="Ethereum")
assert len(result) == 1
assert result[0]["chain"] == "Ethereum"
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_yield_pools_filters_by_project(mock_client_cls):
mock_client_cls.return_value = _make_mock_client({"data": SAMPLE_POOLS})
result = await get_yield_pools(project="curve")
assert len(result) == 1
assert result[0]["project"] == "curve"
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_yield_pools_filters_by_chain_and_project(mock_client_cls):
pools = SAMPLE_POOLS + [
{
"pool": "0xzzzz",
"chain": "Ethereum",
"project": "curve",
"symbol": "USDT",
"tvlUsd": 200_000_000.0,
"apy": 2.0,
"apyBase": 2.0,
"apyReward": 0.0,
}
]
mock_client_cls.return_value = _make_mock_client({"data": pools})
result = await get_yield_pools(chain="Ethereum", project="curve")
assert len(result) == 1
assert result[0]["pool"] == "0xzzzz"
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_yield_pools_returns_empty_on_error(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status.side_effect = Exception("Connection error")
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await get_yield_pools()
assert result == []
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_yield_pools_calls_correct_url(mock_client_cls):
mock_client = _make_mock_client({"data": []})
mock_client_cls.return_value = mock_client
await get_yield_pools()
mock_client.get.assert_called_once_with("https://yields.llama.fi/pools")
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_yield_pools_sorts_by_tvl_descending(mock_client_cls):
unsorted_pools = [
{
"pool": "low", "chain": "Ethereum", "project": "aave",
"symbol": "DAI", "tvlUsd": 100.0, "apy": 1.0,
"apyBase": 1.0, "apyReward": 0.0,
},
{
"pool": "high", "chain": "Ethereum", "project": "aave",
"symbol": "USDC", "tvlUsd": 9000.0, "apy": 2.0,
"apyBase": 2.0, "apyReward": 0.0,
},
]
mock_client_cls.return_value = _make_mock_client({"data": unsorted_pools})
result = await get_yield_pools()
assert result[0]["pool"] == "high"
assert result[1]["pool"] == "low"
# ---------------------------------------------------------------------------
# get_stablecoins
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_stablecoins_returns_correct_fields(mock_client_cls):
mock_client_cls.return_value = _make_mock_client(SAMPLE_STABLECOINS_RESPONSE)
result = await get_stablecoins()
assert len(result) == 2
first = result[0]
assert first["name"] == "Tether"
assert first["symbol"] == "USDT"
assert first["pegType"] == "peggedUSD"
assert first["circulating"] == 100_000_000_000.0
assert first["price"] == 1.0
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_stablecoins_returns_top_20(mock_client_cls):
assets = [
{
"name": f"Stable{i}",
"symbol": f"S{i}",
"pegType": "peggedUSD",
"circulating": {"peggedUSD": float(1000 - i)},
"price": 1.0,
}
for i in range(25)
]
mock_client_cls.return_value = _make_mock_client({"peggedAssets": assets})
result = await get_stablecoins()
assert len(result) == 20
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_stablecoins_handles_missing_circulating(mock_client_cls):
raw = {"peggedAssets": [{"name": "NoCirc", "symbol": "NC", "pegType": "peggedUSD", "price": 1.0}]}
mock_client_cls.return_value = _make_mock_client(raw)
result = await get_stablecoins()
assert result[0]["circulating"] is None
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_stablecoins_returns_empty_on_error(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status.side_effect = Exception("timeout")
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await get_stablecoins()
assert result == []
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_stablecoins_calls_correct_url(mock_client_cls):
mock_client = _make_mock_client({"peggedAssets": []})
mock_client_cls.return_value = mock_client
await get_stablecoins()
mock_client.get.assert_called_once_with("https://stablecoins.llama.fi/stablecoins")
# ---------------------------------------------------------------------------
# get_dex_volumes
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_dex_volumes_returns_correct_structure(mock_client_cls):
mock_client_cls.return_value = _make_mock_client(SAMPLE_DEX_RESPONSE)
result = await get_dex_volumes()
assert result["totalVolume24h"] == 5_000_000_000.0
assert result["totalVolume7d"] == 30_000_000_000.0
assert len(result["protocols"]) == 2
assert result["protocols"][0]["name"] == "Uniswap"
assert result["protocols"][0]["volume24h"] == 2_000_000_000.0
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_dex_volumes_returns_none_on_error(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status.side_effect = Exception("server error")
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await get_dex_volumes()
assert result is None
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_dex_volumes_calls_correct_url(mock_client_cls):
mock_client = _make_mock_client({"total24h": 0, "total7d": 0, "protocols": []})
mock_client_cls.return_value = mock_client
await get_dex_volumes()
mock_client.get.assert_called_once_with("https://api.llama.fi/overview/dexs")
# ---------------------------------------------------------------------------
# get_protocol_fees
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_protocol_fees_returns_correct_structure(mock_client_cls):
mock_client_cls.return_value = _make_mock_client(SAMPLE_FEES_RESPONSE)
result = await get_protocol_fees()
assert len(result) == 2
first = result[0]
assert first["name"] == "Uniswap"
assert first["fees24h"] == 1_000_000.0
assert first["revenue24h"] == 500_000.0
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_protocol_fees_returns_empty_on_error(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status.side_effect = Exception("connection reset")
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_resp)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
result = await get_protocol_fees()
assert result == []
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_protocol_fees_calls_correct_url(mock_client_cls):
mock_client = _make_mock_client({"protocols": []})
mock_client_cls.return_value = mock_client
await get_protocol_fees()
mock_client.get.assert_called_once_with("https://api.llama.fi/overview/fees")
@pytest.mark.asyncio
@patch("defi_service.httpx.AsyncClient")
async def test_get_protocol_fees_handles_missing_revenue(mock_client_cls):
raw = {"protocols": [{"name": "SomeProtocol", "total24h": 500_000.0}]}
mock_client_cls.return_value = _make_mock_client(raw)
result = await get_protocol_fees()
assert result[0]["revenue24h"] is None

View File

@@ -0,0 +1,368 @@
"""Tests for the new social sentiment functions in finnhub_service."""
from unittest.mock import patch, AsyncMock, MagicMock
import pytest
import finnhub_service
import reddit_service
# --- get_social_sentiment ---
@pytest.mark.asyncio
@patch("finnhub_service.settings")
async def test_social_sentiment_not_configured(mock_settings):
mock_settings.finnhub_api_key = ""
result = await finnhub_service.get_social_sentiment("AAPL")
assert result["configured"] is False
assert "INVEST_API_FINNHUB_API_KEY" in result["message"]
@pytest.mark.asyncio
@patch("finnhub_service.settings")
@patch("finnhub_service.httpx.AsyncClient")
async def test_social_sentiment_premium_required_403(mock_client_cls, mock_settings):
mock_settings.finnhub_api_key = "test_key"
mock_resp = MagicMock()
mock_resp.status_code = 403
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await finnhub_service.get_social_sentiment("AAPL")
assert result["configured"] is True
assert result["premium_required"] is True
assert result["reddit"] == []
assert result["twitter"] == []
@pytest.mark.asyncio
@patch("finnhub_service.settings")
@patch("finnhub_service.httpx.AsyncClient")
async def test_social_sentiment_premium_required_401(mock_client_cls, mock_settings):
mock_settings.finnhub_api_key = "test_key"
mock_resp = MagicMock()
mock_resp.status_code = 401
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await finnhub_service.get_social_sentiment("AAPL")
assert result["premium_required"] is True
@pytest.mark.asyncio
@patch("finnhub_service.settings")
@patch("finnhub_service.httpx.AsyncClient")
async def test_social_sentiment_success_with_data(mock_client_cls, mock_settings):
mock_settings.finnhub_api_key = "test_key"
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"reddit": [
{"mention": 50, "positiveScore": 30, "negativeScore": 10, "score": 0.5},
{"mention": 30, "positiveScore": 20, "negativeScore": 5, "score": 0.6},
],
"twitter": [
{"mention": 100, "positiveScore": 60, "negativeScore": 20, "score": 0.4},
],
}
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await finnhub_service.get_social_sentiment("AAPL")
assert result["configured"] is True
assert result["symbol"] == "AAPL"
assert result["reddit_summary"]["total_mentions"] == 80
assert result["reddit_summary"]["data_points"] == 2
assert result["twitter_summary"]["total_mentions"] == 100
assert len(result["reddit"]) == 2
assert len(result["twitter"]) == 1
@pytest.mark.asyncio
@patch("finnhub_service.settings")
@patch("finnhub_service.httpx.AsyncClient")
async def test_social_sentiment_empty_lists(mock_client_cls, mock_settings):
mock_settings.finnhub_api_key = "test_key"
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"reddit": [], "twitter": []}
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await finnhub_service.get_social_sentiment("AAPL")
assert result["configured"] is True
assert result["reddit_summary"] is None
assert result["twitter_summary"] is None
@pytest.mark.asyncio
@patch("finnhub_service.settings")
@patch("finnhub_service.httpx.AsyncClient")
async def test_social_sentiment_non_dict_response(mock_client_cls, mock_settings):
mock_settings.finnhub_api_key = "test_key"
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = "unexpected string response"
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await finnhub_service.get_social_sentiment("AAPL")
assert result["reddit"] == []
assert result["twitter"] == []
# --- _summarize_social ---
def test_summarize_social_empty():
result = finnhub_service._summarize_social([])
assert result == {}
def test_summarize_social_single_entry():
entries = [{"mention": 10, "positiveScore": 7, "negativeScore": 2, "score": 0.5}]
result = finnhub_service._summarize_social(entries)
assert result["total_mentions"] == 10
assert result["total_positive"] == 7
assert result["total_negative"] == 2
assert result["avg_score"] == 0.5
assert result["data_points"] == 1
def test_summarize_social_multiple_entries():
entries = [
{"mention": 100, "positiveScore": 60, "negativeScore": 20, "score": 0.4},
{"mention": 50, "positiveScore": 30, "negativeScore": 10, "score": 0.6},
]
result = finnhub_service._summarize_social(entries)
assert result["total_mentions"] == 150
assert result["total_positive"] == 90
assert result["total_negative"] == 30
assert result["avg_score"] == 0.5
assert result["data_points"] == 2
def test_summarize_social_missing_fields():
entries = [{"mention": 5}]
result = finnhub_service._summarize_social(entries)
assert result["total_mentions"] == 5
assert result["total_positive"] == 0
assert result["total_negative"] == 0
# --- get_reddit_sentiment ---
@pytest.mark.asyncio
@patch("reddit_service.httpx.AsyncClient")
async def test_reddit_sentiment_symbol_found(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"ticker": "AAPL", "name": "Apple Inc", "rank": 3, "mentions": 150, "mentions_24h_ago": 100, "upvotes": 500, "rank_24h_ago": 5},
{"ticker": "TSLA", "name": "Tesla", "rank": 1, "mentions": 300, "mentions_24h_ago": 280, "upvotes": 900, "rank_24h_ago": 2},
]
}
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await reddit_service.get_reddit_sentiment("AAPL")
assert result["found"] is True
assert result["symbol"] == "AAPL"
assert result["rank"] == 3
assert result["mentions_24h"] == 150
assert result["mentions_24h_ago"] == 100
assert result["mentions_change_pct"] == 50.0
@pytest.mark.asyncio
@patch("reddit_service.httpx.AsyncClient")
async def test_reddit_sentiment_symbol_not_found(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"ticker": "TSLA", "rank": 1, "mentions": 300, "mentions_24h_ago": 280}
]
}
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await reddit_service.get_reddit_sentiment("AAPL")
assert result["found"] is False
assert result["symbol"] == "AAPL"
assert "not in Reddit" in result["message"]
@pytest.mark.asyncio
@patch("reddit_service.httpx.AsyncClient")
async def test_reddit_sentiment_zero_mentions_prev(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"ticker": "AAPL", "rank": 1, "mentions": 50, "mentions_24h_ago": 0, "upvotes": 200, "rank_24h_ago": None}
]
}
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await reddit_service.get_reddit_sentiment("AAPL")
assert result["found"] is True
assert result["mentions_change_pct"] is None # division by zero handled
@pytest.mark.asyncio
@patch("reddit_service.httpx.AsyncClient")
async def test_reddit_sentiment_api_failure(mock_client_cls):
mock_client = AsyncMock()
mock_client.get.side_effect = Exception("Connection error")
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await reddit_service.get_reddit_sentiment("AAPL")
assert result["symbol"] == "AAPL"
assert "error" in result
@pytest.mark.asyncio
@patch("reddit_service.httpx.AsyncClient")
async def test_reddit_sentiment_case_insensitive(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"ticker": "aapl", "rank": 1, "mentions": 100, "mentions_24h_ago": 80, "upvotes": 400, "rank_24h_ago": 2}
]
}
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await reddit_service.get_reddit_sentiment("AAPL")
assert result["found"] is True
# --- get_reddit_trending ---
@pytest.mark.asyncio
@patch("reddit_service.httpx.AsyncClient")
async def test_reddit_trending_happy_path(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"ticker": "TSLA", "name": "Tesla", "rank": 1, "mentions": 500, "upvotes": 1000, "rank_24h_ago": 2, "mentions_24h_ago": 400},
{"ticker": "AAPL", "name": "Apple", "rank": 2, "mentions": 300, "upvotes": 700, "rank_24h_ago": 1, "mentions_24h_ago": 350},
{"ticker": "GME", "name": "GameStop", "rank": 3, "mentions": 200, "upvotes": 500, "rank_24h_ago": 3, "mentions_24h_ago": 180},
]
}
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await reddit_service.get_reddit_trending()
assert len(result) == 3
assert result[0]["symbol"] == "TSLA"
assert result[0]["rank"] == 1
assert result[1]["symbol"] == "AAPL"
assert "mentions_24h" in result[0]
assert "upvotes" in result[0]
@pytest.mark.asyncio
@patch("reddit_service.httpx.AsyncClient")
async def test_reddit_trending_limits_to_25(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {
"results": [
{"ticker": f"SYM{i}", "rank": i + 1, "mentions": 100 - i, "upvotes": 50, "rank_24h_ago": i, "mentions_24h_ago": 80}
for i in range(30)
]
}
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await reddit_service.get_reddit_trending()
assert len(result) == 25
@pytest.mark.asyncio
@patch("reddit_service.httpx.AsyncClient")
async def test_reddit_trending_empty_results(mock_client_cls):
mock_resp = MagicMock()
mock_resp.raise_for_status = MagicMock()
mock_resp.json.return_value = {"results": []}
mock_client = AsyncMock()
mock_client.get.return_value = mock_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await reddit_service.get_reddit_trending()
assert result == []
@pytest.mark.asyncio
@patch("reddit_service.httpx.AsyncClient")
async def test_reddit_trending_api_failure(mock_client_cls):
mock_client = AsyncMock()
mock_client.get.side_effect = Exception("ApeWisdom down")
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client_cls.return_value = mock_client
result = await reddit_service.get_reddit_trending()
assert result == []

View File

@@ -0,0 +1,559 @@
"""Tests for portfolio optimization service (TDD - RED phase first)."""
from unittest.mock import AsyncMock, patch
import pytest
# --- HRP Optimization ---
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_hrp_optimize_happy_path(mock_fetch):
"""HRP returns weights that sum to ~1.0 for valid symbols."""
import pandas as pd
mock_fetch.return_value = pd.DataFrame(
{
"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0],
"MSFT": [300.0, 302.0, 298.0, 305.0, 307.0],
"GOOGL": [2800.0, 2820.0, 2790.0, 2830.0, 2850.0],
}
)
import portfolio_service
result = await portfolio_service.optimize_hrp(
["AAPL", "MSFT", "GOOGL"], days=365
)
assert result["method"] == "hrp"
assert set(result["weights"].keys()) == {"AAPL", "MSFT", "GOOGL"}
total = sum(result["weights"].values())
assert abs(total - 1.0) < 0.01
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_hrp_optimize_single_symbol(mock_fetch):
"""Single symbol gets weight of 1.0."""
import pandas as pd
mock_fetch.return_value = pd.DataFrame(
{"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0]}
)
import portfolio_service
result = await portfolio_service.optimize_hrp(["AAPL"], days=365)
assert result["weights"]["AAPL"] == pytest.approx(1.0, abs=0.01)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_hrp_optimize_no_data_raises(mock_fetch):
"""Raises ValueError when no price data is available."""
import pandas as pd
mock_fetch.return_value = pd.DataFrame()
import portfolio_service
with pytest.raises(ValueError, match="No price data"):
await portfolio_service.optimize_hrp(["AAPL", "MSFT"], days=365)
@pytest.mark.asyncio
async def test_hrp_optimize_empty_symbols_raises():
"""Raises ValueError for empty symbol list."""
import portfolio_service
with pytest.raises(ValueError, match="symbols"):
await portfolio_service.optimize_hrp([], days=365)
# --- Correlation Matrix ---
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_correlation_matrix_happy_path(mock_fetch):
"""Correlation matrix has 1.0 on diagonal and valid shape."""
import pandas as pd
mock_fetch.return_value = pd.DataFrame(
{
"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0],
"MSFT": [300.0, 302.0, 298.0, 305.0, 307.0],
"GOOGL": [2800.0, 2820.0, 2790.0, 2830.0, 2850.0],
}
)
import portfolio_service
result = await portfolio_service.compute_correlation(
["AAPL", "MSFT", "GOOGL"], days=365
)
assert result["symbols"] == ["AAPL", "MSFT", "GOOGL"]
matrix = result["matrix"]
assert len(matrix) == 3
assert len(matrix[0]) == 3
# Diagonal should be 1.0
for i in range(3):
assert abs(matrix[i][i] - 1.0) < 0.01
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_correlation_matrix_two_symbols(mock_fetch):
"""Two-symbol correlation is symmetric."""
import pandas as pd
mock_fetch.return_value = pd.DataFrame(
{
"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0],
"MSFT": [300.0, 302.0, 298.0, 305.0, 307.0],
}
)
import portfolio_service
result = await portfolio_service.compute_correlation(["AAPL", "MSFT"], days=365)
matrix = result["matrix"]
# Symmetric: matrix[0][1] == matrix[1][0]
assert abs(matrix[0][1] - matrix[1][0]) < 0.001
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_correlation_no_data_raises(mock_fetch):
"""Raises ValueError when no data is returned."""
import pandas as pd
mock_fetch.return_value = pd.DataFrame()
import portfolio_service
with pytest.raises(ValueError, match="No price data"):
await portfolio_service.compute_correlation(["AAPL", "MSFT"], days=365)
@pytest.mark.asyncio
async def test_correlation_empty_symbols_raises():
"""Raises ValueError for empty symbol list."""
import portfolio_service
with pytest.raises(ValueError, match="symbols"):
await portfolio_service.compute_correlation([], days=365)
# --- Risk Parity ---
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_risk_parity_happy_path(mock_fetch):
"""Risk parity returns weights and risk_contributions summing to ~1.0."""
import pandas as pd
mock_fetch.return_value = pd.DataFrame(
{
"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0],
"MSFT": [300.0, 302.0, 298.0, 305.0, 307.0],
"GOOGL": [2800.0, 2820.0, 2790.0, 2830.0, 2850.0],
}
)
import portfolio_service
result = await portfolio_service.compute_risk_parity(
["AAPL", "MSFT", "GOOGL"], days=365
)
assert result["method"] == "risk_parity"
assert set(result["weights"].keys()) == {"AAPL", "MSFT", "GOOGL"}
assert set(result["risk_contributions"].keys()) == {"AAPL", "MSFT", "GOOGL"}
total_w = sum(result["weights"].values())
assert abs(total_w - 1.0) < 0.01
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_risk_parity_single_symbol(mock_fetch):
"""Single symbol gets weight 1.0 and risk_contribution 1.0."""
import pandas as pd
mock_fetch.return_value = pd.DataFrame(
{"AAPL": [150.0, 151.0, 149.0, 152.0, 153.0]}
)
import portfolio_service
result = await portfolio_service.compute_risk_parity(["AAPL"], days=365)
assert result["weights"]["AAPL"] == pytest.approx(1.0, abs=0.01)
assert result["risk_contributions"]["AAPL"] == pytest.approx(1.0, abs=0.01)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_risk_parity_no_data_raises(mock_fetch):
"""Raises ValueError when no price data is available."""
import pandas as pd
mock_fetch.return_value = pd.DataFrame()
import portfolio_service
with pytest.raises(ValueError, match="No price data"):
await portfolio_service.compute_risk_parity(["AAPL", "MSFT"], days=365)
@pytest.mark.asyncio
async def test_risk_parity_empty_symbols_raises():
"""Raises ValueError for empty symbol list."""
import portfolio_service
with pytest.raises(ValueError, match="symbols"):
await portfolio_service.compute_risk_parity([], days=365)
# --- fetch_historical_prices helper ---
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical")
async def test_fetch_historical_prices_returns_dataframe(mock_fetch_hist):
"""fetch_historical_prices assembles a price DataFrame from OBBject results."""
import pandas as pd
from unittest.mock import MagicMock
mock_result = MagicMock()
mock_result.results = [
MagicMock(date="2024-01-01", close=150.0),
MagicMock(date="2024-01-02", close=151.0),
]
mock_fetch_hist.return_value = mock_result
import portfolio_service
df = await portfolio_service.fetch_historical_prices(["AAPL"], days=30)
assert isinstance(df, pd.DataFrame)
assert "AAPL" in df.columns
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical")
async def test_fetch_historical_prices_skips_none(mock_fetch_hist):
"""fetch_historical_prices returns empty DataFrame when all fetches fail."""
import pandas as pd
mock_fetch_hist.return_value = None
import portfolio_service
df = await portfolio_service.fetch_historical_prices(["AAPL", "MSFT"], days=30)
assert isinstance(df, pd.DataFrame)
assert df.empty
# ---------------------------------------------------------------------------
# cluster_stocks
# ---------------------------------------------------------------------------
def _make_prices(symbols: list[str], n_days: int = 60):
"""Build a deterministic price DataFrame with enough rows for t-SNE."""
import numpy as np
import pandas as pd
rng = np.random.default_rng(42)
data = {}
for sym in symbols:
prices = 100.0 + np.cumsum(rng.normal(0, 1, n_days))
data[sym] = prices
return pd.DataFrame(data)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_happy_path(mock_fetch):
"""cluster_stocks returns valid structure for 6 symbols."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.cluster_stocks(symbols, days=180)
assert result["method"] == "t-SNE + KMeans"
assert result["days"] == 180
assert set(result["symbols"]) == set(symbols)
coords = result["coordinates"]
assert len(coords) == len(symbols)
for c in coords:
assert "symbol" in c
assert "x" in c
assert "y" in c
assert "cluster" in c
assert isinstance(c["x"], float)
assert isinstance(c["y"], float)
assert isinstance(c["cluster"], int)
clusters = result["clusters"]
assert isinstance(clusters, dict)
all_in_clusters = []
for members in clusters.values():
all_in_clusters.extend(members)
assert set(all_in_clusters) == set(symbols)
assert "n_clusters" in result
assert result["n_clusters"] >= 2
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_custom_n_clusters(mock_fetch):
"""Custom n_clusters is respected in the output."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.cluster_stocks(symbols, days=180, n_clusters=3)
assert result["n_clusters"] == 3
assert len(result["clusters"]) == 3
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_minimum_three_symbols(mock_fetch):
"""cluster_stocks works correctly with exactly 3 symbols (minimum)."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.cluster_stocks(symbols, days=180)
assert len(result["coordinates"]) == 3
assert set(result["symbols"]) == set(symbols)
@pytest.mark.asyncio
async def test_cluster_stocks_too_few_symbols_raises():
"""cluster_stocks raises ValueError when fewer than 3 symbols are provided."""
import portfolio_service
with pytest.raises(ValueError, match="at least 3"):
await portfolio_service.cluster_stocks(["AAPL", "MSFT"], days=180)
@pytest.mark.asyncio
async def test_cluster_stocks_empty_symbols_raises():
"""cluster_stocks raises ValueError for empty symbol list."""
import portfolio_service
with pytest.raises(ValueError, match="at least 3"):
await portfolio_service.cluster_stocks([], days=180)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_no_data_raises(mock_fetch):
"""cluster_stocks raises ValueError when fetch returns empty DataFrame."""
import pandas as pd
import portfolio_service
mock_fetch.return_value = pd.DataFrame()
with pytest.raises(ValueError, match="No price data"):
await portfolio_service.cluster_stocks(["AAPL", "MSFT", "GOOGL"], days=180)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_identical_returns_still_works(mock_fetch):
"""t-SNE should not raise even when all symbols have identical returns."""
import pandas as pd
import portfolio_service
# All columns identical — edge case for t-SNE
flat = pd.DataFrame(
{
"AAPL": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
"MSFT": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
"GOOGL": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
}
)
mock_fetch.return_value = flat
result = await portfolio_service.cluster_stocks(
["AAPL", "MSFT", "GOOGL"], days=180
)
assert len(result["coordinates"]) == 3
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_coordinates_are_floats(mock_fetch):
"""x and y coordinates must be Python floats (JSON-serializable)."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.cluster_stocks(symbols, days=180)
for c in result["coordinates"]:
assert type(c["x"]) is float
assert type(c["y"]) is float
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_clusters_key_is_str(mock_fetch):
"""clusters dict keys must be strings (JSON object keys)."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.cluster_stocks(symbols, days=180)
for key in result["clusters"]:
assert isinstance(key, str), f"Expected str key, got {type(key)}"
# ---------------------------------------------------------------------------
# find_similar_stocks
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_happy_path(mock_fetch):
"""most_similar is sorted descending by correlation; least_similar ascending."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.find_similar_stocks(
"AAPL", ["MSFT", "GOOGL", "AMZN", "JPM", "BAC"], days=180, top_n=3
)
assert result["symbol"] == "AAPL"
most = result["most_similar"]
least = result["least_similar"]
assert len(most) <= 3
assert len(least) <= 3
# most_similar sorted descending
corrs_most = [e["correlation"] for e in most]
assert corrs_most == sorted(corrs_most, reverse=True)
# least_similar sorted ascending
corrs_least = [e["correlation"] for e in least]
assert corrs_least == sorted(corrs_least)
# Each entry has symbol and correlation
for entry in most + least:
assert "symbol" in entry
assert "correlation" in entry
assert isinstance(entry["correlation"], float)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_top_n_larger_than_universe(mock_fetch):
"""top_n larger than universe size is handled gracefully (returns all)."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.find_similar_stocks(
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=10
)
# Should return at most len(universe) entries, not crash
assert len(result["most_similar"]) <= 2
assert len(result["least_similar"]) <= 2
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_no_overlap_with_most_and_least(mock_fetch):
"""most_similar and least_similar should not contain the target symbol."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.find_similar_stocks(
"AAPL", ["MSFT", "GOOGL", "AMZN", "JPM"], days=180, top_n=2
)
all_symbols = [e["symbol"] for e in result["most_similar"] + result["least_similar"]]
assert "AAPL" not in all_symbols
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_no_data_raises(mock_fetch):
"""find_similar_stocks raises ValueError when no price data is returned."""
import pandas as pd
import portfolio_service
mock_fetch.return_value = pd.DataFrame()
with pytest.raises(ValueError, match="No price data"):
await portfolio_service.find_similar_stocks(
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=5
)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_target_not_in_data_raises(mock_fetch):
"""find_similar_stocks raises ValueError when target symbol has no data."""
import portfolio_service
# Only universe symbols have data, not the target
mock_fetch.return_value = _make_prices(["MSFT", "GOOGL"])
with pytest.raises(ValueError, match="AAPL"):
await portfolio_service.find_similar_stocks(
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=5
)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_default_top_n(mock_fetch):
"""Default top_n=5 returns at most 5 entries in most_similar."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC", "WFC", "GS"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.find_similar_stocks(
"AAPL",
["MSFT", "GOOGL", "AMZN", "JPM", "BAC", "WFC", "GS"],
days=180,
)
assert len(result["most_similar"]) <= 5
assert len(result["least_similar"]) <= 5

View File

@@ -0,0 +1,443 @@
"""Integration tests for backtest routes - written FIRST (TDD RED phase)."""
from unittest.mock import AsyncMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# Shared mock response used across strategy tests
MOCK_BACKTEST_RESULT = {
"total_return": 0.15,
"annualized_return": 0.14,
"sharpe_ratio": 1.2,
"max_drawdown": -0.08,
"win_rate": 0.6,
"total_trades": 10,
"equity_curve": [10000 + i * 75 for i in range(20)],
}
MOCK_MOMENTUM_RESULT = {
**MOCK_BACKTEST_RESULT,
"allocation_history": [
{"date": "2024-01-01", "symbols": ["AAPL", "MSFT"], "weights": [0.5, 0.5]},
],
}
# ---------------------------------------------------------------------------
# POST /api/v1/backtest/sma-crossover
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
async def test_sma_crossover_happy_path(mock_fn, client):
mock_fn.return_value = MOCK_BACKTEST_RESULT
resp = await client.post(
"/api/v1/backtest/sma-crossover",
json={
"symbol": "AAPL",
"short_window": 20,
"long_window": 50,
"days": 365,
"initial_capital": 10000,
},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["total_return"] == pytest.approx(0.15)
assert data["data"]["total_trades"] == 10
assert len(data["data"]["equity_curve"]) == 20
mock_fn.assert_called_once_with(
"AAPL",
short_window=20,
long_window=50,
days=365,
initial_capital=10000.0,
)
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
async def test_sma_crossover_default_values(mock_fn, client):
mock_fn.return_value = MOCK_BACKTEST_RESULT
resp = await client.post(
"/api/v1/backtest/sma-crossover",
json={"symbol": "MSFT"},
)
assert resp.status_code == 200
mock_fn.assert_called_once_with(
"MSFT",
short_window=20,
long_window=50,
days=365,
initial_capital=10000.0,
)
@pytest.mark.asyncio
async def test_sma_crossover_missing_symbol(client):
resp = await client.post(
"/api/v1/backtest/sma-crossover",
json={"short_window": 20, "long_window": 50},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_sma_crossover_short_window_too_small(client):
resp = await client.post(
"/api/v1/backtest/sma-crossover",
json={"symbol": "AAPL", "short_window": 2, "long_window": 50},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_sma_crossover_long_window_too_large(client):
resp = await client.post(
"/api/v1/backtest/sma-crossover",
json={"symbol": "AAPL", "short_window": 20, "long_window": 500},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_sma_crossover_days_too_few(client):
resp = await client.post(
"/api/v1/backtest/sma-crossover",
json={"symbol": "AAPL", "days": 5},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_sma_crossover_days_too_many(client):
resp = await client.post(
"/api/v1/backtest/sma-crossover",
json={"symbol": "AAPL", "days": 9999},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_sma_crossover_capital_zero(client):
resp = await client.post(
"/api/v1/backtest/sma-crossover",
json={"symbol": "AAPL", "initial_capital": 0},
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
async def test_sma_crossover_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("Data fetch failed")
resp = await client.post(
"/api/v1/backtest/sma-crossover",
json={"symbol": "AAPL", "short_window": 20, "long_window": 50},
)
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
async def test_sma_crossover_value_error_returns_400(mock_fn, client):
mock_fn.side_effect = ValueError("No historical data")
resp = await client.post(
"/api/v1/backtest/sma-crossover",
json={"symbol": "AAPL", "short_window": 20, "long_window": 50},
)
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# POST /api/v1/backtest/rsi
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
async def test_rsi_happy_path(mock_fn, client):
mock_fn.return_value = MOCK_BACKTEST_RESULT
resp = await client.post(
"/api/v1/backtest/rsi",
json={
"symbol": "AAPL",
"period": 14,
"oversold": 30,
"overbought": 70,
"days": 365,
"initial_capital": 10000,
},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["sharpe_ratio"] == pytest.approx(1.2)
mock_fn.assert_called_once_with(
"AAPL",
period=14,
oversold=30.0,
overbought=70.0,
days=365,
initial_capital=10000.0,
)
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
async def test_rsi_default_values(mock_fn, client):
mock_fn.return_value = MOCK_BACKTEST_RESULT
resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"})
assert resp.status_code == 200
mock_fn.assert_called_once_with(
"AAPL",
period=14,
oversold=30.0,
overbought=70.0,
days=365,
initial_capital=10000.0,
)
@pytest.mark.asyncio
async def test_rsi_missing_symbol(client):
resp = await client.post("/api/v1/backtest/rsi", json={"period": 14})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_rsi_period_too_small(client):
resp = await client.post(
"/api/v1/backtest/rsi",
json={"symbol": "AAPL", "period": 1},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_rsi_oversold_too_high(client):
# oversold must be < 50
resp = await client.post(
"/api/v1/backtest/rsi",
json={"symbol": "AAPL", "oversold": 55, "overbought": 70},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_rsi_overbought_too_low(client):
# overbought must be > 50
resp = await client.post(
"/api/v1/backtest/rsi",
json={"symbol": "AAPL", "oversold": 30, "overbought": 45},
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
async def test_rsi_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("upstream error")
resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"})
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
async def test_rsi_value_error_returns_400(mock_fn, client):
mock_fn.side_effect = ValueError("Insufficient data")
resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"})
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# POST /api/v1/backtest/buy-and-hold
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
async def test_buy_and_hold_happy_path(mock_fn, client):
mock_fn.return_value = MOCK_BACKTEST_RESULT
resp = await client.post(
"/api/v1/backtest/buy-and-hold",
json={"symbol": "AAPL", "days": 365, "initial_capital": 10000},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["total_return"] == pytest.approx(0.15)
mock_fn.assert_called_once_with("AAPL", days=365, initial_capital=10000.0)
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
async def test_buy_and_hold_default_values(mock_fn, client):
mock_fn.return_value = MOCK_BACKTEST_RESULT
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"})
assert resp.status_code == 200
mock_fn.assert_called_once_with("AAPL", days=365, initial_capital=10000.0)
@pytest.mark.asyncio
async def test_buy_and_hold_missing_symbol(client):
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"days": 365})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_buy_and_hold_days_too_few(client):
resp = await client.post(
"/api/v1/backtest/buy-and-hold",
json={"symbol": "AAPL", "days": 10},
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
async def test_buy_and_hold_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("provider down")
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"})
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
async def test_buy_and_hold_value_error_returns_400(mock_fn, client):
mock_fn.side_effect = ValueError("No historical data")
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"})
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# POST /api/v1/backtest/momentum
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
async def test_momentum_happy_path(mock_fn, client):
mock_fn.return_value = MOCK_MOMENTUM_RESULT
resp = await client.post(
"/api/v1/backtest/momentum",
json={
"symbols": ["AAPL", "MSFT", "GOOGL"],
"lookback": 60,
"top_n": 2,
"rebalance_days": 30,
"days": 365,
"initial_capital": 10000,
},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert "allocation_history" in data["data"]
assert data["data"]["total_trades"] == 10
mock_fn.assert_called_once_with(
symbols=["AAPL", "MSFT", "GOOGL"],
lookback=60,
top_n=2,
rebalance_days=30,
days=365,
initial_capital=10000.0,
)
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
async def test_momentum_default_values(mock_fn, client):
mock_fn.return_value = MOCK_MOMENTUM_RESULT
resp = await client.post(
"/api/v1/backtest/momentum",
json={"symbols": ["AAPL", "MSFT"]},
)
assert resp.status_code == 200
mock_fn.assert_called_once_with(
symbols=["AAPL", "MSFT"],
lookback=60,
top_n=2,
rebalance_days=30,
days=365,
initial_capital=10000.0,
)
@pytest.mark.asyncio
async def test_momentum_missing_symbols(client):
resp = await client.post("/api/v1/backtest/momentum", json={"lookback": 60})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_momentum_too_few_symbols(client):
# min_length=2 on symbols list
resp = await client.post(
"/api/v1/backtest/momentum",
json={"symbols": ["AAPL"]},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_momentum_too_many_symbols(client):
symbols = [f"SYM{i}" for i in range(25)]
resp = await client.post(
"/api/v1/backtest/momentum",
json={"symbols": symbols},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_momentum_lookback_too_small(client):
resp = await client.post(
"/api/v1/backtest/momentum",
json={"symbols": ["AAPL", "MSFT"], "lookback": 2},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_momentum_days_too_few(client):
resp = await client.post(
"/api/v1/backtest/momentum",
json={"symbols": ["AAPL", "MSFT"], "days": 10},
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
async def test_momentum_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("provider down")
resp = await client.post(
"/api/v1/backtest/momentum",
json={"symbols": ["AAPL", "MSFT"]},
)
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
async def test_momentum_value_error_returns_400(mock_fn, client):
mock_fn.side_effect = ValueError("No price data available")
resp = await client.post(
"/api/v1/backtest/momentum",
json={"symbols": ["AAPL", "MSFT"]},
)
assert resp.status_code == 400

321
tests/test_routes_cn.py Normal file
View File

@@ -0,0 +1,321 @@
"""Integration tests for routes_cn.py - written FIRST (TDD RED phase)."""
from unittest.mock import patch, AsyncMock
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# ============================================================
# A-share quote GET /api/v1/cn/a-share/{symbol}/quote
# ============================================================
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_a_share_quote", new_callable=AsyncMock)
async def test_a_share_quote_happy_path(mock_fn, client):
mock_fn.return_value = {
"symbol": "000001",
"name": "平安银行",
"price": 12.34,
"change": 0.15,
"change_percent": 1.23,
"volume": 500_000,
"turnover": 6_170_000.0,
"open": 12.10,
"high": 12.50,
"low": 12.00,
"prev_close": 12.19,
}
resp = await client.get("/api/v1/cn/a-share/000001/quote")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["symbol"] == "000001"
assert data["data"]["name"] == "平安银行"
assert data["data"]["price"] == pytest.approx(12.34)
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_a_share_quote", new_callable=AsyncMock)
async def test_a_share_quote_not_found_returns_404(mock_fn, client):
mock_fn.return_value = None
resp = await client.get("/api/v1/cn/a-share/000001/quote")
assert resp.status_code == 404
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_a_share_quote", new_callable=AsyncMock)
async def test_a_share_quote_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("AKShare down")
resp = await client.get("/api/v1/cn/a-share/000001/quote")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_a_share_quote_invalid_symbol_returns_400(client):
# symbol starting with 1 is invalid for A-shares
resp = await client.get("/api/v1/cn/a-share/100001/quote")
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_a_share_quote_non_numeric_symbol_returns_400(client):
resp = await client.get("/api/v1/cn/a-share/ABCDEF/quote")
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_a_share_quote_too_short_returns_422(client):
resp = await client.get("/api/v1/cn/a-share/00001/quote")
assert resp.status_code == 422
# ============================================================
# A-share historical GET /api/v1/cn/a-share/{symbol}/historical
# ============================================================
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock)
async def test_a_share_historical_happy_path(mock_fn, client):
mock_fn.return_value = [
{
"date": "2026-01-01",
"open": 10.0,
"close": 10.5,
"high": 11.0,
"low": 9.5,
"volume": 1_000_000,
"turnover": 10_500_000.0,
"change_percent": 0.5,
"turnover_rate": 0.3,
}
]
resp = await client.get("/api/v1/cn/a-share/000001/historical?days=30")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert isinstance(data["data"], list)
assert len(data["data"]) == 1
assert data["data"][0]["date"] == "2026-01-01"
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock)
async def test_a_share_historical_default_days(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/cn/a-share/600519/historical")
assert resp.status_code == 200
mock_fn.assert_called_once_with("600519", days=365)
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock)
async def test_a_share_historical_empty_returns_200(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/cn/a-share/000001/historical")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock)
async def test_a_share_historical_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("AKShare down")
resp = await client.get("/api/v1/cn/a-share/000001/historical")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_a_share_historical_invalid_symbol_returns_400(client):
resp = await client.get("/api/v1/cn/a-share/100001/historical")
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_a_share_historical_days_out_of_range_returns_422(client):
resp = await client.get("/api/v1/cn/a-share/000001/historical?days=0")
assert resp.status_code == 422
# ============================================================
# A-share search GET /api/v1/cn/a-share/search
# ============================================================
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.search_a_shares", new_callable=AsyncMock)
async def test_a_share_search_happy_path(mock_fn, client):
mock_fn.return_value = [
{"code": "000001", "name": "平安银行"},
{"code": "000002", "name": "平安地产"},
]
resp = await client.get("/api/v1/cn/a-share/search?query=平安")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 2
assert data["data"][0]["code"] == "000001"
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.search_a_shares", new_callable=AsyncMock)
async def test_a_share_search_empty_results(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/cn/a-share/search?query=NOMATCH")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.search_a_shares", new_callable=AsyncMock)
async def test_a_share_search_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("AKShare down")
resp = await client.get("/api/v1/cn/a-share/search?query=test")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_a_share_search_missing_query_returns_422(client):
resp = await client.get("/api/v1/cn/a-share/search")
assert resp.status_code == 422
# ============================================================
# HK quote GET /api/v1/cn/hk/{symbol}/quote
# ============================================================
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_hk_quote", new_callable=AsyncMock)
async def test_hk_quote_happy_path(mock_fn, client):
mock_fn.return_value = {
"symbol": "00700",
"name": "腾讯控股",
"price": 380.0,
"change": 5.0,
"change_percent": 1.33,
"volume": 10_000_000,
"turnover": 3_800_000_000.0,
"open": 375.0,
"high": 385.0,
"low": 374.0,
"prev_close": 375.0,
}
resp = await client.get("/api/v1/cn/hk/00700/quote")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["symbol"] == "00700"
assert data["data"]["price"] == pytest.approx(380.0)
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_hk_quote", new_callable=AsyncMock)
async def test_hk_quote_not_found_returns_404(mock_fn, client):
mock_fn.return_value = None
resp = await client.get("/api/v1/cn/hk/00700/quote")
assert resp.status_code == 404
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_hk_quote", new_callable=AsyncMock)
async def test_hk_quote_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("AKShare down")
resp = await client.get("/api/v1/cn/hk/00700/quote")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_hk_quote_invalid_symbol_letters_returns_400(client):
resp = await client.get("/api/v1/cn/hk/ABCDE/quote")
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_hk_quote_too_short_returns_422(client):
resp = await client.get("/api/v1/cn/hk/0070/quote")
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_hk_quote_too_long_returns_422(client):
resp = await client.get("/api/v1/cn/hk/007000/quote")
assert resp.status_code == 422
# ============================================================
# HK historical GET /api/v1/cn/hk/{symbol}/historical
# ============================================================
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock)
async def test_hk_historical_happy_path(mock_fn, client):
mock_fn.return_value = [
{
"date": "2026-01-01",
"open": 375.0,
"close": 380.0,
"high": 385.0,
"low": 374.0,
"volume": 10_000_000,
"turnover": 3_800_000_000.0,
"change_percent": 1.33,
"turnover_rate": 0.5,
}
]
resp = await client.get("/api/v1/cn/hk/00700/historical?days=90")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert isinstance(data["data"], list)
assert data["data"][0]["close"] == pytest.approx(380.0)
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock)
async def test_hk_historical_default_days(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/cn/hk/09988/historical")
assert resp.status_code == 200
mock_fn.assert_called_once_with("09988", days=365)
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock)
async def test_hk_historical_empty_returns_200(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/cn/hk/00700/historical")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock)
async def test_hk_historical_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("AKShare down")
resp = await client.get("/api/v1/cn/hk/00700/historical")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_hk_historical_invalid_symbol_returns_400(client):
resp = await client.get("/api/v1/cn/hk/ABCDE/historical")
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_hk_historical_days_out_of_range_returns_422(client):
resp = await client.get("/api/v1/cn/hk/00700/historical?days=0")
assert resp.status_code == 422

View File

@@ -0,0 +1,98 @@
"""Tests for congress trading routes (TDD - RED phase first)."""
from unittest.mock import patch, AsyncMock
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# --- GET /api/v1/regulators/congress/trades ---
@pytest.mark.asyncio
@patch("routes_regulators.congress_service.get_congress_trades", new_callable=AsyncMock)
async def test_congress_trades_happy_path(mock_fn, client):
mock_fn.return_value = [
{
"representative": "Nancy Pelosi",
"ticker": "NVDA",
"transaction_date": "2024-01-15",
"transaction_type": "Purchase",
"amount": "$1,000,001-$5,000,000",
}
]
resp = await client.get("/api/v1/regulators/congress/trades")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 1
assert data["data"][0]["representative"] == "Nancy Pelosi"
mock_fn.assert_called_once()
@pytest.mark.asyncio
@patch("routes_regulators.congress_service.get_congress_trades", new_callable=AsyncMock)
async def test_congress_trades_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/regulators/congress/trades")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_regulators.congress_service.get_congress_trades", new_callable=AsyncMock)
async def test_congress_trades_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("Data provider unavailable")
resp = await client.get("/api/v1/regulators/congress/trades")
assert resp.status_code == 502
# --- GET /api/v1/regulators/congress/bills ---
@pytest.mark.asyncio
@patch("routes_regulators.congress_service.search_congress_bills", new_callable=AsyncMock)
async def test_congress_bills_happy_path(mock_fn, client):
mock_fn.return_value = [
{"title": "Infrastructure Investment and Jobs Act", "bill_id": "HR3684"},
{"title": "Inflation Reduction Act", "bill_id": "HR5376"},
]
resp = await client.get("/api/v1/regulators/congress/bills?query=infrastructure")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 2
assert data["data"][0]["bill_id"] == "HR3684"
mock_fn.assert_called_once_with("infrastructure")
@pytest.mark.asyncio
async def test_congress_bills_missing_query(client):
resp = await client.get("/api/v1/regulators/congress/bills")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_regulators.congress_service.search_congress_bills", new_callable=AsyncMock)
async def test_congress_bills_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/regulators/congress/bills?query=nonexistent")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_regulators.congress_service.search_congress_bills", new_callable=AsyncMock)
async def test_congress_bills_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("Congress API unavailable")
resp = await client.get("/api/v1/regulators/congress/bills?query=tax")
assert resp.status_code == 502

363
tests/test_routes_defi.py Normal file
View File

@@ -0,0 +1,363 @@
"""Tests for routes_defi.py - DeFi API routes.
TDD: these tests are written before implementation.
"""
from unittest.mock import AsyncMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# ---------------------------------------------------------------------------
# GET /api/v1/defi/tvl/protocols
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_top_protocols", new_callable=AsyncMock)
async def test_tvl_protocols_happy_path(mock_fn, client):
mock_fn.return_value = [
{
"name": "Aave",
"symbol": "AAVE",
"tvl": 10_000_000_000.0,
"chain": "Ethereum",
"chains": ["Ethereum"],
"category": "Lending",
"change_1d": 0.5,
"change_7d": -1.2,
}
]
resp = await client.get("/api/v1/defi/tvl/protocols")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 1
assert data["data"][0]["name"] == "Aave"
assert data["data"][0]["tvl"] == 10_000_000_000.0
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_top_protocols", new_callable=AsyncMock)
async def test_tvl_protocols_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/defi/tvl/protocols")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_top_protocols", new_callable=AsyncMock)
async def test_tvl_protocols_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("DefiLlama unavailable")
resp = await client.get("/api/v1/defi/tvl/protocols")
assert resp.status_code == 502
# ---------------------------------------------------------------------------
# GET /api/v1/defi/tvl/chains
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_chain_tvls", new_callable=AsyncMock)
async def test_tvl_chains_happy_path(mock_fn, client):
mock_fn.return_value = [
{"name": "Ethereum", "tvl": 50_000_000_000.0, "tokenSymbol": "ETH"},
{"name": "BSC", "tvl": 5_000_000_000.0, "tokenSymbol": "BNB"},
]
resp = await client.get("/api/v1/defi/tvl/chains")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 2
assert data["data"][0]["name"] == "Ethereum"
assert data["data"][0]["tokenSymbol"] == "ETH"
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_chain_tvls", new_callable=AsyncMock)
async def test_tvl_chains_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/defi/tvl/chains")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_chain_tvls", new_callable=AsyncMock)
async def test_tvl_chains_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("upstream error")
resp = await client.get("/api/v1/defi/tvl/chains")
assert resp.status_code == 502
# ---------------------------------------------------------------------------
# GET /api/v1/defi/tvl/{protocol}
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_protocol_tvl", new_callable=AsyncMock)
async def test_protocol_tvl_happy_path(mock_fn, client):
mock_fn.return_value = 10_000_000_000.0
resp = await client.get("/api/v1/defi/tvl/aave")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["protocol"] == "aave"
assert data["data"]["tvl"] == 10_000_000_000.0
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_protocol_tvl", new_callable=AsyncMock)
async def test_protocol_tvl_not_found_returns_404(mock_fn, client):
mock_fn.return_value = None
resp = await client.get("/api/v1/defi/tvl/nonexistent-protocol")
assert resp.status_code == 404
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_protocol_tvl", new_callable=AsyncMock)
async def test_protocol_tvl_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("HTTP error")
resp = await client.get("/api/v1/defi/tvl/aave")
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_protocol_tvl", new_callable=AsyncMock)
async def test_protocol_tvl_passes_slug_to_service(mock_fn, client):
mock_fn.return_value = 5_000_000_000.0
await client.get("/api/v1/defi/tvl/uniswap-v3")
mock_fn.assert_called_once_with("uniswap-v3")
# ---------------------------------------------------------------------------
# GET /api/v1/defi/yields
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
async def test_yields_happy_path(mock_fn, client):
mock_fn.return_value = [
{
"pool": "0xabcd",
"chain": "Ethereum",
"project": "aave-v3",
"symbol": "USDC",
"tvlUsd": 1_000_000_000.0,
"apy": 3.5,
"apyBase": 3.0,
"apyReward": 0.5,
}
]
resp = await client.get("/api/v1/defi/yields")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 1
assert data["data"][0]["pool"] == "0xabcd"
assert data["data"][0]["apy"] == 3.5
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
async def test_yields_with_chain_filter(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/defi/yields?chain=Ethereum")
assert resp.status_code == 200
mock_fn.assert_called_once_with(chain="Ethereum", project=None)
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
async def test_yields_with_project_filter(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/defi/yields?project=aave-v3")
assert resp.status_code == 200
mock_fn.assert_called_once_with(chain=None, project="aave-v3")
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
async def test_yields_with_chain_and_project_filter(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/defi/yields?chain=Polygon&project=curve")
assert resp.status_code == 200
mock_fn.assert_called_once_with(chain="Polygon", project="curve")
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
async def test_yields_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/defi/yields")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_yield_pools", new_callable=AsyncMock)
async def test_yields_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("yields API down")
resp = await client.get("/api/v1/defi/yields")
assert resp.status_code == 502
# ---------------------------------------------------------------------------
# GET /api/v1/defi/stablecoins
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_stablecoins", new_callable=AsyncMock)
async def test_stablecoins_happy_path(mock_fn, client):
mock_fn.return_value = [
{
"name": "Tether",
"symbol": "USDT",
"pegType": "peggedUSD",
"circulating": 100_000_000_000.0,
"price": 1.0,
}
]
resp = await client.get("/api/v1/defi/stablecoins")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 1
assert data["data"][0]["symbol"] == "USDT"
assert data["data"][0]["circulating"] == 100_000_000_000.0
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_stablecoins", new_callable=AsyncMock)
async def test_stablecoins_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/defi/stablecoins")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_stablecoins", new_callable=AsyncMock)
async def test_stablecoins_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("stables API error")
resp = await client.get("/api/v1/defi/stablecoins")
assert resp.status_code == 502
# ---------------------------------------------------------------------------
# GET /api/v1/defi/volumes/dexs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_dex_volumes", new_callable=AsyncMock)
async def test_dex_volumes_happy_path(mock_fn, client):
mock_fn.return_value = {
"totalVolume24h": 5_000_000_000.0,
"totalVolume7d": 30_000_000_000.0,
"protocols": [
{"name": "Uniswap", "volume24h": 2_000_000_000.0},
],
}
resp = await client.get("/api/v1/defi/volumes/dexs")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["totalVolume24h"] == 5_000_000_000.0
assert data["data"]["totalVolume7d"] == 30_000_000_000.0
assert len(data["data"]["protocols"]) == 1
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_dex_volumes", new_callable=AsyncMock)
async def test_dex_volumes_returns_502_when_service_returns_none(mock_fn, client):
mock_fn.return_value = None
resp = await client.get("/api/v1/defi/volumes/dexs")
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_dex_volumes", new_callable=AsyncMock)
async def test_dex_volumes_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("volume API error")
resp = await client.get("/api/v1/defi/volumes/dexs")
assert resp.status_code == 502
# ---------------------------------------------------------------------------
# GET /api/v1/defi/fees
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_protocol_fees", new_callable=AsyncMock)
async def test_fees_happy_path(mock_fn, client):
mock_fn.return_value = [
{"name": "Uniswap", "fees24h": 1_000_000.0, "revenue24h": 500_000.0},
{"name": "Aave", "fees24h": 800_000.0, "revenue24h": 800_000.0},
]
resp = await client.get("/api/v1/defi/fees")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 2
assert data["data"][0]["name"] == "Uniswap"
assert data["data"][0]["fees24h"] == 1_000_000.0
assert data["data"][0]["revenue24h"] == 500_000.0
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_protocol_fees", new_callable=AsyncMock)
async def test_fees_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/defi/fees")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_defi.defi_service.get_protocol_fees", new_callable=AsyncMock)
async def test_fees_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("fees API error")
resp = await client.get("/api/v1/defi/fees")
assert resp.status_code == 502

View File

@@ -0,0 +1,433 @@
"""Tests for expanded economy routes."""
from unittest.mock import patch, AsyncMock
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# --- CPI ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_cpi", new_callable=AsyncMock)
async def test_macro_cpi_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-02-01", "value": 312.5, "country": "united_states"}
]
resp = await client.get("/api/v1/macro/cpi")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["value"] == 312.5
mock_fn.assert_called_once_with(country="united_states")
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_cpi", new_callable=AsyncMock)
async def test_macro_cpi_custom_country(mock_fn, client):
mock_fn.return_value = [{"date": "2026-02-01", "value": 120.0}]
resp = await client.get("/api/v1/macro/cpi?country=germany")
assert resp.status_code == 200
mock_fn.assert_called_once_with(country="germany")
@pytest.mark.asyncio
async def test_macro_cpi_invalid_country(client):
resp = await client.get("/api/v1/macro/cpi?country=INVALID!!!COUNTRY")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_cpi", new_callable=AsyncMock)
async def test_macro_cpi_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/macro/cpi")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_cpi", new_callable=AsyncMock)
async def test_macro_cpi_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("FRED down")
resp = await client.get("/api/v1/macro/cpi")
assert resp.status_code == 502
# --- GDP ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_gdp", new_callable=AsyncMock)
async def test_macro_gdp_default_real(mock_fn, client):
mock_fn.return_value = [{"date": "2026-01-01", "value": 22.5}]
resp = await client.get("/api/v1/macro/gdp")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
mock_fn.assert_called_once_with(gdp_type="real")
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_gdp", new_callable=AsyncMock)
async def test_macro_gdp_nominal(mock_fn, client):
mock_fn.return_value = [{"date": "2026-01-01", "value": 28.3}]
resp = await client.get("/api/v1/macro/gdp?gdp_type=nominal")
assert resp.status_code == 200
mock_fn.assert_called_once_with(gdp_type="nominal")
@pytest.mark.asyncio
async def test_macro_gdp_invalid_type(client):
resp = await client.get("/api/v1/macro/gdp?gdp_type=invalid")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_gdp", new_callable=AsyncMock)
async def test_macro_gdp_forecast(mock_fn, client):
mock_fn.return_value = [{"date": "2027-01-01", "value": 23.1}]
resp = await client.get("/api/v1/macro/gdp?gdp_type=forecast")
assert resp.status_code == 200
mock_fn.assert_called_once_with(gdp_type="forecast")
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_gdp", new_callable=AsyncMock)
async def test_macro_gdp_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/macro/gdp")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- Unemployment ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_unemployment", new_callable=AsyncMock)
async def test_macro_unemployment_happy_path(mock_fn, client):
mock_fn.return_value = [{"date": "2026-02-01", "value": 3.7, "country": "united_states"}]
resp = await client.get("/api/v1/macro/unemployment")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["value"] == 3.7
mock_fn.assert_called_once_with(country="united_states")
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_unemployment", new_callable=AsyncMock)
async def test_macro_unemployment_custom_country(mock_fn, client):
mock_fn.return_value = [{"date": "2026-02-01", "value": 5.1}]
resp = await client.get("/api/v1/macro/unemployment?country=france")
assert resp.status_code == 200
mock_fn.assert_called_once_with(country="france")
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_unemployment", new_callable=AsyncMock)
async def test_macro_unemployment_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/macro/unemployment")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- PCE ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_pce", new_callable=AsyncMock)
async def test_macro_pce_happy_path(mock_fn, client):
mock_fn.return_value = [{"date": "2026-02-01", "value": 2.8}]
resp = await client.get("/api/v1/macro/pce")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["value"] == 2.8
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_pce", new_callable=AsyncMock)
async def test_macro_pce_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/macro/pce")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_pce", new_callable=AsyncMock)
async def test_macro_pce_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("FRED unavailable")
resp = await client.get("/api/v1/macro/pce")
assert resp.status_code == 502
# --- Money Measures ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_money_measures", new_callable=AsyncMock)
async def test_macro_money_measures_happy_path(mock_fn, client):
mock_fn.return_value = [{"date": "2026-02-01", "m1": 18200.0, "m2": 21000.0}]
resp = await client.get("/api/v1/macro/money-measures")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["m2"] == 21000.0
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_money_measures", new_callable=AsyncMock)
async def test_macro_money_measures_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/macro/money-measures")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- CLI ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_composite_leading_indicator", new_callable=AsyncMock)
async def test_macro_cli_happy_path(mock_fn, client):
mock_fn.return_value = [{"date": "2026-01-01", "value": 99.2, "country": "united_states"}]
resp = await client.get("/api/v1/macro/cli")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["value"] == 99.2
mock_fn.assert_called_once_with(country="united_states")
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_composite_leading_indicator", new_callable=AsyncMock)
async def test_macro_cli_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/macro/cli")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- House Price Index ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_house_price_index", new_callable=AsyncMock)
async def test_macro_hpi_happy_path(mock_fn, client):
mock_fn.return_value = [{"date": "2026-01-01", "value": 350.0, "country": "united_states"}]
resp = await client.get("/api/v1/macro/house-price-index")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["value"] == 350.0
mock_fn.assert_called_once_with(country="united_states")
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_house_price_index", new_callable=AsyncMock)
async def test_macro_hpi_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/macro/house-price-index")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- FRED Regional ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_fred_regional", new_callable=AsyncMock)
async def test_economy_fred_regional_happy_path(mock_fn, client):
mock_fn.return_value = [{"region": "CA", "value": 5.2}]
resp = await client.get("/api/v1/economy/fred-regional?series_id=CAUR")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["region"] == "CA"
mock_fn.assert_called_once_with(series_id="CAUR", region=None)
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_fred_regional", new_callable=AsyncMock)
async def test_economy_fred_regional_with_region(mock_fn, client):
mock_fn.return_value = [{"region": "state", "value": 4.1}]
resp = await client.get("/api/v1/economy/fred-regional?series_id=CAUR&region=state")
assert resp.status_code == 200
mock_fn.assert_called_once_with(series_id="CAUR", region="state")
@pytest.mark.asyncio
async def test_economy_fred_regional_missing_series_id(client):
resp = await client.get("/api/v1/economy/fred-regional")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_fred_regional", new_callable=AsyncMock)
async def test_economy_fred_regional_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/fred-regional?series_id=UNKNOWN")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- Primary Dealer Positioning ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_primary_dealer_positioning", new_callable=AsyncMock)
async def test_economy_primary_dealer_happy_path(mock_fn, client):
mock_fn.return_value = [{"date": "2026-03-12", "treasuries": 250000.0, "mbs": 80000.0}]
resp = await client.get("/api/v1/economy/primary-dealer-positioning")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["treasuries"] == 250000.0
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_primary_dealer_positioning", new_callable=AsyncMock)
async def test_economy_primary_dealer_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/primary-dealer-positioning")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- FRED Search ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.fred_search", new_callable=AsyncMock)
async def test_economy_fred_search_happy_path(mock_fn, client):
mock_fn.return_value = [
{"id": "FEDFUNDS", "title": "Effective Federal Funds Rate", "frequency": "Monthly"}
]
resp = await client.get("/api/v1/economy/fred-search?query=federal+funds")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["id"] == "FEDFUNDS"
mock_fn.assert_called_once_with(query="federal funds")
@pytest.mark.asyncio
async def test_economy_fred_search_missing_query(client):
resp = await client.get("/api/v1/economy/fred-search")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_economy.economy_service.fred_search", new_callable=AsyncMock)
async def test_economy_fred_search_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/fred-search?query=nothingtofind")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- Balance of Payments ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_balance_of_payments", new_callable=AsyncMock)
async def test_economy_bop_happy_path(mock_fn, client):
mock_fn.return_value = [{"date": "2026-01-01", "current_account": -200.0, "capital_account": 5.0}]
resp = await client.get("/api/v1/economy/balance-of-payments")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["current_account"] == -200.0
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_balance_of_payments", new_callable=AsyncMock)
async def test_economy_bop_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/balance-of-payments")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- Central Bank Holdings ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_central_bank_holdings", new_callable=AsyncMock)
async def test_economy_central_bank_holdings_happy_path(mock_fn, client):
mock_fn.return_value = [{"date": "2026-03-13", "treasuries": 4500000.0, "mbs": 2300000.0}]
resp = await client.get("/api/v1/economy/central-bank-holdings")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["treasuries"] == 4500000.0
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_central_bank_holdings", new_callable=AsyncMock)
async def test_economy_central_bank_holdings_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/central-bank-holdings")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- FOMC Documents ---
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_fomc_documents", new_callable=AsyncMock)
async def test_economy_fomc_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-01-28", "type": "Minutes", "url": "https://federalreserve.gov/fomc"}
]
resp = await client.get("/api/v1/economy/fomc-documents")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["type"] == "Minutes"
mock_fn.assert_called_once_with(year=None)
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_fomc_documents", new_callable=AsyncMock)
async def test_economy_fomc_with_year(mock_fn, client):
mock_fn.return_value = [{"date": "2024-01-30", "type": "Statement"}]
resp = await client.get("/api/v1/economy/fomc-documents?year=2024")
assert resp.status_code == 200
mock_fn.assert_called_once_with(year=2024)
@pytest.mark.asyncio
async def test_economy_fomc_invalid_year_too_low(client):
resp = await client.get("/api/v1/economy/fomc-documents?year=1999")
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_economy_fomc_invalid_year_too_high(client):
resp = await client.get("/api/v1/economy/fomc-documents?year=2100")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_economy.economy_service.get_fomc_documents", new_callable=AsyncMock)
async def test_economy_fomc_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/fomc-documents")
assert resp.status_code == 200
assert resp.json()["data"] == []

View File

@@ -0,0 +1,328 @@
"""Tests for fixed income routes."""
from unittest.mock import patch, AsyncMock
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# --- Treasury Rates ---
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_treasury_rates", new_callable=AsyncMock)
async def test_treasury_rates_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-18", "week_4": 5.27, "month_3": 5.30, "year_2": 4.85, "year_10": 4.32, "year_30": 4.55}
]
resp = await client.get("/api/v1/fixed-income/treasury-rates")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 1
assert data["data"][0]["year_10"] == 4.32
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_treasury_rates", new_callable=AsyncMock)
async def test_treasury_rates_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/fixed-income/treasury-rates")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_treasury_rates", new_callable=AsyncMock)
async def test_treasury_rates_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("Federal Reserve API down")
resp = await client.get("/api/v1/fixed-income/treasury-rates")
assert resp.status_code == 502
# --- Yield Curve ---
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_yield_curve", new_callable=AsyncMock)
async def test_yield_curve_happy_path(mock_fn, client):
mock_fn.return_value = [
{"maturity": "3M", "rate": 5.30},
{"maturity": "2Y", "rate": 4.85},
{"maturity": "10Y", "rate": 4.32},
]
resp = await client.get("/api/v1/fixed-income/yield-curve")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 3
mock_fn.assert_called_once_with(date=None)
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_yield_curve", new_callable=AsyncMock)
async def test_yield_curve_with_date(mock_fn, client):
mock_fn.return_value = [{"maturity": "10Y", "rate": 3.80}]
resp = await client.get("/api/v1/fixed-income/yield-curve?date=2024-01-15")
assert resp.status_code == 200
mock_fn.assert_called_once_with(date="2024-01-15")
@pytest.mark.asyncio
async def test_yield_curve_invalid_date_format(client):
resp = await client.get("/api/v1/fixed-income/yield-curve?date=not-a-date")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_yield_curve", new_callable=AsyncMock)
async def test_yield_curve_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/fixed-income/yield-curve")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- Treasury Auctions ---
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_treasury_auctions", new_callable=AsyncMock)
async def test_treasury_auctions_happy_path(mock_fn, client):
mock_fn.return_value = [
{"auction_date": "2026-03-10", "security_type": "Note", "security_term": "10-Year", "high_yield": 4.32, "bid_to_cover_ratio": 2.45}
]
resp = await client.get("/api/v1/fixed-income/treasury-auctions")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["bid_to_cover_ratio"] == 2.45
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_treasury_auctions", new_callable=AsyncMock)
async def test_treasury_auctions_with_security_type(mock_fn, client):
mock_fn.return_value = [{"security_type": "Bill"}]
resp = await client.get("/api/v1/fixed-income/treasury-auctions?security_type=Bill")
assert resp.status_code == 200
mock_fn.assert_called_once_with(security_type="Bill")
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_treasury_auctions", new_callable=AsyncMock)
async def test_treasury_auctions_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/fixed-income/treasury-auctions")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
async def test_treasury_auctions_invalid_security_type(client):
resp = await client.get("/api/v1/fixed-income/treasury-auctions?security_type=DROP;TABLE")
assert resp.status_code == 422
# --- TIPS Yields ---
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_tips_yields", new_callable=AsyncMock)
async def test_tips_yields_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-18", "year_5": 2.10, "year_10": 2.25, "year_30": 2.40}
]
resp = await client.get("/api/v1/fixed-income/tips-yields")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["year_10"] == 2.25
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_tips_yields", new_callable=AsyncMock)
async def test_tips_yields_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/fixed-income/tips-yields")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_tips_yields", new_callable=AsyncMock)
async def test_tips_yields_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("FRED unavailable")
resp = await client.get("/api/v1/fixed-income/tips-yields")
assert resp.status_code == 502
# --- EFFR ---
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_effr", new_callable=AsyncMock)
async def test_effr_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-18", "rate": 5.33, "percentile_1": 5.31, "percentile_25": 5.32, "percentile_75": 5.33}
]
resp = await client.get("/api/v1/fixed-income/effr")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["rate"] == 5.33
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_effr", new_callable=AsyncMock)
async def test_effr_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/fixed-income/effr")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- SOFR ---
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_sofr", new_callable=AsyncMock)
async def test_sofr_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-18", "rate": 5.31, "average_30d": 5.31, "average_90d": 5.30}
]
resp = await client.get("/api/v1/fixed-income/sofr")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["rate"] == 5.31
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_sofr", new_callable=AsyncMock)
async def test_sofr_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/fixed-income/sofr")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- HQM ---
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_hqm", new_callable=AsyncMock)
async def test_hqm_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-02-01", "aaa": 5.10, "aa": 5.25, "a": 5.40}
]
resp = await client.get("/api/v1/fixed-income/hqm")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["aaa"] == 5.10
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_hqm", new_callable=AsyncMock)
async def test_hqm_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/fixed-income/hqm")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- Commercial Paper ---
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_commercial_paper", new_callable=AsyncMock)
async def test_commercial_paper_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-18", "maturity": "overnight", "financial": 5.28, "nonfinancial": 5.30}
]
resp = await client.get("/api/v1/fixed-income/commercial-paper")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 1
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_commercial_paper", new_callable=AsyncMock)
async def test_commercial_paper_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/fixed-income/commercial-paper")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- Spot Rates ---
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_spot_rates", new_callable=AsyncMock)
async def test_spot_rates_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-01", "year_1": 5.50, "year_5": 5.20, "year_10": 5.10}
]
resp = await client.get("/api/v1/fixed-income/spot-rates")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["year_10"] == 5.10
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_spot_rates", new_callable=AsyncMock)
async def test_spot_rates_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/fixed-income/spot-rates")
assert resp.status_code == 200
assert resp.json()["data"] == []
# --- Spreads ---
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_spreads", new_callable=AsyncMock)
async def test_spreads_default(mock_fn, client):
mock_fn.return_value = [{"date": "2026-03-18", "spread": 1.10}]
resp = await client.get("/api/v1/fixed-income/spreads")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
mock_fn.assert_called_once_with(series="tcm")
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_spreads", new_callable=AsyncMock)
async def test_spreads_tcm_effr(mock_fn, client):
mock_fn.return_value = [{"date": "2026-03-18", "spread": 0.02}]
resp = await client.get("/api/v1/fixed-income/spreads?series=tcm_effr")
assert resp.status_code == 200
mock_fn.assert_called_once_with(series="tcm_effr")
@pytest.mark.asyncio
async def test_spreads_invalid_series(client):
resp = await client.get("/api/v1/fixed-income/spreads?series=invalid")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_fixed_income.fixed_income_service.get_spreads", new_callable=AsyncMock)
async def test_spreads_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/fixed-income/spreads?series=treasury_effr")
assert resp.status_code == 200
assert resp.json()["data"] == []

View File

@@ -0,0 +1,493 @@
"""Tests for portfolio optimization routes (TDD - RED phase first)."""
from unittest.mock import patch, AsyncMock
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# --- POST /api/v1/portfolio/optimize ---
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.optimize_hrp", new_callable=AsyncMock)
async def test_portfolio_optimize_happy_path(mock_fn, client):
mock_fn.return_value = {
"weights": {"AAPL": 0.35, "MSFT": 0.32, "GOOGL": 0.33},
"method": "hrp",
}
resp = await client.post(
"/api/v1/portfolio/optimize",
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 365},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["method"] == "hrp"
assert "AAPL" in data["data"]["weights"]
mock_fn.assert_called_once_with(["AAPL", "MSFT", "GOOGL"], days=365)
@pytest.mark.asyncio
async def test_portfolio_optimize_missing_symbols(client):
resp = await client.post("/api/v1/portfolio/optimize", json={"days": 365})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_optimize_empty_symbols(client):
resp = await client.post(
"/api/v1/portfolio/optimize", json={"symbols": [], "days": 365}
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_optimize_too_many_symbols(client):
symbols = [f"SYM{i}" for i in range(51)]
resp = await client.post(
"/api/v1/portfolio/optimize", json={"symbols": symbols, "days": 365}
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.optimize_hrp", new_callable=AsyncMock)
async def test_portfolio_optimize_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("Computation failed")
resp = await client.post(
"/api/v1/portfolio/optimize",
json={"symbols": ["AAPL", "MSFT"], "days": 365},
)
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.optimize_hrp", new_callable=AsyncMock)
async def test_portfolio_optimize_value_error_returns_400(mock_fn, client):
mock_fn.side_effect = ValueError("No price data available")
resp = await client.post(
"/api/v1/portfolio/optimize",
json={"symbols": ["AAPL", "MSFT"], "days": 365},
)
assert resp.status_code == 400
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.optimize_hrp", new_callable=AsyncMock)
async def test_portfolio_optimize_default_days(mock_fn, client):
mock_fn.return_value = {"weights": {"AAPL": 1.0}, "method": "hrp"}
resp = await client.post(
"/api/v1/portfolio/optimize", json={"symbols": ["AAPL"]}
)
assert resp.status_code == 200
mock_fn.assert_called_once_with(["AAPL"], days=365)
# --- POST /api/v1/portfolio/correlation ---
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.compute_correlation", new_callable=AsyncMock)
async def test_portfolio_correlation_happy_path(mock_fn, client):
mock_fn.return_value = {
"symbols": ["AAPL", "MSFT"],
"matrix": [[1.0, 0.85], [0.85, 1.0]],
}
resp = await client.post(
"/api/v1/portfolio/correlation",
json={"symbols": ["AAPL", "MSFT"], "days": 365},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["symbols"] == ["AAPL", "MSFT"]
assert data["data"]["matrix"][0][0] == pytest.approx(1.0)
mock_fn.assert_called_once_with(["AAPL", "MSFT"], days=365)
@pytest.mark.asyncio
async def test_portfolio_correlation_missing_symbols(client):
resp = await client.post("/api/v1/portfolio/correlation", json={"days": 365})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_correlation_empty_symbols(client):
resp = await client.post(
"/api/v1/portfolio/correlation", json={"symbols": [], "days": 365}
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.compute_correlation", new_callable=AsyncMock)
async def test_portfolio_correlation_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("Failed")
resp = await client.post(
"/api/v1/portfolio/correlation",
json={"symbols": ["AAPL", "MSFT"], "days": 365},
)
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.compute_correlation", new_callable=AsyncMock)
async def test_portfolio_correlation_value_error_returns_400(mock_fn, client):
mock_fn.side_effect = ValueError("No price data available")
resp = await client.post(
"/api/v1/portfolio/correlation",
json={"symbols": ["AAPL", "MSFT"], "days": 365},
)
assert resp.status_code == 400
# --- POST /api/v1/portfolio/risk-parity ---
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.compute_risk_parity", new_callable=AsyncMock)
async def test_portfolio_risk_parity_happy_path(mock_fn, client):
mock_fn.return_value = {
"weights": {"AAPL": 0.35, "MSFT": 0.33, "GOOGL": 0.32},
"risk_contributions": {"AAPL": 0.34, "MSFT": 0.33, "GOOGL": 0.33},
"method": "risk_parity",
}
resp = await client.post(
"/api/v1/portfolio/risk-parity",
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 365},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["method"] == "risk_parity"
assert "risk_contributions" in data["data"]
mock_fn.assert_called_once_with(["AAPL", "MSFT", "GOOGL"], days=365)
@pytest.mark.asyncio
async def test_portfolio_risk_parity_missing_symbols(client):
resp = await client.post("/api/v1/portfolio/risk-parity", json={"days": 365})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_risk_parity_empty_symbols(client):
resp = await client.post(
"/api/v1/portfolio/risk-parity", json={"symbols": [], "days": 365}
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.compute_risk_parity", new_callable=AsyncMock)
async def test_portfolio_risk_parity_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("Failed")
resp = await client.post(
"/api/v1/portfolio/risk-parity",
json={"symbols": ["AAPL", "MSFT"], "days": 365},
)
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.compute_risk_parity", new_callable=AsyncMock)
async def test_portfolio_risk_parity_value_error_returns_400(mock_fn, client):
mock_fn.side_effect = ValueError("No price data available")
resp = await client.post(
"/api/v1/portfolio/risk-parity",
json={"symbols": ["AAPL", "MSFT"], "days": 365},
)
assert resp.status_code == 400
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.compute_risk_parity", new_callable=AsyncMock)
async def test_portfolio_risk_parity_default_days(mock_fn, client):
mock_fn.return_value = {
"weights": {"AAPL": 1.0},
"risk_contributions": {"AAPL": 1.0},
"method": "risk_parity",
}
resp = await client.post(
"/api/v1/portfolio/risk-parity", json={"symbols": ["AAPL"]}
)
assert resp.status_code == 200
mock_fn.assert_called_once_with(["AAPL"], days=365)
# ---------------------------------------------------------------------------
# POST /api/v1/portfolio/cluster
# ---------------------------------------------------------------------------
_CLUSTER_RESULT = {
"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
"coordinates": [
{"symbol": "AAPL", "x": 12.5, "y": -3.2, "cluster": 0},
{"symbol": "MSFT", "x": 11.8, "y": -2.9, "cluster": 0},
{"symbol": "GOOGL", "x": 10.1, "y": -1.5, "cluster": 0},
{"symbol": "AMZN", "x": 9.5, "y": -0.8, "cluster": 0},
{"symbol": "JPM", "x": -5.1, "y": 8.3, "cluster": 1},
{"symbol": "BAC", "x": -4.9, "y": 7.9, "cluster": 1},
],
"clusters": {"0": ["AAPL", "MSFT", "GOOGL", "AMZN"], "1": ["JPM", "BAC"]},
"method": "t-SNE + KMeans",
"n_clusters": 2,
"days": 180,
}
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
async def test_portfolio_cluster_happy_path(mock_fn, client):
"""POST /cluster returns 200 with valid cluster result."""
mock_fn.return_value = _CLUSTER_RESULT
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"], "days": 180},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["method"] == "t-SNE + KMeans"
assert "coordinates" in data["data"]
assert "clusters" in data["data"]
mock_fn.assert_called_once_with(
["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"], days=180, n_clusters=None
)
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
async def test_portfolio_cluster_with_custom_n_clusters(mock_fn, client):
"""n_clusters is forwarded to service when provided."""
mock_fn.return_value = _CLUSTER_RESULT
resp = await client.post(
"/api/v1/portfolio/cluster",
json={
"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
"days": 180,
"n_clusters": 3,
},
)
assert resp.status_code == 200
mock_fn.assert_called_once_with(
["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"], days=180, n_clusters=3
)
@pytest.mark.asyncio
async def test_portfolio_cluster_too_few_symbols_returns_422(client):
"""Fewer than 3 symbols triggers Pydantic validation error (422)."""
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT"], "days": 180},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_cluster_missing_symbols_returns_422(client):
"""Missing symbols field returns 422."""
resp = await client.post("/api/v1/portfolio/cluster", json={"days": 180})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_cluster_too_many_symbols_returns_422(client):
"""More than 50 symbols returns 422."""
symbols = [f"SYM{i}" for i in range(51)]
resp = await client.post(
"/api/v1/portfolio/cluster", json={"symbols": symbols, "days": 180}
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_cluster_days_below_minimum_returns_422(client):
"""days < 30 returns 422."""
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 10},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_cluster_n_clusters_below_minimum_returns_422(client):
"""n_clusters < 2 returns 422."""
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 180, "n_clusters": 1},
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
async def test_portfolio_cluster_value_error_returns_400(mock_fn, client):
"""ValueError from service returns 400."""
mock_fn.side_effect = ValueError("at least 3 symbols required")
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 180},
)
assert resp.status_code == 400
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
async def test_portfolio_cluster_upstream_error_returns_502(mock_fn, client):
"""Unexpected exception from service returns 502."""
mock_fn.side_effect = RuntimeError("upstream failure")
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 180},
)
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
async def test_portfolio_cluster_default_days(mock_fn, client):
"""Default days=180 is used when not provided."""
mock_fn.return_value = _CLUSTER_RESULT
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL"]},
)
assert resp.status_code == 200
mock_fn.assert_called_once_with(
["AAPL", "MSFT", "GOOGL"], days=180, n_clusters=None
)
# ---------------------------------------------------------------------------
# POST /api/v1/portfolio/similar
# ---------------------------------------------------------------------------
_SIMILAR_RESULT = {
"symbol": "AAPL",
"most_similar": [
{"symbol": "MSFT", "correlation": 0.85},
{"symbol": "GOOGL", "correlation": 0.78},
],
"least_similar": [
{"symbol": "JPM", "correlation": 0.32},
{"symbol": "BAC", "correlation": 0.28},
],
}
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
async def test_portfolio_similar_happy_path(mock_fn, client):
"""POST /similar returns 200 with most_similar and least_similar."""
mock_fn.return_value = _SIMILAR_RESULT
resp = await client.post(
"/api/v1/portfolio/similar",
json={
"symbol": "AAPL",
"universe": ["MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
"days": 180,
"top_n": 2,
},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["symbol"] == "AAPL"
assert "most_similar" in data["data"]
assert "least_similar" in data["data"]
mock_fn.assert_called_once_with(
"AAPL",
["MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
days=180,
top_n=2,
)
@pytest.mark.asyncio
async def test_portfolio_similar_missing_symbol_returns_422(client):
"""Missing symbol field returns 422."""
resp = await client.post(
"/api/v1/portfolio/similar",
json={"universe": ["MSFT", "GOOGL"], "days": 180},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_similar_missing_universe_returns_422(client):
"""Missing universe field returns 422."""
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "days": 180},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_similar_universe_too_small_returns_422(client):
"""universe with fewer than 2 entries returns 422."""
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "universe": ["MSFT"], "days": 180},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_similar_top_n_below_minimum_returns_422(client):
"""top_n < 1 returns 422."""
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL"], "days": 180, "top_n": 0},
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
async def test_portfolio_similar_value_error_returns_400(mock_fn, client):
"""ValueError from service returns 400."""
mock_fn.side_effect = ValueError("AAPL not found in price data")
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL"], "days": 180},
)
assert resp.status_code == 400
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
async def test_portfolio_similar_upstream_error_returns_502(mock_fn, client):
"""Unexpected exception from service returns 502."""
mock_fn.side_effect = RuntimeError("upstream failure")
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL"], "days": 180},
)
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
async def test_portfolio_similar_default_top_n(mock_fn, client):
"""Default top_n=5 is passed to service when not specified."""
mock_fn.return_value = _SIMILAR_RESULT
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL", "AMZN"]},
)
assert resp.status_code == 200
mock_fn.assert_called_once_with("AAPL", ["MSFT", "GOOGL", "AMZN"], days=180, top_n=5)

View File

@@ -0,0 +1,228 @@
"""Tests for regulatory data routes (CFTC, SEC)."""
from unittest.mock import patch, AsyncMock
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# --- COT Report ---
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.get_cot", new_callable=AsyncMock)
async def test_cot_report_happy_path(mock_fn, client):
mock_fn.return_value = [
{
"date": "2026-03-11",
"symbol": "ES",
"commercial_long": 250000,
"commercial_short": 300000,
"noncommercial_long": 180000,
"noncommercial_short": 120000,
}
]
resp = await client.get("/api/v1/regulators/cot?symbol=ES")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["commercial_long"] == 250000
mock_fn.assert_called_once_with("ES")
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.get_cot", new_callable=AsyncMock)
async def test_cot_report_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/regulators/cot?symbol=UNKNOWN")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.get_cot", new_callable=AsyncMock)
async def test_cot_report_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("CFTC unavailable")
resp = await client.get("/api/v1/regulators/cot?symbol=ES")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_cot_report_missing_symbol(client):
resp = await client.get("/api/v1/regulators/cot")
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_cot_report_invalid_symbol(client):
resp = await client.get("/api/v1/regulators/cot?symbol=DROP;TABLE")
assert resp.status_code == 400
# --- COT Search ---
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.cot_search", new_callable=AsyncMock)
async def test_cot_search_happy_path(mock_fn, client):
mock_fn.return_value = [
{"code": "13874P", "name": "E-MINI S&P 500"},
{"code": "13874A", "name": "S&P 500 CONSOLIDATED"},
]
resp = await client.get("/api/v1/regulators/cot/search?query=S%26P+500")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 2
assert data["data"][0]["name"] == "E-MINI S&P 500"
@pytest.mark.asyncio
async def test_cot_search_missing_query(client):
resp = await client.get("/api/v1/regulators/cot/search")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.cot_search", new_callable=AsyncMock)
async def test_cot_search_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/regulators/cot/search?query=nonexistentfutures")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.cot_search", new_callable=AsyncMock)
async def test_cot_search_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("CFTC search failed")
resp = await client.get("/api/v1/regulators/cot/search?query=gold")
assert resp.status_code == 502
# --- SEC Litigation ---
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.get_sec_litigation", new_callable=AsyncMock)
async def test_sec_litigation_happy_path(mock_fn, client):
mock_fn.return_value = [
{
"date": "2026-03-15",
"title": "SEC Charges Former CEO with Fraud",
"url": "https://sec.gov/litigation/lr/2026/lr-99999.htm",
"summary": "The Commission charged...",
}
]
resp = await client.get("/api/v1/regulators/sec/litigation")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 1
assert "CEO" in data["data"][0]["title"]
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.get_sec_litigation", new_callable=AsyncMock)
async def test_sec_litigation_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/regulators/sec/litigation")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.get_sec_litigation", new_callable=AsyncMock)
async def test_sec_litigation_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("SEC RSS feed unavailable")
resp = await client.get("/api/v1/regulators/sec/litigation")
assert resp.status_code == 502
# --- SEC Institution Search ---
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.search_institutions", new_callable=AsyncMock)
async def test_sec_institutions_happy_path(mock_fn, client):
mock_fn.return_value = [
{"name": "Vanguard Group Inc", "cik": "0000102909"},
{"name": "BlackRock Inc", "cik": "0001364742"},
]
resp = await client.get("/api/v1/regulators/sec/institutions?query=vanguard")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 2
assert data["data"][0]["name"] == "Vanguard Group Inc"
mock_fn.assert_called_once_with("vanguard")
@pytest.mark.asyncio
async def test_sec_institutions_missing_query(client):
resp = await client.get("/api/v1/regulators/sec/institutions")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.search_institutions", new_callable=AsyncMock)
async def test_sec_institutions_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/regulators/sec/institutions?query=notarealfirm")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.search_institutions", new_callable=AsyncMock)
async def test_sec_institutions_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("SEC API failed")
resp = await client.get("/api/v1/regulators/sec/institutions?query=blackrock")
assert resp.status_code == 502
# --- SEC CIK Map ---
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.get_cik_map", new_callable=AsyncMock)
async def test_sec_cik_map_happy_path(mock_fn, client):
mock_fn.return_value = [{"symbol": "AAPL", "cik": "0000320193"}]
resp = await client.get("/api/v1/regulators/sec/cik-map/AAPL")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["cik"] == "0000320193"
mock_fn.assert_called_once_with("AAPL")
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.get_cik_map", new_callable=AsyncMock)
async def test_sec_cik_map_not_found(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/regulators/sec/cik-map/XXXX")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_regulators.regulators_service.get_cik_map", new_callable=AsyncMock)
async def test_sec_cik_map_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("SEC lookup failed")
resp = await client.get("/api/v1/regulators/sec/cik-map/AAPL")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_sec_cik_map_invalid_symbol(client):
resp = await client.get("/api/v1/regulators/sec/cik-map/INVALID!!!")
assert resp.status_code == 400

View File

@@ -14,13 +14,17 @@ async def client():
@pytest.mark.asyncio
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
async def test_stock_sentiment(mock_sentiment, mock_av, client):
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
async def test_stock_sentiment(mock_av, mock_sentiment, mock_reddit, mock_upgrades, mock_recs, client):
# Route was refactored to return composite_score/composite_label/details/source_scores
mock_sentiment.return_value = {
"symbol": "AAPL",
"news_sentiment": {"bullish_percent": 0.7, "bearish_percent": 0.3},
"recent_news": [],
"recent_news": [{"headline": "Apple strong"}],
"analyst_recommendations": [],
"recent_upgrades_downgrades": [],
}
@@ -31,12 +35,22 @@ async def test_stock_sentiment(mock_sentiment, mock_av, client):
"overall_sentiment": {"avg_score": 0.4, "label": "Bullish"},
"articles": [],
}
mock_reddit.return_value = {"found": False, "symbol": "AAPL"}
mock_upgrades.return_value = []
mock_recs.return_value = []
resp = await client.get("/api/v1/stock/AAPL/sentiment")
assert resp.status_code == 200
data = resp.json()["data"]
assert data["symbol"] == "AAPL"
assert data["news_sentiment"]["bullish_percent"] == 0.7
assert data["alpha_vantage_sentiment"]["overall_sentiment"]["label"] == "Bullish"
# New composite response shape
assert "composite_score" in data
assert "composite_label" in data
assert "source_scores" in data
assert "details" in data
# AV news data accessible via details
assert data["details"]["news_sentiment"]["overall_sentiment"]["label"] == "Bullish"
# Finnhub news accessible via details
assert len(data["details"]["finnhub_news"]) == 1
@pytest.mark.asyncio

View File

@@ -0,0 +1,358 @@
"""Tests for new sentiment routes: social sentiment, reddit, composite sentiment, trending."""
from unittest.mock import patch, AsyncMock
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# --- Social Sentiment (Finnhub) ---
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_social_sentiment", new_callable=AsyncMock)
async def test_stock_social_sentiment_happy_path(mock_fn, client):
mock_fn.return_value = {
"configured": True,
"symbol": "AAPL",
"reddit_summary": {"total_mentions": 150, "avg_score": 0.55, "data_points": 5},
"twitter_summary": {"total_mentions": 300, "avg_score": 0.40, "data_points": 8},
"reddit": [{"mention": 30, "score": 0.5}],
"twitter": [{"mention": 40, "score": 0.4}],
}
resp = await client.get("/api/v1/stock/AAPL/social-sentiment")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["symbol"] == "AAPL"
assert data["data"]["reddit_summary"]["total_mentions"] == 150
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_social_sentiment", new_callable=AsyncMock)
async def test_stock_social_sentiment_not_configured(mock_fn, client):
mock_fn.return_value = {"configured": False, "message": "Set INVEST_API_FINNHUB_API_KEY"}
resp = await client.get("/api/v1/stock/AAPL/social-sentiment")
assert resp.status_code == 200
data = resp.json()
assert data["data"]["configured"] is False
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_social_sentiment", new_callable=AsyncMock)
async def test_stock_social_sentiment_premium_required(mock_fn, client):
mock_fn.return_value = {
"configured": True,
"symbol": "AAPL",
"premium_required": True,
"reddit": [],
"twitter": [],
}
resp = await client.get("/api/v1/stock/AAPL/social-sentiment")
assert resp.status_code == 200
data = resp.json()
assert data["data"]["premium_required"] is True
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_social_sentiment", new_callable=AsyncMock)
async def test_stock_social_sentiment_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("Finnhub error")
resp = await client.get("/api/v1/stock/AAPL/social-sentiment")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_stock_social_sentiment_invalid_symbol(client):
resp = await client.get("/api/v1/stock/INVALID!!!/social-sentiment")
assert resp.status_code == 400
# --- Reddit Sentiment ---
@pytest.mark.asyncio
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
async def test_stock_reddit_sentiment_found(mock_fn, client):
mock_fn.return_value = {
"symbol": "AAPL",
"found": True,
"rank": 3,
"mentions_24h": 150,
"mentions_24h_ago": 100,
"mentions_change_pct": 50.0,
"upvotes": 500,
"rank_24h_ago": 5,
}
resp = await client.get("/api/v1/stock/AAPL/reddit-sentiment")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["found"] is True
assert data["data"]["rank"] == 3
assert data["data"]["mentions_24h"] == 150
assert data["data"]["mentions_change_pct"] == 50.0
@pytest.mark.asyncio
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
async def test_stock_reddit_sentiment_not_found(mock_fn, client):
mock_fn.return_value = {
"symbol": "OBSCURE",
"found": False,
"message": "OBSCURE not in Reddit top trending (not enough mentions)",
}
resp = await client.get("/api/v1/stock/OBSCURE/reddit-sentiment")
assert resp.status_code == 200
data = resp.json()
assert data["data"]["found"] is False
assert "not in Reddit" in data["data"]["message"]
@pytest.mark.asyncio
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
async def test_stock_reddit_sentiment_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("ApeWisdom down")
resp = await client.get("/api/v1/stock/AAPL/reddit-sentiment")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_stock_reddit_sentiment_invalid_symbol(client):
resp = await client.get("/api/v1/stock/BAD!!!/reddit-sentiment")
assert resp.status_code == 400
# --- Reddit Trending ---
@pytest.mark.asyncio
@patch("routes_sentiment.reddit_service.get_reddit_trending", new_callable=AsyncMock)
async def test_reddit_trending_happy_path(mock_fn, client):
mock_fn.return_value = [
{"rank": 1, "symbol": "TSLA", "name": "Tesla", "mentions_24h": 500, "upvotes": 1200, "rank_24h_ago": 2, "mentions_24h_ago": 400},
{"rank": 2, "symbol": "AAPL", "name": "Apple", "mentions_24h": 300, "upvotes": 800, "rank_24h_ago": 1, "mentions_24h_ago": 350},
{"rank": 3, "symbol": "GME", "name": "GameStop", "mentions_24h": 200, "upvotes": 600, "rank_24h_ago": 3, "mentions_24h_ago": 180},
]
resp = await client.get("/api/v1/discover/reddit-trending")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 3
assert data["data"][0]["symbol"] == "TSLA"
assert data["data"][0]["rank"] == 1
assert data["data"][1]["symbol"] == "AAPL"
@pytest.mark.asyncio
@patch("routes_sentiment.reddit_service.get_reddit_trending", new_callable=AsyncMock)
async def test_reddit_trending_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/discover/reddit-trending")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_sentiment.reddit_service.get_reddit_trending", new_callable=AsyncMock)
async def test_reddit_trending_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("ApeWisdom unavailable")
resp = await client.get("/api/v1/discover/reddit-trending")
assert resp.status_code == 502
# --- Composite /stock/{symbol}/sentiment (aggregation logic) ---
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
async def test_composite_sentiment_all_sources(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
mock_av.return_value = {
"configured": True,
"symbol": "AAPL",
"overall_sentiment": {"avg_score": 0.2, "label": "Bullish"},
"articles": [],
}
mock_fh.return_value = {
"symbol": "AAPL",
"news_sentiment": {},
"recent_news": [{"headline": "Apple rises", "source": "Reuters"}],
"analyst_recommendations": [],
"recent_upgrades_downgrades": [],
}
mock_reddit.return_value = {
"symbol": "AAPL",
"found": True,
"rank": 2,
"mentions_24h": 200,
"mentions_24h_ago": 150,
"mentions_change_pct": 33.3,
"upvotes": 800,
}
mock_upgrades.return_value = [
{"action": "up", "company": "Goldman"},
{"action": "down", "company": "Morgan Stanley"},
{"action": "init", "company": "JPMorgan"},
]
mock_recs.return_value = [
{"strongBuy": 10, "buy": 15, "hold": 5, "sell": 2, "strongSell": 1}
]
resp = await client.get("/api/v1/stock/AAPL/sentiment")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
result = data["data"]
assert result["symbol"] == "AAPL"
assert result["composite_score"] is not None
assert result["composite_label"] in ("Strong Bullish", "Bullish", "Neutral", "Bearish", "Strong Bearish")
assert "news" in result["source_scores"]
assert "analysts" in result["source_scores"]
assert "upgrades" in result["source_scores"]
assert "reddit" in result["source_scores"]
assert "details" in result
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
async def test_composite_sentiment_no_data_returns_unknown(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
mock_av.return_value = {}
mock_fh.return_value = {}
mock_reddit.return_value = {"found": False}
mock_upgrades.return_value = []
mock_recs.return_value = []
resp = await client.get("/api/v1/stock/AAPL/sentiment")
assert resp.status_code == 200
data = resp.json()["data"]
assert data["composite_score"] is None
assert data["composite_label"] == "Unknown"
assert data["source_scores"] == {}
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
async def test_composite_sentiment_strong_bullish_label(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
# All signals strongly bullish
mock_av.return_value = {"overall_sentiment": {"avg_score": 0.35}}
mock_fh.return_value = {}
mock_reddit.return_value = {"found": True, "mentions_24h": 500, "mentions_change_pct": 100.0}
mock_upgrades.return_value = [{"action": "up"}, {"action": "up"}, {"action": "up"}]
mock_recs.return_value = [{"strongBuy": 20, "buy": 10, "hold": 1, "sell": 0, "strongSell": 0}]
resp = await client.get("/api/v1/stock/AAPL/sentiment")
assert resp.status_code == 200
data = resp.json()["data"]
assert data["composite_score"] >= 0.5
assert data["composite_label"] == "Strong Bullish"
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
async def test_composite_sentiment_bearish_label(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
# All signals bearish
mock_av.return_value = {"overall_sentiment": {"avg_score": -0.3}}
mock_fh.return_value = {}
mock_reddit.return_value = {"found": True, "mentions_24h": 200, "mentions_change_pct": -70.0}
mock_upgrades.return_value = [{"action": "down"}, {"action": "down"}, {"action": "down"}]
mock_recs.return_value = [{"strongBuy": 0, "buy": 2, "hold": 5, "sell": 10, "strongSell": 5}]
resp = await client.get("/api/v1/stock/AAPL/sentiment")
assert resp.status_code == 200
data = resp.json()["data"]
assert data["composite_label"] in ("Bearish", "Strong Bearish")
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
async def test_composite_sentiment_one_source_failing_is_graceful(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
# Simulate an exception from one source — gather uses return_exceptions=True
mock_av.side_effect = RuntimeError("AV down")
mock_fh.return_value = {}
mock_reddit.return_value = {"found": False}
mock_upgrades.return_value = []
mock_recs.return_value = [{"strongBuy": 5, "buy": 5, "hold": 3, "sell": 1, "strongSell": 0}]
resp = await client.get("/api/v1/stock/AAPL/sentiment")
# Should still succeed, gracefully skipping the failed source
assert resp.status_code == 200
data = resp.json()["data"]
assert data["symbol"] == "AAPL"
@pytest.mark.asyncio
async def test_composite_sentiment_invalid_symbol(client):
resp = await client.get("/api/v1/stock/INVALID!!!/sentiment")
assert resp.status_code == 400
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
async def test_composite_sentiment_reddit_low_mentions_excluded(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
# Reddit mentions < 10 threshold should exclude reddit from scoring
mock_av.return_value = {}
mock_fh.return_value = {}
mock_reddit.return_value = {"found": True, "mentions_24h": 5, "mentions_change_pct": 50.0}
mock_upgrades.return_value = []
mock_recs.return_value = []
resp = await client.get("/api/v1/stock/AAPL/sentiment")
assert resp.status_code == 200
data = resp.json()["data"]
assert "reddit" not in data["source_scores"]
@pytest.mark.asyncio
@patch("routes_sentiment.finnhub_service.get_recommendation_trends", new_callable=AsyncMock)
@patch("routes_sentiment.openbb_service.get_upgrades_downgrades", new_callable=AsyncMock)
@patch("routes_sentiment.reddit_service.get_reddit_sentiment", new_callable=AsyncMock)
@patch("routes_sentiment.finnhub_service.get_sentiment_summary", new_callable=AsyncMock)
@patch("routes_sentiment.alphavantage_service.get_news_sentiment", new_callable=AsyncMock)
async def test_composite_sentiment_details_structure(mock_av, mock_fh, mock_reddit, mock_upgrades, mock_recs, client):
mock_av.return_value = {"overall_sentiment": {"avg_score": 0.1}}
mock_fh.return_value = {"recent_news": [{"headline": "Test news"}]}
mock_reddit.return_value = {"found": False}
mock_upgrades.return_value = [{"action": "up"}, {"action": "up"}]
mock_recs.return_value = []
resp = await client.get("/api/v1/stock/MSFT/sentiment")
assert resp.status_code == 200
details = resp.json()["data"]["details"]
assert "news_sentiment" in details
assert "analyst_recommendations" in details
assert "recent_upgrades" in details
assert "reddit" in details
assert "finnhub_news" in details

176
tests/test_routes_shorts.py Normal file
View File

@@ -0,0 +1,176 @@
"""Tests for shorts and dark pool routes."""
from unittest.mock import patch, AsyncMock
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# --- Short Volume ---
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_short_volume", new_callable=AsyncMock)
async def test_short_volume_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-18", "short_volume": 5000000, "short_exempt_volume": 10000, "total_volume": 20000000, "short_volume_percent": 0.25}
]
resp = await client.get("/api/v1/stock/AAPL/shorts/volume")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 1
assert data["data"][0]["short_volume"] == 5000000
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_short_volume", new_callable=AsyncMock)
async def test_short_volume_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/stock/GME/shorts/volume")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"] == []
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_short_volume", new_callable=AsyncMock)
async def test_short_volume_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("stockgrid unavailable")
resp = await client.get("/api/v1/stock/AAPL/shorts/volume")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_short_volume_invalid_symbol(client):
resp = await client.get("/api/v1/stock/AAPL;DROP/shorts/volume")
assert resp.status_code == 400
# --- Fails To Deliver ---
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_fails_to_deliver", new_callable=AsyncMock)
async def test_ftd_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-01", "cusip": "037833100", "failure_quantity": 50000, "symbol": "AAPL", "price": 175.0}
]
resp = await client.get("/api/v1/stock/AAPL/shorts/ftd")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["symbol"] == "AAPL"
assert data["data"][0]["failure_quantity"] == 50000
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_fails_to_deliver", new_callable=AsyncMock)
async def test_ftd_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/stock/TSLA/shorts/ftd")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_fails_to_deliver", new_callable=AsyncMock)
async def test_ftd_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("SEC connection failed")
resp = await client.get("/api/v1/stock/AAPL/shorts/ftd")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_ftd_invalid_symbol(client):
resp = await client.get("/api/v1/stock/BAD!!!/shorts/ftd")
assert resp.status_code == 400
# --- Short Interest ---
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_short_interest", new_callable=AsyncMock)
async def test_short_interest_happy_path(mock_fn, client):
mock_fn.return_value = [
{"settlement_date": "2026-02-28", "symbol": "GME", "short_interest": 20000000, "days_to_cover": 3.5}
]
resp = await client.get("/api/v1/stock/GME/shorts/interest")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["short_interest"] == 20000000
assert data["data"][0]["days_to_cover"] == 3.5
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_short_interest", new_callable=AsyncMock)
async def test_short_interest_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/stock/NVDA/shorts/interest")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_short_interest", new_callable=AsyncMock)
async def test_short_interest_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("FINRA unavailable")
resp = await client.get("/api/v1/stock/AAPL/shorts/interest")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_short_interest_invalid_symbol(client):
resp = await client.get("/api/v1/stock/INVALID!!!/shorts/interest")
assert resp.status_code == 400
# --- Dark Pool OTC ---
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_darkpool_otc", new_callable=AsyncMock)
async def test_darkpool_otc_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-18", "symbol": "AAPL", "shares": 3000000, "percentage": 12.5}
]
resp = await client.get("/api/v1/darkpool/AAPL/otc")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["percentage"] == 12.5
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_darkpool_otc", new_callable=AsyncMock)
async def test_darkpool_otc_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/darkpool/TSLA/otc")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_shorts.shorts_service.get_darkpool_otc", new_callable=AsyncMock)
async def test_darkpool_otc_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("FINRA connection timeout")
resp = await client.get("/api/v1/darkpool/AAPL/otc")
assert resp.status_code == 502
@pytest.mark.asyncio
async def test_darkpool_otc_invalid_symbol(client):
resp = await client.get("/api/v1/darkpool/BAD!!!/otc")
assert resp.status_code == 400

View File

@@ -0,0 +1,189 @@
"""Tests for economy survey routes."""
from unittest.mock import patch, AsyncMock
import pytest
from httpx import AsyncClient, ASGITransport
from main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c
# --- Michigan Consumer Sentiment ---
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_michigan", new_callable=AsyncMock)
async def test_survey_michigan_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-01", "consumer_sentiment": 76.5, "inflation_expectation_1yr": 3.1}
]
resp = await client.get("/api/v1/economy/surveys/michigan")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["consumer_sentiment"] == 76.5
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_michigan", new_callable=AsyncMock)
async def test_survey_michigan_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/surveys/michigan")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_michigan", new_callable=AsyncMock)
async def test_survey_michigan_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("FRED unavailable")
resp = await client.get("/api/v1/economy/surveys/michigan")
assert resp.status_code == 502
# --- SLOOS ---
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_sloos", new_callable=AsyncMock)
async def test_survey_sloos_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-01-01", "c_i_tightening_pct": 25.0, "consumer_tightening_pct": 10.0}
]
resp = await client.get("/api/v1/economy/surveys/sloos")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["c_i_tightening_pct"] == 25.0
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_sloos", new_callable=AsyncMock)
async def test_survey_sloos_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/surveys/sloos")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_sloos", new_callable=AsyncMock)
async def test_survey_sloos_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("FRED down")
resp = await client.get("/api/v1/economy/surveys/sloos")
assert resp.status_code == 502
# --- Nonfarm Payrolls ---
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_nonfarm_payrolls", new_callable=AsyncMock)
async def test_survey_nfp_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-07", "value": 275000, "industry": "total_nonfarm"}
]
resp = await client.get("/api/v1/economy/surveys/nonfarm-payrolls")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["value"] == 275000
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_nonfarm_payrolls", new_callable=AsyncMock)
async def test_survey_nfp_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/surveys/nonfarm-payrolls")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_nonfarm_payrolls", new_callable=AsyncMock)
async def test_survey_nfp_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("BLS unavailable")
resp = await client.get("/api/v1/economy/surveys/nonfarm-payrolls")
assert resp.status_code == 502
# --- Empire State Manufacturing ---
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_empire_state", new_callable=AsyncMock)
async def test_survey_empire_state_happy_path(mock_fn, client):
mock_fn.return_value = [
{"date": "2026-03-01", "general_business_conditions": -7.58}
]
resp = await client.get("/api/v1/economy/surveys/empire-state")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"][0]["general_business_conditions"] == -7.58
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_empire_state", new_callable=AsyncMock)
async def test_survey_empire_state_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/surveys/empire-state")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.get_empire_state", new_callable=AsyncMock)
async def test_survey_empire_state_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("FRED connection error")
resp = await client.get("/api/v1/economy/surveys/empire-state")
assert resp.status_code == 502
# --- BLS Search ---
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.bls_search", new_callable=AsyncMock)
async def test_survey_bls_search_happy_path(mock_fn, client):
mock_fn.return_value = [
{"series_id": "CES0000000001", "series_title": "All employees, thousands, total nonfarm"},
{"series_id": "CES1000000001", "series_title": "All employees, thousands, mining and logging"},
]
resp = await client.get("/api/v1/economy/surveys/bls-search?query=nonfarm+payrolls")
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert len(data["data"]) == 2
assert data["data"][0]["series_id"] == "CES0000000001"
mock_fn.assert_called_once_with(query="nonfarm payrolls")
@pytest.mark.asyncio
async def test_survey_bls_search_missing_query(client):
resp = await client.get("/api/v1/economy/surveys/bls-search")
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.bls_search", new_callable=AsyncMock)
async def test_survey_bls_search_empty(mock_fn, client):
mock_fn.return_value = []
resp = await client.get("/api/v1/economy/surveys/bls-search?query=nothingtofind")
assert resp.status_code == 200
assert resp.json()["data"] == []
@pytest.mark.asyncio
@patch("routes_surveys.surveys_service.bls_search", new_callable=AsyncMock)
async def test_survey_bls_search_service_error_returns_502(mock_fn, client):
mock_fn.side_effect = RuntimeError("BLS API down")
resp = await client.get("/api/v1/economy/surveys/bls-search?query=wages")
assert resp.status_code == 502