feat: add backtesting engine with 4 strategies (TDD)
Strategies: - POST /backtest/sma-crossover - SMA crossover (short/long window) - POST /backtest/rsi - RSI oversold/overbought signals - POST /backtest/buy-and-hold - passive benchmark - POST /backtest/momentum - multi-symbol momentum rotation Returns: total_return, annualized_return, sharpe_ratio, max_drawdown, win_rate, total_trades, equity_curve (last 20 points) Implementation: pure pandas/numpy, no external backtesting libs. Shared _compute_metrics helper across all strategies. 79 new tests (46 service unit + 33 route integration). All 391 tests passing.
This commit is contained in:
372
backtest_service.py
Normal file
372
backtest_service.py
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
"""Backtesting engine using pure pandas/numpy - no external backtesting libraries."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from obb_utils import fetch_historical
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_EQUITY_CURVE_MAX_POINTS = 20
|
||||||
|
_MIN_BARS_FOR_SINGLE_POINT = 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal signal computation helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_closes(result: Any) -> pd.Series:
|
||||||
|
"""Pull close prices from an OBBject result into a float Series."""
|
||||||
|
bars = result.results
|
||||||
|
closes = [getattr(bar, "close", None) for bar in bars]
|
||||||
|
return pd.Series(closes, dtype=float).dropna().reset_index(drop=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_sma_signals(
|
||||||
|
prices: pd.Series, short_window: int, long_window: int
|
||||||
|
) -> pd.Series:
|
||||||
|
"""Return position series (1=long, 0=flat) from SMA crossover strategy.
|
||||||
|
|
||||||
|
Buy when short SMA crosses above long SMA; sell when it crosses below.
|
||||||
|
"""
|
||||||
|
short_ma = prices.rolling(short_window).mean()
|
||||||
|
long_ma = prices.rolling(long_window).mean()
|
||||||
|
|
||||||
|
# 1 where short > long, else 0; NaN before long_window filled with 0
|
||||||
|
signal = (short_ma > long_ma).astype(int)
|
||||||
|
signal.iloc[: long_window - 1] = 0
|
||||||
|
return signal
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_rsi(prices: pd.Series, period: int) -> pd.Series:
|
||||||
|
"""Compute Wilder RSI for a price series."""
|
||||||
|
delta = prices.diff()
|
||||||
|
gain = delta.clip(lower=0)
|
||||||
|
loss = (-delta).clip(lower=0)
|
||||||
|
|
||||||
|
avg_gain = gain.ewm(alpha=1 / period, adjust=False).mean()
|
||||||
|
avg_loss = loss.ewm(alpha=1 / period, adjust=False).mean()
|
||||||
|
|
||||||
|
# When avg_loss == 0 and avg_gain > 0, RSI = 100; avoid division by zero.
|
||||||
|
rsi = pd.Series(np.where(
|
||||||
|
avg_loss == 0,
|
||||||
|
np.where(avg_gain == 0, 50.0, 100.0),
|
||||||
|
100 - (100 / (1 + avg_gain / avg_loss)),
|
||||||
|
), index=prices.index, dtype=float)
|
||||||
|
# Preserve NaN for the initial diff period
|
||||||
|
rsi[avg_gain.isna()] = np.nan
|
||||||
|
return rsi
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_rsi_signals(
|
||||||
|
prices: pd.Series, period: int, oversold: float, overbought: float
|
||||||
|
) -> pd.Series:
|
||||||
|
"""Return position series (1=long, 0=flat) from RSI strategy.
|
||||||
|
|
||||||
|
Buy when RSI < oversold; sell when RSI > overbought.
|
||||||
|
"""
|
||||||
|
rsi = _compute_rsi(prices, period)
|
||||||
|
position = pd.Series(0, index=prices.index, dtype=int)
|
||||||
|
in_trade = False
|
||||||
|
|
||||||
|
for i in range(len(prices)):
|
||||||
|
rsi_val = rsi.iloc[i]
|
||||||
|
if pd.isna(rsi_val):
|
||||||
|
continue
|
||||||
|
if not in_trade and rsi_val < oversold:
|
||||||
|
in_trade = True
|
||||||
|
elif in_trade and rsi_val > overbought:
|
||||||
|
in_trade = False
|
||||||
|
if in_trade:
|
||||||
|
position.iloc[i] = 1
|
||||||
|
|
||||||
|
return position
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared metrics computation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_metrics(equity: pd.Series, trades: int) -> dict[str, Any]:
|
||||||
|
"""Compute standard backtest performance metrics from an equity curve.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
equity:
|
||||||
|
Daily portfolio value series starting from initial_capital.
|
||||||
|
trades:
|
||||||
|
Number of completed round-trip trades.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict with keys: total_return, annualized_return, sharpe_ratio,
|
||||||
|
max_drawdown, win_rate, total_trades, equity_curve.
|
||||||
|
"""
|
||||||
|
n = len(equity)
|
||||||
|
initial = float(equity.iloc[0])
|
||||||
|
final = float(equity.iloc[-1])
|
||||||
|
|
||||||
|
total_return = (final - initial) / initial if initial != 0 else 0.0
|
||||||
|
|
||||||
|
trading_days = max(n - 1, 1)
|
||||||
|
annualized_return = (1 + total_return) ** (252 / trading_days) - 1
|
||||||
|
|
||||||
|
# Sharpe ratio (annualized, risk-free rate = 0)
|
||||||
|
sharpe_ratio: float | None = None
|
||||||
|
if n > 1:
|
||||||
|
daily_returns = equity.pct_change().dropna()
|
||||||
|
std = float(daily_returns.std())
|
||||||
|
if std > 0:
|
||||||
|
sharpe_ratio = float(daily_returns.mean() / std * np.sqrt(252))
|
||||||
|
|
||||||
|
# Maximum drawdown
|
||||||
|
rolling_max = equity.cummax()
|
||||||
|
drawdown = (equity - rolling_max) / rolling_max
|
||||||
|
max_drawdown = float(drawdown.min())
|
||||||
|
|
||||||
|
# Win rate: undefined when no trades
|
||||||
|
win_rate: float | None = None
|
||||||
|
if trades > 0:
|
||||||
|
# Approximate: compare each trade entry/exit pair captured in equity
|
||||||
|
win_rate = None # will be overridden by callers that track trades
|
||||||
|
|
||||||
|
# Equity curve - last N points as plain Python floats
|
||||||
|
last_n = equity.iloc[-_EQUITY_CURVE_MAX_POINTS:]
|
||||||
|
equity_curve = [round(float(v), 4) for v in last_n]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_return": round(total_return, 6),
|
||||||
|
"annualized_return": round(annualized_return, 6),
|
||||||
|
"sharpe_ratio": round(sharpe_ratio, 6) if sharpe_ratio is not None else None,
|
||||||
|
"max_drawdown": round(max_drawdown, 6),
|
||||||
|
"win_rate": win_rate,
|
||||||
|
"total_trades": trades,
|
||||||
|
"equity_curve": equity_curve,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _simulate_positions(
|
||||||
|
prices: pd.Series,
|
||||||
|
positions: pd.Series,
|
||||||
|
initial_capital: float,
|
||||||
|
) -> tuple[pd.Series, int, int]:
|
||||||
|
"""Simulate portfolio equity given a position series and prices.
|
||||||
|
|
||||||
|
Returns (equity_curve, total_trades, winning_trades).
|
||||||
|
A trade is a complete buy->sell round-trip.
|
||||||
|
"""
|
||||||
|
# Daily returns when in position
|
||||||
|
price_returns = prices.pct_change().fillna(0.0)
|
||||||
|
strategy_returns = positions.shift(1).fillna(0).astype(float) * price_returns
|
||||||
|
|
||||||
|
equity = initial_capital * (1 + strategy_returns).cumprod()
|
||||||
|
equity.iloc[0] = initial_capital
|
||||||
|
|
||||||
|
# Count round trips
|
||||||
|
trade_changes = positions.diff().abs()
|
||||||
|
entries = int((trade_changes == 1).sum())
|
||||||
|
exits = int((trade_changes == -1).sum())
|
||||||
|
total_trades = min(entries, exits) # only completed round trips
|
||||||
|
|
||||||
|
# Count wins: each completed trade where exit value > entry value
|
||||||
|
winning_trades = 0
|
||||||
|
in_trade = False
|
||||||
|
entry_price = 0.0
|
||||||
|
for i in range(len(positions)):
|
||||||
|
pos = int(positions.iloc[i])
|
||||||
|
price = float(prices.iloc[i])
|
||||||
|
if not in_trade and pos == 1:
|
||||||
|
in_trade = True
|
||||||
|
entry_price = price
|
||||||
|
elif in_trade and pos == 0:
|
||||||
|
in_trade = False
|
||||||
|
if price > entry_price:
|
||||||
|
winning_trades += 1
|
||||||
|
|
||||||
|
return equity, total_trades, winning_trades
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public strategy functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def backtest_sma_crossover(
|
||||||
|
symbol: str,
|
||||||
|
short_window: int,
|
||||||
|
long_window: int,
|
||||||
|
days: int,
|
||||||
|
initial_capital: float,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Run SMA crossover backtest for a single symbol."""
|
||||||
|
hist = await fetch_historical(symbol, days)
|
||||||
|
if hist is None:
|
||||||
|
raise ValueError(f"No historical data available for {symbol}")
|
||||||
|
|
||||||
|
prices = _extract_closes(hist)
|
||||||
|
if len(prices) <= long_window:
|
||||||
|
raise ValueError(
|
||||||
|
f"Insufficient data: need >{long_window} bars, got {len(prices)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
positions = _compute_sma_signals(prices, short_window, long_window)
|
||||||
|
equity, total_trades, winning_trades = _simulate_positions(
|
||||||
|
prices, positions, initial_capital
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _compute_metrics(equity, total_trades)
|
||||||
|
if total_trades > 0:
|
||||||
|
result["win_rate"] = round(winning_trades / total_trades, 6)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def backtest_rsi(
|
||||||
|
symbol: str,
|
||||||
|
period: int,
|
||||||
|
oversold: float,
|
||||||
|
overbought: float,
|
||||||
|
days: int,
|
||||||
|
initial_capital: float,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Run RSI-based backtest for a single symbol."""
|
||||||
|
hist = await fetch_historical(symbol, days)
|
||||||
|
if hist is None:
|
||||||
|
raise ValueError(f"No historical data available for {symbol}")
|
||||||
|
|
||||||
|
prices = _extract_closes(hist)
|
||||||
|
if len(prices) <= period:
|
||||||
|
raise ValueError(
|
||||||
|
f"Insufficient data: need >{period} bars, got {len(prices)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
positions = _compute_rsi_signals(prices, period, oversold, overbought)
|
||||||
|
equity, total_trades, winning_trades = _simulate_positions(
|
||||||
|
prices, positions, initial_capital
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _compute_metrics(equity, total_trades)
|
||||||
|
if total_trades > 0:
|
||||||
|
result["win_rate"] = round(winning_trades / total_trades, 6)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def backtest_buy_and_hold(
|
||||||
|
symbol: str,
|
||||||
|
days: int,
|
||||||
|
initial_capital: float,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Run a simple buy-and-hold backtest as a benchmark."""
|
||||||
|
hist = await fetch_historical(symbol, days)
|
||||||
|
if hist is None:
|
||||||
|
raise ValueError(f"No historical data available for {symbol}")
|
||||||
|
|
||||||
|
prices = _extract_closes(hist)
|
||||||
|
if len(prices) < _MIN_BARS_FOR_SINGLE_POINT:
|
||||||
|
raise ValueError(
|
||||||
|
f"Insufficient data: need at least 2 bars, got {len(prices)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Always fully invested - position is 1 from day 0
|
||||||
|
positions = pd.Series(1, index=prices.index, dtype=int)
|
||||||
|
equity, _, _ = _simulate_positions(prices, positions, initial_capital)
|
||||||
|
|
||||||
|
result = _compute_metrics(equity, trades=1)
|
||||||
|
# Buy-and-hold: 1 trade, win_rate is whether final > initial
|
||||||
|
result["win_rate"] = 1.0 if result["total_return"] > 0 else 0.0
|
||||||
|
result["total_trades"] = 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def backtest_momentum(
|
||||||
|
symbols: list[str],
|
||||||
|
lookback: int,
|
||||||
|
top_n: int,
|
||||||
|
rebalance_days: int,
|
||||||
|
days: int,
|
||||||
|
initial_capital: float,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Run momentum strategy: every rebalance_days pick top_n symbols by lookback return."""
|
||||||
|
# Fetch all price series
|
||||||
|
price_map: dict[str, pd.Series] = {}
|
||||||
|
for sym in symbols:
|
||||||
|
hist = await fetch_historical(sym, days)
|
||||||
|
if hist is not None:
|
||||||
|
closes = _extract_closes(hist)
|
||||||
|
if len(closes) > lookback:
|
||||||
|
price_map[sym] = closes
|
||||||
|
|
||||||
|
if not price_map:
|
||||||
|
raise ValueError("No price data available for any of the requested symbols")
|
||||||
|
|
||||||
|
# Align all price series to the same length (min across symbols)
|
||||||
|
min_len = min(len(v) for v in price_map.values())
|
||||||
|
aligned = {sym: s.iloc[:min_len].reset_index(drop=True) for sym, s in price_map.items()}
|
||||||
|
|
||||||
|
n_bars = min_len
|
||||||
|
portfolio_value = initial_capital
|
||||||
|
equity_values: list[float] = [initial_capital]
|
||||||
|
allocation_history: list[dict[str, Any]] = []
|
||||||
|
total_trades = 0
|
||||||
|
|
||||||
|
current_symbols: list[str] = []
|
||||||
|
current_weights: list[float] = []
|
||||||
|
entry_prices: dict[str, float] = {}
|
||||||
|
winning_trades = 0
|
||||||
|
|
||||||
|
for bar in range(1, n_bars):
|
||||||
|
# Rebalance check
|
||||||
|
if bar % rebalance_days == 0 and bar >= lookback:
|
||||||
|
# Rank symbols by lookback-period return
|
||||||
|
returns: dict[str, float] = {}
|
||||||
|
for sym, prices in aligned.items():
|
||||||
|
if bar >= lookback:
|
||||||
|
ret = (prices.iloc[bar] / prices.iloc[bar - lookback]) - 1
|
||||||
|
returns[sym] = ret
|
||||||
|
|
||||||
|
sorted_syms = sorted(returns, key=returns.get, reverse=True) # type: ignore[arg-type]
|
||||||
|
selected = sorted_syms[:top_n]
|
||||||
|
weight = 1.0 / len(selected) if selected else 0.0
|
||||||
|
|
||||||
|
# Count closed positions as trades
|
||||||
|
for sym in current_symbols:
|
||||||
|
if sym in aligned:
|
||||||
|
exit_price = float(aligned[sym].iloc[bar])
|
||||||
|
entry_price = entry_prices.get(sym, exit_price)
|
||||||
|
total_trades += 1
|
||||||
|
if exit_price > entry_price:
|
||||||
|
winning_trades += 1
|
||||||
|
|
||||||
|
current_symbols = selected
|
||||||
|
current_weights = [weight] * len(selected)
|
||||||
|
entry_prices = {sym: float(aligned[sym].iloc[bar]) for sym in selected}
|
||||||
|
|
||||||
|
allocation_history.append({
|
||||||
|
"bar": bar,
|
||||||
|
"symbols": selected,
|
||||||
|
"weights": current_weights,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Compute portfolio daily return
|
||||||
|
if current_symbols:
|
||||||
|
daily_ret = 0.0
|
||||||
|
for sym, w in zip(current_symbols, current_weights):
|
||||||
|
prev_bar = bar - 1
|
||||||
|
prev_price = float(aligned[sym].iloc[prev_bar])
|
||||||
|
curr_price = float(aligned[sym].iloc[bar])
|
||||||
|
if prev_price != 0:
|
||||||
|
daily_ret += w * (curr_price / prev_price - 1)
|
||||||
|
portfolio_value = portfolio_value * (1 + daily_ret)
|
||||||
|
|
||||||
|
equity_values.append(portfolio_value)
|
||||||
|
|
||||||
|
equity = pd.Series(equity_values, dtype=float)
|
||||||
|
result = _compute_metrics(equity, total_trades)
|
||||||
|
if total_trades > 0:
|
||||||
|
result["win_rate"] = round(winning_trades / total_trades, 6)
|
||||||
|
result["allocation_history"] = allocation_history
|
||||||
|
return result
|
||||||
2
main.py
2
main.py
@@ -38,6 +38,7 @@ from routes_shorts import router as shorts_router # noqa: E402
|
|||||||
from routes_surveys import router as surveys_router # noqa: E402
|
from routes_surveys import router as surveys_router # noqa: E402
|
||||||
from routes_technical import router as technical_router # noqa: E402
|
from routes_technical import router as technical_router # noqa: E402
|
||||||
from routes_portfolio import router as portfolio_router # noqa: E402
|
from routes_portfolio import router as portfolio_router # noqa: E402
|
||||||
|
from routes_backtest import router as backtest_router # noqa: E402
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=settings.log_level.upper(),
|
level=settings.log_level.upper(),
|
||||||
@@ -83,6 +84,7 @@ app.include_router(economy_router)
|
|||||||
app.include_router(surveys_router)
|
app.include_router(surveys_router)
|
||||||
app.include_router(regulators_router)
|
app.include_router(regulators_router)
|
||||||
app.include_router(portfolio_router)
|
app.include_router(portfolio_router)
|
||||||
|
app.include_router(backtest_router)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health", response_model=dict[str, str])
|
@app.get("/health", response_model=dict[str, str])
|
||||||
|
|||||||
141
routes_backtest.py
Normal file
141
routes_backtest.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Routes for backtesting strategies."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
import backtest_service
|
||||||
|
from models import ApiResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/backtest", tags=["backtest"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Request models
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class BacktestRequest(BaseModel):
|
||||||
|
symbol: str = Field(..., min_length=1, max_length=20)
|
||||||
|
days: int = Field(default=365, ge=30, le=3650)
|
||||||
|
initial_capital: float = Field(default=10000.0, gt=0, le=1_000_000_000)
|
||||||
|
|
||||||
|
|
||||||
|
class SMARequest(BacktestRequest):
|
||||||
|
short_window: int = Field(default=20, ge=5, le=100)
|
||||||
|
long_window: int = Field(default=50, ge=10, le=400)
|
||||||
|
|
||||||
|
|
||||||
|
class RSIRequest(BacktestRequest):
|
||||||
|
period: int = Field(default=14, ge=2, le=50)
|
||||||
|
oversold: float = Field(default=30.0, ge=1, le=49)
|
||||||
|
overbought: float = Field(default=70.0, ge=51, le=99)
|
||||||
|
|
||||||
|
|
||||||
|
class BuyAndHoldRequest(BacktestRequest):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MomentumRequest(BaseModel):
|
||||||
|
symbols: list[str] = Field(..., min_length=2, max_length=20)
|
||||||
|
lookback: int = Field(default=60, ge=5, le=252)
|
||||||
|
top_n: int = Field(default=2, ge=1)
|
||||||
|
rebalance_days: int = Field(default=30, ge=5, le=252)
|
||||||
|
days: int = Field(default=365, ge=60, le=3650)
|
||||||
|
initial_capital: float = Field(default=10000.0, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Route handlers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sma-crossover", response_model=ApiResponse)
|
||||||
|
async def sma_crossover(req: SMARequest) -> ApiResponse:
|
||||||
|
"""SMA crossover strategy: buy when short SMA crosses above long SMA."""
|
||||||
|
try:
|
||||||
|
result = await backtest_service.backtest_sma_crossover(
|
||||||
|
req.symbol,
|
||||||
|
short_window=req.short_window,
|
||||||
|
long_window=req.long_window,
|
||||||
|
days=req.days,
|
||||||
|
initial_capital=req.initial_capital,
|
||||||
|
)
|
||||||
|
return ApiResponse(data=result)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("SMA crossover backtest validation error: %s", exc)
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("SMA crossover backtest failed")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502, detail="Data provider error. Check server logs."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/rsi", response_model=ApiResponse)
|
||||||
|
async def rsi_strategy(req: RSIRequest) -> ApiResponse:
|
||||||
|
"""RSI strategy: buy when RSI < oversold, sell when RSI > overbought."""
|
||||||
|
try:
|
||||||
|
result = await backtest_service.backtest_rsi(
|
||||||
|
req.symbol,
|
||||||
|
period=req.period,
|
||||||
|
oversold=req.oversold,
|
||||||
|
overbought=req.overbought,
|
||||||
|
days=req.days,
|
||||||
|
initial_capital=req.initial_capital,
|
||||||
|
)
|
||||||
|
return ApiResponse(data=result)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("RSI backtest validation error: %s", exc)
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("RSI backtest failed")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502, detail="Data provider error. Check server logs."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/buy-and-hold", response_model=ApiResponse)
|
||||||
|
async def buy_and_hold(req: BuyAndHoldRequest) -> ApiResponse:
|
||||||
|
"""Buy-and-hold benchmark: buy on day 1, hold through end of period."""
|
||||||
|
try:
|
||||||
|
result = await backtest_service.backtest_buy_and_hold(
|
||||||
|
req.symbol,
|
||||||
|
days=req.days,
|
||||||
|
initial_capital=req.initial_capital,
|
||||||
|
)
|
||||||
|
return ApiResponse(data=result)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("Buy-and-hold backtest validation error: %s", exc)
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Buy-and-hold backtest failed")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502, detail="Data provider error. Check server logs."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/momentum", response_model=ApiResponse)
|
||||||
|
async def momentum_strategy(req: MomentumRequest) -> ApiResponse:
|
||||||
|
"""Momentum strategy: every rebalance_days pick top_n symbols by lookback return."""
|
||||||
|
try:
|
||||||
|
result = await backtest_service.backtest_momentum(
|
||||||
|
symbols=req.symbols,
|
||||||
|
lookback=req.lookback,
|
||||||
|
top_n=req.top_n,
|
||||||
|
rebalance_days=req.rebalance_days,
|
||||||
|
days=req.days,
|
||||||
|
initial_capital=req.initial_capital,
|
||||||
|
)
|
||||||
|
return ApiResponse(data=result)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.warning("Momentum backtest validation error: %s", exc)
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Momentum backtest failed")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502, detail="Data provider error. Check server logs."
|
||||||
|
) from exc
|
||||||
627
tests/test_backtest_service.py
Normal file
627
tests/test_backtest_service.py
Normal file
@@ -0,0 +1,627 @@
|
|||||||
|
"""Unit tests for backtest_service - written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import backtest_service
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_equity(values: list[float]) -> pd.Series:
|
||||||
|
"""Build a simple equity-curve Series from a list of values."""
|
||||||
|
return pd.Series(values, dtype=float)
|
||||||
|
|
||||||
|
|
||||||
|
def _rising_prices(n: int = 100, start: float = 100.0, step: float = 1.0) -> pd.Series:
|
||||||
|
"""Linearly rising price series."""
|
||||||
|
return pd.Series([start + i * step for i in range(n)], dtype=float)
|
||||||
|
|
||||||
|
|
||||||
|
def _flat_prices(n: int = 100, price: float = 100.0) -> pd.Series:
|
||||||
|
"""Flat price series - no movement."""
|
||||||
|
return pd.Series([price] * n, dtype=float)
|
||||||
|
|
||||||
|
|
||||||
|
def _oscillating_prices(n: int = 200, period: int = 40) -> pd.Series:
|
||||||
|
"""Sinusoidal price series to generate crossover signals."""
|
||||||
|
t = np.arange(n)
|
||||||
|
prices = 100 + 20 * np.sin(2 * np.pi * t / period)
|
||||||
|
return pd.Series(prices, dtype=float)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _compute_metrics tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeMetrics:
|
||||||
|
def test_total_return_positive(self):
|
||||||
|
equity = _make_equity([10000, 11000, 12000])
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=1)
|
||||||
|
assert result["total_return"] == pytest.approx(0.2, abs=1e-6)
|
||||||
|
|
||||||
|
def test_total_return_negative(self):
|
||||||
|
equity = _make_equity([10000, 9000, 8000])
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=1)
|
||||||
|
assert result["total_return"] == pytest.approx(-0.2, abs=1e-6)
|
||||||
|
|
||||||
|
def test_total_return_zero_on_flat(self):
|
||||||
|
equity = _make_equity([10000, 10000, 10000])
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=0)
|
||||||
|
assert result["total_return"] == pytest.approx(0.0, abs=1e-6)
|
||||||
|
|
||||||
|
def test_annualized_return_shape(self):
|
||||||
|
# 252 daily bars => 1 trading year; 10000 -> 11000 = +10% annualized
|
||||||
|
values = [10000 * (1.0 + 0.1 / 252) ** i for i in range(253)]
|
||||||
|
equity = _make_equity(values)
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=5)
|
||||||
|
# Should be close to 10% annualized
|
||||||
|
assert result["annualized_return"] == pytest.approx(0.1, abs=0.01)
|
||||||
|
|
||||||
|
def test_sharpe_ratio_positive_drift(self):
|
||||||
|
# Steadily rising equity with small daily increments -> positive Sharpe
|
||||||
|
values = [10000 + i * 10 for i in range(252)]
|
||||||
|
equity = _make_equity(values)
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=5)
|
||||||
|
assert result["sharpe_ratio"] > 0
|
||||||
|
|
||||||
|
def test_sharpe_ratio_none_on_single_point(self):
|
||||||
|
equity = _make_equity([10000])
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=0)
|
||||||
|
assert result["sharpe_ratio"] is None
|
||||||
|
|
||||||
|
def test_sharpe_ratio_none_on_zero_std(self):
|
||||||
|
# Perfectly flat equity => std = 0, Sharpe undefined
|
||||||
|
equity = _make_equity([10000] * 50)
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=0)
|
||||||
|
assert result["sharpe_ratio"] is None
|
||||||
|
|
||||||
|
def test_max_drawdown_known_value(self):
|
||||||
|
# Peak 12000, trough 8000 => drawdown = (8000-12000)/12000 = -1/3
|
||||||
|
equity = _make_equity([10000, 12000, 8000, 9000])
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=2)
|
||||||
|
assert result["max_drawdown"] == pytest.approx(-1 / 3, abs=1e-6)
|
||||||
|
|
||||||
|
def test_max_drawdown_zero_on_monotone_rise(self):
|
||||||
|
equity = _make_equity([10000, 11000, 12000, 13000])
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=1)
|
||||||
|
assert result["max_drawdown"] == pytest.approx(0.0, abs=1e-6)
|
||||||
|
|
||||||
|
def test_total_trades_propagated(self):
|
||||||
|
equity = _make_equity([10000, 11000])
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=7)
|
||||||
|
assert result["total_trades"] == 7
|
||||||
|
|
||||||
|
def test_win_rate_zero_trades(self):
|
||||||
|
equity = _make_equity([10000, 10000])
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=0)
|
||||||
|
assert result["win_rate"] is None
|
||||||
|
|
||||||
|
def test_equity_curve_last_20_points(self):
|
||||||
|
values = list(range(100, 160)) # 60 points
|
||||||
|
equity = _make_equity(values)
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=10)
|
||||||
|
assert len(result["equity_curve"]) == 20
|
||||||
|
assert result["equity_curve"][-1] == pytest.approx(159.0, abs=1e-6)
|
||||||
|
|
||||||
|
def test_equity_curve_shorter_than_20(self):
|
||||||
|
values = [10000, 11000, 12000]
|
||||||
|
equity = _make_equity(values)
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=1)
|
||||||
|
assert len(result["equity_curve"]) == 3
|
||||||
|
|
||||||
|
def test_result_keys_present(self):
|
||||||
|
equity = _make_equity([10000, 11000])
|
||||||
|
result = backtest_service._compute_metrics(equity, trades=1)
|
||||||
|
expected_keys = {
|
||||||
|
"total_return",
|
||||||
|
"annualized_return",
|
||||||
|
"sharpe_ratio",
|
||||||
|
"max_drawdown",
|
||||||
|
"win_rate",
|
||||||
|
"total_trades",
|
||||||
|
"equity_curve",
|
||||||
|
}
|
||||||
|
assert expected_keys.issubset(result.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _compute_sma_signals tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeSmaSignals:
|
||||||
|
def test_returns_series_with_position_column(self):
|
||||||
|
prices = _oscillating_prices(200, period=40)
|
||||||
|
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
|
||||||
|
assert isinstance(positions, pd.Series)
|
||||||
|
assert len(positions) == len(prices)
|
||||||
|
|
||||||
|
def test_positions_are_zero_or_one(self):
|
||||||
|
prices = _oscillating_prices(200, period=40)
|
||||||
|
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
|
||||||
|
unique_vals = set(positions.dropna().unique())
|
||||||
|
assert unique_vals.issubset({0, 1})
|
||||||
|
|
||||||
|
def test_no_position_before_long_window(self):
|
||||||
|
prices = _oscillating_prices(200, period=40)
|
||||||
|
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
|
||||||
|
# Before long_window-1 data points, positions should be 0
|
||||||
|
assert (positions.iloc[: 19] == 0).all()
|
||||||
|
|
||||||
|
def test_generates_at_least_one_signal_on_oscillating(self):
|
||||||
|
prices = _oscillating_prices(300, period=60)
|
||||||
|
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
|
||||||
|
# Should flip between 0 and 1 at least once on oscillating data
|
||||||
|
changes = positions.diff().abs().sum()
|
||||||
|
assert changes > 0
|
||||||
|
|
||||||
|
def test_flat_prices_produce_no_signals(self):
|
||||||
|
prices = _flat_prices(100)
|
||||||
|
positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20)
|
||||||
|
# After warm-up both SMAs equal price; short never strictly above long
|
||||||
|
assert (positions == 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _compute_rsi tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeRsi:
|
||||||
|
def test_rsi_length(self):
|
||||||
|
prices = _rising_prices(50)
|
||||||
|
rsi = backtest_service._compute_rsi(prices, period=14)
|
||||||
|
assert len(rsi) == len(prices)
|
||||||
|
|
||||||
|
def test_rsi_range(self):
|
||||||
|
prices = _oscillating_prices(100, period=20)
|
||||||
|
rsi = backtest_service._compute_rsi(prices, period=14)
|
||||||
|
valid = rsi.dropna()
|
||||||
|
assert (valid >= 0).all()
|
||||||
|
assert (valid <= 100).all()
|
||||||
|
|
||||||
|
def test_rsi_rising_prices_high(self):
|
||||||
|
# Monotonically rising prices => RSI should be high (>= 70)
|
||||||
|
prices = _rising_prices(80, step=1.0)
|
||||||
|
rsi = backtest_service._compute_rsi(prices, period=14)
|
||||||
|
# After warm-up period, RSI should be very high
|
||||||
|
assert rsi.iloc[-1] >= 70
|
||||||
|
|
||||||
|
def test_rsi_falling_prices_low(self):
|
||||||
|
# Monotonically falling prices => RSI should be low (<= 30)
|
||||||
|
prices = pd.Series([100 - i * 0.8 for i in range(80)], dtype=float)
|
||||||
|
rsi = backtest_service._compute_rsi(prices, period=14)
|
||||||
|
assert rsi.iloc[-1] <= 30
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _compute_rsi_signals tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeRsiSignals:
|
||||||
|
def test_returns_series(self):
|
||||||
|
prices = _oscillating_prices(200, period=40)
|
||||||
|
positions = backtest_service._compute_rsi_signals(
|
||||||
|
prices, period=14, oversold=30, overbought=70
|
||||||
|
)
|
||||||
|
assert isinstance(positions, pd.Series)
|
||||||
|
assert len(positions) == len(prices)
|
||||||
|
|
||||||
|
def test_positions_are_zero_or_one(self):
|
||||||
|
prices = _oscillating_prices(200, period=40)
|
||||||
|
positions = backtest_service._compute_rsi_signals(
|
||||||
|
prices, period=14, oversold=30, overbought=70
|
||||||
|
)
|
||||||
|
unique_vals = set(positions.dropna().unique())
|
||||||
|
assert unique_vals.issubset({0, 1})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# backtest_sma_crossover tests (async integration of service layer)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBacktestSmaCrossover:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_hist(self, monkeypatch):
|
||||||
|
"""Patch fetch_historical to return a synthetic OBBject-like result."""
|
||||||
|
prices = _oscillating_prices(300, period=60).tolist()
|
||||||
|
|
||||||
|
class FakeBar:
|
||||||
|
def __init__(self, close):
|
||||||
|
self.close = close
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
results = [FakeBar(p) for p in prices]
|
||||||
|
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return FakeResult()
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
return FakeResult()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_all_required_keys(self, mock_hist):
|
||||||
|
result = await backtest_service.backtest_sma_crossover(
|
||||||
|
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
required = {
|
||||||
|
"total_return",
|
||||||
|
"annualized_return",
|
||||||
|
"sharpe_ratio",
|
||||||
|
"max_drawdown",
|
||||||
|
"win_rate",
|
||||||
|
"total_trades",
|
||||||
|
"equity_curve",
|
||||||
|
}
|
||||||
|
assert required.issubset(result.keys())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_equity_curve_max_20_points(self, mock_hist):
|
||||||
|
result = await backtest_service.backtest_sma_crossover(
|
||||||
|
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
assert len(result["equity_curve"]) <= 20
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raises_value_error_on_no_data(self, monkeypatch):
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
with pytest.raises(ValueError, match="No historical data"):
|
||||||
|
await backtest_service.backtest_sma_crossover(
|
||||||
|
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initial_capital_reflected_in_equity(self, mock_hist):
|
||||||
|
result = await backtest_service.backtest_sma_crossover(
|
||||||
|
"AAPL", short_window=5, long_window=20, days=365, initial_capital=50000
|
||||||
|
)
|
||||||
|
# equity_curve values should be in range related to 50000 initial capital
|
||||||
|
assert result["equity_curve"][0] > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# backtest_rsi tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBacktestRsi:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_hist(self, monkeypatch):
|
||||||
|
prices = _oscillating_prices(300, period=60).tolist()
|
||||||
|
|
||||||
|
class FakeBar:
|
||||||
|
def __init__(self, close):
|
||||||
|
self.close = close
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
results = [FakeBar(p) for p in prices]
|
||||||
|
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return FakeResult()
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_all_required_keys(self, mock_hist):
|
||||||
|
result = await backtest_service.backtest_rsi(
|
||||||
|
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
required = {
|
||||||
|
"total_return",
|
||||||
|
"annualized_return",
|
||||||
|
"sharpe_ratio",
|
||||||
|
"max_drawdown",
|
||||||
|
"win_rate",
|
||||||
|
"total_trades",
|
||||||
|
"equity_curve",
|
||||||
|
}
|
||||||
|
assert required.issubset(result.keys())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_equity_curve_max_20_points(self, mock_hist):
|
||||||
|
result = await backtest_service.backtest_rsi(
|
||||||
|
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
assert len(result["equity_curve"]) <= 20
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raises_value_error_on_no_data(self, monkeypatch):
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
with pytest.raises(ValueError, match="No historical data"):
|
||||||
|
await backtest_service.backtest_rsi(
|
||||||
|
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# backtest_buy_and_hold tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBacktestBuyAndHold:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_hist_rising(self, monkeypatch):
|
||||||
|
prices = _rising_prices(252, start=100.0, step=1.0).tolist()
|
||||||
|
|
||||||
|
class FakeBar:
|
||||||
|
def __init__(self, close):
|
||||||
|
self.close = close
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
results = [FakeBar(p) for p in prices]
|
||||||
|
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return FakeResult()
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_all_required_keys(self, mock_hist_rising):
|
||||||
|
result = await backtest_service.backtest_buy_and_hold(
|
||||||
|
"AAPL", days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
required = {
|
||||||
|
"total_return",
|
||||||
|
"annualized_return",
|
||||||
|
"sharpe_ratio",
|
||||||
|
"max_drawdown",
|
||||||
|
"win_rate",
|
||||||
|
"total_trades",
|
||||||
|
"equity_curve",
|
||||||
|
}
|
||||||
|
assert required.issubset(result.keys())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_total_trades_always_one(self, mock_hist_rising):
|
||||||
|
result = await backtest_service.backtest_buy_and_hold(
|
||||||
|
"AAPL", days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
assert result["total_trades"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rising_prices_positive_return(self, mock_hist_rising):
|
||||||
|
result = await backtest_service.backtest_buy_and_hold(
|
||||||
|
"AAPL", days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
assert result["total_return"] > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_known_return_value(self, monkeypatch):
|
||||||
|
# 100 -> 200: 100% total return
|
||||||
|
prices = [100.0, 200.0]
|
||||||
|
|
||||||
|
class FakeBar:
|
||||||
|
def __init__(self, close):
|
||||||
|
self.close = close
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
results = [FakeBar(p) for p in prices]
|
||||||
|
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return FakeResult()
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
result = await backtest_service.backtest_buy_and_hold(
|
||||||
|
"AAPL", days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
assert result["total_return"] == pytest.approx(1.0, abs=1e-6)
|
||||||
|
assert result["equity_curve"][-1] == pytest.approx(20000.0, abs=1e-6)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raises_value_error_on_no_data(self, monkeypatch):
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
with pytest.raises(ValueError, match="No historical data"):
|
||||||
|
await backtest_service.backtest_buy_and_hold("AAPL", days=365, initial_capital=10000)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flat_prices_zero_return(self, monkeypatch):
|
||||||
|
prices = _flat_prices(50).tolist()
|
||||||
|
|
||||||
|
class FakeBar:
|
||||||
|
def __init__(self, close):
|
||||||
|
self.close = close
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
results = [FakeBar(p) for p in prices]
|
||||||
|
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return FakeResult()
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
result = await backtest_service.backtest_buy_and_hold(
|
||||||
|
"AAPL", days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
assert result["total_return"] == pytest.approx(0.0, abs=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# backtest_momentum tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBacktestMomentum:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_multi_hist(self, monkeypatch):
|
||||||
|
"""Three symbols with different return profiles."""
|
||||||
|
aapl_prices = _rising_prices(200, start=100.0, step=2.0).tolist()
|
||||||
|
msft_prices = _rising_prices(200, start=100.0, step=0.5).tolist()
|
||||||
|
googl_prices = _flat_prices(200, price=150.0).tolist()
|
||||||
|
|
||||||
|
price_map = {
|
||||||
|
"AAPL": aapl_prices,
|
||||||
|
"MSFT": msft_prices,
|
||||||
|
"GOOGL": googl_prices,
|
||||||
|
}
|
||||||
|
|
||||||
|
class FakeBar:
|
||||||
|
def __init__(self, close):
|
||||||
|
self.close = close
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
def __init__(self, prices):
|
||||||
|
self.results = [FakeBar(p) for p in prices]
|
||||||
|
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return FakeResult(price_map[symbol])
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_all_required_keys(self, mock_multi_hist):
|
||||||
|
result = await backtest_service.backtest_momentum(
|
||||||
|
symbols=["AAPL", "MSFT", "GOOGL"],
|
||||||
|
lookback=20,
|
||||||
|
top_n=2,
|
||||||
|
rebalance_days=30,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000,
|
||||||
|
)
|
||||||
|
required = {
|
||||||
|
"total_return",
|
||||||
|
"annualized_return",
|
||||||
|
"sharpe_ratio",
|
||||||
|
"max_drawdown",
|
||||||
|
"win_rate",
|
||||||
|
"total_trades",
|
||||||
|
"equity_curve",
|
||||||
|
"allocation_history",
|
||||||
|
}
|
||||||
|
assert required.issubset(result.keys())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allocation_history_is_list(self, mock_multi_hist):
|
||||||
|
result = await backtest_service.backtest_momentum(
|
||||||
|
symbols=["AAPL", "MSFT", "GOOGL"],
|
||||||
|
lookback=20,
|
||||||
|
top_n=2,
|
||||||
|
rebalance_days=30,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000,
|
||||||
|
)
|
||||||
|
assert isinstance(result["allocation_history"], list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_top_n_respected_in_allocations(self, mock_multi_hist):
|
||||||
|
result = await backtest_service.backtest_momentum(
|
||||||
|
symbols=["AAPL", "MSFT", "GOOGL"],
|
||||||
|
lookback=20,
|
||||||
|
top_n=2,
|
||||||
|
rebalance_days=30,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000,
|
||||||
|
)
|
||||||
|
for entry in result["allocation_history"]:
|
||||||
|
assert len(entry["symbols"]) <= 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raises_value_error_on_no_data(self, monkeypatch):
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
with pytest.raises(ValueError, match="No price data"):
|
||||||
|
await backtest_service.backtest_momentum(
|
||||||
|
symbols=["AAPL", "MSFT"],
|
||||||
|
lookback=20,
|
||||||
|
top_n=1,
|
||||||
|
rebalance_days=30,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_equity_curve_max_20_points(self, mock_multi_hist):
|
||||||
|
result = await backtest_service.backtest_momentum(
|
||||||
|
symbols=["AAPL", "MSFT", "GOOGL"],
|
||||||
|
lookback=20,
|
||||||
|
top_n=2,
|
||||||
|
rebalance_days=30,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000,
|
||||||
|
)
|
||||||
|
assert len(result["equity_curve"]) <= 20
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Edge case: insufficient data
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sma_crossover_insufficient_bars_raises(self, monkeypatch):
|
||||||
|
"""Fewer bars than long_window should raise ValueError."""
|
||||||
|
prices = [100.0, 101.0, 102.0] # Only 3 bars
|
||||||
|
|
||||||
|
class FakeBar:
|
||||||
|
def __init__(self, close):
|
||||||
|
self.close = close
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
results = [FakeBar(p) for p in prices]
|
||||||
|
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return FakeResult()
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
with pytest.raises(ValueError, match="Insufficient data"):
|
||||||
|
await backtest_service.backtest_sma_crossover(
|
||||||
|
"AAPL", short_window=5, long_window=20, days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rsi_insufficient_bars_raises(self, monkeypatch):
|
||||||
|
prices = [100.0, 101.0]
|
||||||
|
|
||||||
|
class FakeBar:
|
||||||
|
def __init__(self, close):
|
||||||
|
self.close = close
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
results = [FakeBar(p) for p in prices]
|
||||||
|
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return FakeResult()
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
with pytest.raises(ValueError, match="Insufficient data"):
|
||||||
|
await backtest_service.backtest_rsi(
|
||||||
|
"AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_buy_and_hold_single_bar_raises(self, monkeypatch):
|
||||||
|
prices = [100.0]
|
||||||
|
|
||||||
|
class FakeBar:
|
||||||
|
def __init__(self, close):
|
||||||
|
self.close = close
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
results = [FakeBar(p) for p in prices]
|
||||||
|
|
||||||
|
async def fake_fetch(symbol, days, **kwargs):
|
||||||
|
return FakeResult()
|
||||||
|
|
||||||
|
monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch)
|
||||||
|
with pytest.raises(ValueError, match="Insufficient data"):
|
||||||
|
await backtest_service.backtest_buy_and_hold(
|
||||||
|
"AAPL", days=365, initial_capital=10000
|
||||||
|
)
|
||||||
443
tests/test_routes_backtest.py
Normal file
443
tests/test_routes_backtest.py
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
"""Integration tests for backtest routes - written FIRST (TDD RED phase)."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from main import app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client():
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
|
||||||
|
# Shared mock response used across strategy tests
|
||||||
|
MOCK_BACKTEST_RESULT = {
|
||||||
|
"total_return": 0.15,
|
||||||
|
"annualized_return": 0.14,
|
||||||
|
"sharpe_ratio": 1.2,
|
||||||
|
"max_drawdown": -0.08,
|
||||||
|
"win_rate": 0.6,
|
||||||
|
"total_trades": 10,
|
||||||
|
"equity_curve": [10000 + i * 75 for i in range(20)],
|
||||||
|
}
|
||||||
|
|
||||||
|
MOCK_MOMENTUM_RESULT = {
|
||||||
|
**MOCK_BACKTEST_RESULT,
|
||||||
|
"allocation_history": [
|
||||||
|
{"date": "2024-01-01", "symbols": ["AAPL", "MSFT"], "weights": [0.5, 0.5]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/v1/backtest/sma-crossover
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
|
||||||
|
async def test_sma_crossover_happy_path(mock_fn, client):
|
||||||
|
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/sma-crossover",
|
||||||
|
json={
|
||||||
|
"symbol": "AAPL",
|
||||||
|
"short_window": 20,
|
||||||
|
"long_window": 50,
|
||||||
|
"days": 365,
|
||||||
|
"initial_capital": 10000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["data"]["total_return"] == pytest.approx(0.15)
|
||||||
|
assert data["data"]["total_trades"] == 10
|
||||||
|
assert len(data["data"]["equity_curve"]) == 20
|
||||||
|
mock_fn.assert_called_once_with(
|
||||||
|
"AAPL",
|
||||||
|
short_window=20,
|
||||||
|
long_window=50,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
|
||||||
|
async def test_sma_crossover_default_values(mock_fn, client):
|
||||||
|
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/sma-crossover",
|
||||||
|
json={"symbol": "MSFT"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
mock_fn.assert_called_once_with(
|
||||||
|
"MSFT",
|
||||||
|
short_window=20,
|
||||||
|
long_window=50,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sma_crossover_missing_symbol(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/sma-crossover",
|
||||||
|
json={"short_window": 20, "long_window": 50},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sma_crossover_short_window_too_small(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/sma-crossover",
|
||||||
|
json={"symbol": "AAPL", "short_window": 2, "long_window": 50},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sma_crossover_long_window_too_large(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/sma-crossover",
|
||||||
|
json={"symbol": "AAPL", "short_window": 20, "long_window": 500},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sma_crossover_days_too_few(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/sma-crossover",
|
||||||
|
json={"symbol": "AAPL", "days": 5},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sma_crossover_days_too_many(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/sma-crossover",
|
||||||
|
json={"symbol": "AAPL", "days": 9999},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sma_crossover_capital_zero(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/sma-crossover",
|
||||||
|
json={"symbol": "AAPL", "initial_capital": 0},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
|
||||||
|
async def test_sma_crossover_service_error_returns_502(mock_fn, client):
|
||||||
|
mock_fn.side_effect = RuntimeError("Data fetch failed")
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/sma-crossover",
|
||||||
|
json={"symbol": "AAPL", "short_window": 20, "long_window": 50},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock)
|
||||||
|
async def test_sma_crossover_value_error_returns_400(mock_fn, client):
|
||||||
|
mock_fn.side_effect = ValueError("No historical data")
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/sma-crossover",
|
||||||
|
json={"symbol": "AAPL", "short_window": 20, "long_window": 50},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/v1/backtest/rsi
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
|
||||||
|
async def test_rsi_happy_path(mock_fn, client):
|
||||||
|
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/rsi",
|
||||||
|
json={
|
||||||
|
"symbol": "AAPL",
|
||||||
|
"period": 14,
|
||||||
|
"oversold": 30,
|
||||||
|
"overbought": 70,
|
||||||
|
"days": 365,
|
||||||
|
"initial_capital": 10000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["data"]["sharpe_ratio"] == pytest.approx(1.2)
|
||||||
|
mock_fn.assert_called_once_with(
|
||||||
|
"AAPL",
|
||||||
|
period=14,
|
||||||
|
oversold=30.0,
|
||||||
|
overbought=70.0,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
|
||||||
|
async def test_rsi_default_values(mock_fn, client):
|
||||||
|
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||||
|
resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
mock_fn.assert_called_once_with(
|
||||||
|
"AAPL",
|
||||||
|
period=14,
|
||||||
|
oversold=30.0,
|
||||||
|
overbought=70.0,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rsi_missing_symbol(client):
|
||||||
|
resp = await client.post("/api/v1/backtest/rsi", json={"period": 14})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rsi_period_too_small(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/rsi",
|
||||||
|
json={"symbol": "AAPL", "period": 1},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rsi_oversold_too_high(client):
|
||||||
|
# oversold must be < 50
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/rsi",
|
||||||
|
json={"symbol": "AAPL", "oversold": 55, "overbought": 70},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rsi_overbought_too_low(client):
|
||||||
|
# overbought must be > 50
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/rsi",
|
||||||
|
json={"symbol": "AAPL", "oversold": 30, "overbought": 45},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
|
||||||
|
async def test_rsi_service_error_returns_502(mock_fn, client):
|
||||||
|
mock_fn.side_effect = RuntimeError("upstream error")
|
||||||
|
resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"})
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock)
|
||||||
|
async def test_rsi_value_error_returns_400(mock_fn, client):
|
||||||
|
mock_fn.side_effect = ValueError("Insufficient data")
|
||||||
|
resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/v1/backtest/buy-and-hold
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
|
||||||
|
async def test_buy_and_hold_happy_path(mock_fn, client):
|
||||||
|
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/buy-and-hold",
|
||||||
|
json={"symbol": "AAPL", "days": 365, "initial_capital": 10000},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert data["data"]["total_return"] == pytest.approx(0.15)
|
||||||
|
mock_fn.assert_called_once_with("AAPL", days=365, initial_capital=10000.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
|
||||||
|
async def test_buy_and_hold_default_values(mock_fn, client):
|
||||||
|
mock_fn.return_value = MOCK_BACKTEST_RESULT
|
||||||
|
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
mock_fn.assert_called_once_with("AAPL", days=365, initial_capital=10000.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_buy_and_hold_missing_symbol(client):
|
||||||
|
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"days": 365})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_buy_and_hold_days_too_few(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/buy-and-hold",
|
||||||
|
json={"symbol": "AAPL", "days": 10},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
|
||||||
|
async def test_buy_and_hold_service_error_returns_502(mock_fn, client):
|
||||||
|
mock_fn.side_effect = RuntimeError("provider down")
|
||||||
|
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"})
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock)
|
||||||
|
async def test_buy_and_hold_value_error_returns_400(mock_fn, client):
|
||||||
|
mock_fn.side_effect = ValueError("No historical data")
|
||||||
|
resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/v1/backtest/momentum
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
|
||||||
|
async def test_momentum_happy_path(mock_fn, client):
|
||||||
|
mock_fn.return_value = MOCK_MOMENTUM_RESULT
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/momentum",
|
||||||
|
json={
|
||||||
|
"symbols": ["AAPL", "MSFT", "GOOGL"],
|
||||||
|
"lookback": 60,
|
||||||
|
"top_n": 2,
|
||||||
|
"rebalance_days": 30,
|
||||||
|
"days": 365,
|
||||||
|
"initial_capital": 10000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "allocation_history" in data["data"]
|
||||||
|
assert data["data"]["total_trades"] == 10
|
||||||
|
mock_fn.assert_called_once_with(
|
||||||
|
symbols=["AAPL", "MSFT", "GOOGL"],
|
||||||
|
lookback=60,
|
||||||
|
top_n=2,
|
||||||
|
rebalance_days=30,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
|
||||||
|
async def test_momentum_default_values(mock_fn, client):
|
||||||
|
mock_fn.return_value = MOCK_MOMENTUM_RESULT
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/momentum",
|
||||||
|
json={"symbols": ["AAPL", "MSFT"]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
mock_fn.assert_called_once_with(
|
||||||
|
symbols=["AAPL", "MSFT"],
|
||||||
|
lookback=60,
|
||||||
|
top_n=2,
|
||||||
|
rebalance_days=30,
|
||||||
|
days=365,
|
||||||
|
initial_capital=10000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_momentum_missing_symbols(client):
|
||||||
|
resp = await client.post("/api/v1/backtest/momentum", json={"lookback": 60})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_momentum_too_few_symbols(client):
|
||||||
|
# min_length=2 on symbols list
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/momentum",
|
||||||
|
json={"symbols": ["AAPL"]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_momentum_too_many_symbols(client):
|
||||||
|
symbols = [f"SYM{i}" for i in range(25)]
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/momentum",
|
||||||
|
json={"symbols": symbols},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_momentum_lookback_too_small(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/momentum",
|
||||||
|
json={"symbols": ["AAPL", "MSFT"], "lookback": 2},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_momentum_days_too_few(client):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/momentum",
|
||||||
|
json={"symbols": ["AAPL", "MSFT"], "days": 10},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
|
||||||
|
async def test_momentum_service_error_returns_502(mock_fn, client):
|
||||||
|
mock_fn.side_effect = RuntimeError("provider down")
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/momentum",
|
||||||
|
json={"symbols": ["AAPL", "MSFT"]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock)
|
||||||
|
async def test_momentum_value_error_returns_400(mock_fn, client):
|
||||||
|
mock_fn.side_effect = ValueError("No price data available")
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/backtest/momentum",
|
||||||
|
json={"symbols": ["AAPL", "MSFT"]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
Reference in New Issue
Block a user