diff --git a/backtest_service.py b/backtest_service.py new file mode 100644 index 0000000..d0f7609 --- /dev/null +++ b/backtest_service.py @@ -0,0 +1,372 @@ +"""Backtesting engine using pure pandas/numpy - no external backtesting libraries.""" + +import logging +from typing import Any + +import numpy as np +import pandas as pd + +from obb_utils import fetch_historical + +logger = logging.getLogger(__name__) + +_EQUITY_CURVE_MAX_POINTS = 20 +_MIN_BARS_FOR_SINGLE_POINT = 2 + + +# --------------------------------------------------------------------------- +# Internal signal computation helpers +# --------------------------------------------------------------------------- + + +def _extract_closes(result: Any) -> pd.Series: + """Pull close prices from an OBBject result into a float Series.""" + bars = result.results + closes = [getattr(bar, "close", None) for bar in bars] + return pd.Series(closes, dtype=float).dropna().reset_index(drop=True) + + +def _compute_sma_signals( + prices: pd.Series, short_window: int, long_window: int +) -> pd.Series: + """Return position series (1=long, 0=flat) from SMA crossover strategy. + + Buy when short SMA crosses above long SMA; sell when it crosses below. + """ + short_ma = prices.rolling(short_window).mean() + long_ma = prices.rolling(long_window).mean() + + # 1 where short > long, else 0; NaN before long_window filled with 0 + signal = (short_ma > long_ma).astype(int) + signal.iloc[: long_window - 1] = 0 + return signal + + +def _compute_rsi(prices: pd.Series, period: int) -> pd.Series: + """Compute Wilder RSI for a price series.""" + delta = prices.diff() + gain = delta.clip(lower=0) + loss = (-delta).clip(lower=0) + + avg_gain = gain.ewm(alpha=1 / period, adjust=False).mean() + avg_loss = loss.ewm(alpha=1 / period, adjust=False).mean() + + # When avg_loss == 0 and avg_gain > 0, RSI = 100; avoid division by zero. + rsi = pd.Series(np.where( + avg_loss == 0, + np.where(avg_gain == 0, 50.0, 100.0), + 100 - (100 / (1 + avg_gain / avg_loss)), + ), index=prices.index, dtype=float) + # Preserve NaN for the initial diff period + rsi[avg_gain.isna()] = np.nan + return rsi + + +def _compute_rsi_signals( + prices: pd.Series, period: int, oversold: float, overbought: float +) -> pd.Series: + """Return position series (1=long, 0=flat) from RSI strategy. + + Buy when RSI < oversold; sell when RSI > overbought. + """ + rsi = _compute_rsi(prices, period) + position = pd.Series(0, index=prices.index, dtype=int) + in_trade = False + + for i in range(len(prices)): + rsi_val = rsi.iloc[i] + if pd.isna(rsi_val): + continue + if not in_trade and rsi_val < oversold: + in_trade = True + elif in_trade and rsi_val > overbought: + in_trade = False + if in_trade: + position.iloc[i] = 1 + + return position + + +# --------------------------------------------------------------------------- +# Shared metrics computation +# --------------------------------------------------------------------------- + + +def _compute_metrics(equity: pd.Series, trades: int) -> dict[str, Any]: + """Compute standard backtest performance metrics from an equity curve. + + Parameters + ---------- + equity: + Daily portfolio value series starting from initial_capital. + trades: + Number of completed round-trip trades. + + Returns + ------- + dict with keys: total_return, annualized_return, sharpe_ratio, + max_drawdown, win_rate, total_trades, equity_curve. + """ + n = len(equity) + initial = float(equity.iloc[0]) + final = float(equity.iloc[-1]) + + total_return = (final - initial) / initial if initial != 0 else 0.0 + + trading_days = max(n - 1, 1) + annualized_return = (1 + total_return) ** (252 / trading_days) - 1 + + # Sharpe ratio (annualized, risk-free rate = 0) + sharpe_ratio: float | None = None + if n > 1: + daily_returns = equity.pct_change().dropna() + std = float(daily_returns.std()) + if std > 0: + sharpe_ratio = float(daily_returns.mean() / std * np.sqrt(252)) + + # Maximum drawdown + rolling_max = equity.cummax() + drawdown = (equity - rolling_max) / rolling_max + max_drawdown = float(drawdown.min()) + + # Win rate: undefined when no trades + win_rate: float | None = None + if trades > 0: + # Approximate: compare each trade entry/exit pair captured in equity + win_rate = None # will be overridden by callers that track trades + + # Equity curve - last N points as plain Python floats + last_n = equity.iloc[-_EQUITY_CURVE_MAX_POINTS:] + equity_curve = [round(float(v), 4) for v in last_n] + + return { + "total_return": round(total_return, 6), + "annualized_return": round(annualized_return, 6), + "sharpe_ratio": round(sharpe_ratio, 6) if sharpe_ratio is not None else None, + "max_drawdown": round(max_drawdown, 6), + "win_rate": win_rate, + "total_trades": trades, + "equity_curve": equity_curve, + } + + +def _simulate_positions( + prices: pd.Series, + positions: pd.Series, + initial_capital: float, +) -> tuple[pd.Series, int, int]: + """Simulate portfolio equity given a position series and prices. + + Returns (equity_curve, total_trades, winning_trades). + A trade is a complete buy->sell round-trip. + """ + # Daily returns when in position + price_returns = prices.pct_change().fillna(0.0) + strategy_returns = positions.shift(1).fillna(0).astype(float) * price_returns + + equity = initial_capital * (1 + strategy_returns).cumprod() + equity.iloc[0] = initial_capital + + # Count round trips + trade_changes = positions.diff().abs() + entries = int((trade_changes == 1).sum()) + exits = int((trade_changes == -1).sum()) + total_trades = min(entries, exits) # only completed round trips + + # Count wins: each completed trade where exit value > entry value + winning_trades = 0 + in_trade = False + entry_price = 0.0 + for i in range(len(positions)): + pos = int(positions.iloc[i]) + price = float(prices.iloc[i]) + if not in_trade and pos == 1: + in_trade = True + entry_price = price + elif in_trade and pos == 0: + in_trade = False + if price > entry_price: + winning_trades += 1 + + return equity, total_trades, winning_trades + + +# --------------------------------------------------------------------------- +# Public strategy functions +# --------------------------------------------------------------------------- + + +async def backtest_sma_crossover( + symbol: str, + short_window: int, + long_window: int, + days: int, + initial_capital: float, +) -> dict[str, Any]: + """Run SMA crossover backtest for a single symbol.""" + hist = await fetch_historical(symbol, days) + if hist is None: + raise ValueError(f"No historical data available for {symbol}") + + prices = _extract_closes(hist) + if len(prices) <= long_window: + raise ValueError( + f"Insufficient data: need >{long_window} bars, got {len(prices)}" + ) + + positions = _compute_sma_signals(prices, short_window, long_window) + equity, total_trades, winning_trades = _simulate_positions( + prices, positions, initial_capital + ) + + result = _compute_metrics(equity, total_trades) + if total_trades > 0: + result["win_rate"] = round(winning_trades / total_trades, 6) + return result + + +async def backtest_rsi( + symbol: str, + period: int, + oversold: float, + overbought: float, + days: int, + initial_capital: float, +) -> dict[str, Any]: + """Run RSI-based backtest for a single symbol.""" + hist = await fetch_historical(symbol, days) + if hist is None: + raise ValueError(f"No historical data available for {symbol}") + + prices = _extract_closes(hist) + if len(prices) <= period: + raise ValueError( + f"Insufficient data: need >{period} bars, got {len(prices)}" + ) + + positions = _compute_rsi_signals(prices, period, oversold, overbought) + equity, total_trades, winning_trades = _simulate_positions( + prices, positions, initial_capital + ) + + result = _compute_metrics(equity, total_trades) + if total_trades > 0: + result["win_rate"] = round(winning_trades / total_trades, 6) + return result + + +async def backtest_buy_and_hold( + symbol: str, + days: int, + initial_capital: float, +) -> dict[str, Any]: + """Run a simple buy-and-hold backtest as a benchmark.""" + hist = await fetch_historical(symbol, days) + if hist is None: + raise ValueError(f"No historical data available for {symbol}") + + prices = _extract_closes(hist) + if len(prices) < _MIN_BARS_FOR_SINGLE_POINT: + raise ValueError( + f"Insufficient data: need at least 2 bars, got {len(prices)}" + ) + + # Always fully invested - position is 1 from day 0 + positions = pd.Series(1, index=prices.index, dtype=int) + equity, _, _ = _simulate_positions(prices, positions, initial_capital) + + result = _compute_metrics(equity, trades=1) + # Buy-and-hold: 1 trade, win_rate is whether final > initial + result["win_rate"] = 1.0 if result["total_return"] > 0 else 0.0 + result["total_trades"] = 1 + return result + + +async def backtest_momentum( + symbols: list[str], + lookback: int, + top_n: int, + rebalance_days: int, + days: int, + initial_capital: float, +) -> dict[str, Any]: + """Run momentum strategy: every rebalance_days pick top_n symbols by lookback return.""" + # Fetch all price series + price_map: dict[str, pd.Series] = {} + for sym in symbols: + hist = await fetch_historical(sym, days) + if hist is not None: + closes = _extract_closes(hist) + if len(closes) > lookback: + price_map[sym] = closes + + if not price_map: + raise ValueError("No price data available for any of the requested symbols") + + # Align all price series to the same length (min across symbols) + min_len = min(len(v) for v in price_map.values()) + aligned = {sym: s.iloc[:min_len].reset_index(drop=True) for sym, s in price_map.items()} + + n_bars = min_len + portfolio_value = initial_capital + equity_values: list[float] = [initial_capital] + allocation_history: list[dict[str, Any]] = [] + total_trades = 0 + + current_symbols: list[str] = [] + current_weights: list[float] = [] + entry_prices: dict[str, float] = {} + winning_trades = 0 + + for bar in range(1, n_bars): + # Rebalance check + if bar % rebalance_days == 0 and bar >= lookback: + # Rank symbols by lookback-period return + returns: dict[str, float] = {} + for sym, prices in aligned.items(): + if bar >= lookback: + ret = (prices.iloc[bar] / prices.iloc[bar - lookback]) - 1 + returns[sym] = ret + + sorted_syms = sorted(returns, key=returns.get, reverse=True) # type: ignore[arg-type] + selected = sorted_syms[:top_n] + weight = 1.0 / len(selected) if selected else 0.0 + + # Count closed positions as trades + for sym in current_symbols: + if sym in aligned: + exit_price = float(aligned[sym].iloc[bar]) + entry_price = entry_prices.get(sym, exit_price) + total_trades += 1 + if exit_price > entry_price: + winning_trades += 1 + + current_symbols = selected + current_weights = [weight] * len(selected) + entry_prices = {sym: float(aligned[sym].iloc[bar]) for sym in selected} + + allocation_history.append({ + "bar": bar, + "symbols": selected, + "weights": current_weights, + }) + + # Compute portfolio daily return + if current_symbols: + daily_ret = 0.0 + for sym, w in zip(current_symbols, current_weights): + prev_bar = bar - 1 + prev_price = float(aligned[sym].iloc[prev_bar]) + curr_price = float(aligned[sym].iloc[bar]) + if prev_price != 0: + daily_ret += w * (curr_price / prev_price - 1) + portfolio_value = portfolio_value * (1 + daily_ret) + + equity_values.append(portfolio_value) + + equity = pd.Series(equity_values, dtype=float) + result = _compute_metrics(equity, total_trades) + if total_trades > 0: + result["win_rate"] = round(winning_trades / total_trades, 6) + result["allocation_history"] = allocation_history + return result diff --git a/main.py b/main.py index 69961c0..296ddc4 100644 --- a/main.py +++ b/main.py @@ -38,6 +38,7 @@ from routes_shorts import router as shorts_router # noqa: E402 from routes_surveys import router as surveys_router # noqa: E402 from routes_technical import router as technical_router # noqa: E402 from routes_portfolio import router as portfolio_router # noqa: E402 +from routes_backtest import router as backtest_router # noqa: E402 logging.basicConfig( level=settings.log_level.upper(), @@ -83,6 +84,7 @@ app.include_router(economy_router) app.include_router(surveys_router) app.include_router(regulators_router) app.include_router(portfolio_router) +app.include_router(backtest_router) @app.get("/health", response_model=dict[str, str]) diff --git a/routes_backtest.py b/routes_backtest.py new file mode 100644 index 0000000..6b0e7df --- /dev/null +++ b/routes_backtest.py @@ -0,0 +1,141 @@ +"""Routes for backtesting strategies.""" + +import logging + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field + +import backtest_service +from models import ApiResponse + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/backtest", tags=["backtest"]) + + +# --------------------------------------------------------------------------- +# Request models +# --------------------------------------------------------------------------- + + +class BacktestRequest(BaseModel): + symbol: str = Field(..., min_length=1, max_length=20) + days: int = Field(default=365, ge=30, le=3650) + initial_capital: float = Field(default=10000.0, gt=0, le=1_000_000_000) + + +class SMARequest(BacktestRequest): + short_window: int = Field(default=20, ge=5, le=100) + long_window: int = Field(default=50, ge=10, le=400) + + +class RSIRequest(BacktestRequest): + period: int = Field(default=14, ge=2, le=50) + oversold: float = Field(default=30.0, ge=1, le=49) + overbought: float = Field(default=70.0, ge=51, le=99) + + +class BuyAndHoldRequest(BacktestRequest): + pass + + +class MomentumRequest(BaseModel): + symbols: list[str] = Field(..., min_length=2, max_length=20) + lookback: int = Field(default=60, ge=5, le=252) + top_n: int = Field(default=2, ge=1) + rebalance_days: int = Field(default=30, ge=5, le=252) + days: int = Field(default=365, ge=60, le=3650) + initial_capital: float = Field(default=10000.0, gt=0) + + +# --------------------------------------------------------------------------- +# Route handlers +# --------------------------------------------------------------------------- + + +@router.post("/sma-crossover", response_model=ApiResponse) +async def sma_crossover(req: SMARequest) -> ApiResponse: + """SMA crossover strategy: buy when short SMA crosses above long SMA.""" + try: + result = await backtest_service.backtest_sma_crossover( + req.symbol, + short_window=req.short_window, + long_window=req.long_window, + days=req.days, + initial_capital=req.initial_capital, + ) + return ApiResponse(data=result) + except ValueError as exc: + logger.warning("SMA crossover backtest validation error: %s", exc) + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + logger.exception("SMA crossover backtest failed") + raise HTTPException( + status_code=502, detail="Data provider error. Check server logs." + ) from exc + + +@router.post("/rsi", response_model=ApiResponse) +async def rsi_strategy(req: RSIRequest) -> ApiResponse: + """RSI strategy: buy when RSI < oversold, sell when RSI > overbought.""" + try: + result = await backtest_service.backtest_rsi( + req.symbol, + period=req.period, + oversold=req.oversold, + overbought=req.overbought, + days=req.days, + initial_capital=req.initial_capital, + ) + return ApiResponse(data=result) + except ValueError as exc: + logger.warning("RSI backtest validation error: %s", exc) + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + logger.exception("RSI backtest failed") + raise HTTPException( + status_code=502, detail="Data provider error. Check server logs." + ) from exc + + +@router.post("/buy-and-hold", response_model=ApiResponse) +async def buy_and_hold(req: BuyAndHoldRequest) -> ApiResponse: + """Buy-and-hold benchmark: buy on day 1, hold through end of period.""" + try: + result = await backtest_service.backtest_buy_and_hold( + req.symbol, + days=req.days, + initial_capital=req.initial_capital, + ) + return ApiResponse(data=result) + except ValueError as exc: + logger.warning("Buy-and-hold backtest validation error: %s", exc) + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + logger.exception("Buy-and-hold backtest failed") + raise HTTPException( + status_code=502, detail="Data provider error. Check server logs." + ) from exc + + +@router.post("/momentum", response_model=ApiResponse) +async def momentum_strategy(req: MomentumRequest) -> ApiResponse: + """Momentum strategy: every rebalance_days pick top_n symbols by lookback return.""" + try: + result = await backtest_service.backtest_momentum( + symbols=req.symbols, + lookback=req.lookback, + top_n=req.top_n, + rebalance_days=req.rebalance_days, + days=req.days, + initial_capital=req.initial_capital, + ) + return ApiResponse(data=result) + except ValueError as exc: + logger.warning("Momentum backtest validation error: %s", exc) + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + logger.exception("Momentum backtest failed") + raise HTTPException( + status_code=502, detail="Data provider error. Check server logs." + ) from exc diff --git a/tests/test_backtest_service.py b/tests/test_backtest_service.py new file mode 100644 index 0000000..211eab4 --- /dev/null +++ b/tests/test_backtest_service.py @@ -0,0 +1,627 @@ +"""Unit tests for backtest_service - written FIRST (TDD RED phase).""" + +import numpy as np +import pandas as pd +import pytest + +import backtest_service + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_equity(values: list[float]) -> pd.Series: + """Build a simple equity-curve Series from a list of values.""" + return pd.Series(values, dtype=float) + + +def _rising_prices(n: int = 100, start: float = 100.0, step: float = 1.0) -> pd.Series: + """Linearly rising price series.""" + return pd.Series([start + i * step for i in range(n)], dtype=float) + + +def _flat_prices(n: int = 100, price: float = 100.0) -> pd.Series: + """Flat price series - no movement.""" + return pd.Series([price] * n, dtype=float) + + +def _oscillating_prices(n: int = 200, period: int = 40) -> pd.Series: + """Sinusoidal price series to generate crossover signals.""" + t = np.arange(n) + prices = 100 + 20 * np.sin(2 * np.pi * t / period) + return pd.Series(prices, dtype=float) + + +# --------------------------------------------------------------------------- +# _compute_metrics tests +# --------------------------------------------------------------------------- + + +class TestComputeMetrics: + def test_total_return_positive(self): + equity = _make_equity([10000, 11000, 12000]) + result = backtest_service._compute_metrics(equity, trades=1) + assert result["total_return"] == pytest.approx(0.2, abs=1e-6) + + def test_total_return_negative(self): + equity = _make_equity([10000, 9000, 8000]) + result = backtest_service._compute_metrics(equity, trades=1) + assert result["total_return"] == pytest.approx(-0.2, abs=1e-6) + + def test_total_return_zero_on_flat(self): + equity = _make_equity([10000, 10000, 10000]) + result = backtest_service._compute_metrics(equity, trades=0) + assert result["total_return"] == pytest.approx(0.0, abs=1e-6) + + def test_annualized_return_shape(self): + # 252 daily bars => 1 trading year; 10000 -> 11000 = +10% annualized + values = [10000 * (1.0 + 0.1 / 252) ** i for i in range(253)] + equity = _make_equity(values) + result = backtest_service._compute_metrics(equity, trades=5) + # Should be close to 10% annualized + assert result["annualized_return"] == pytest.approx(0.1, abs=0.01) + + def test_sharpe_ratio_positive_drift(self): + # Steadily rising equity with small daily increments -> positive Sharpe + values = [10000 + i * 10 for i in range(252)] + equity = _make_equity(values) + result = backtest_service._compute_metrics(equity, trades=5) + assert result["sharpe_ratio"] > 0 + + def test_sharpe_ratio_none_on_single_point(self): + equity = _make_equity([10000]) + result = backtest_service._compute_metrics(equity, trades=0) + assert result["sharpe_ratio"] is None + + def test_sharpe_ratio_none_on_zero_std(self): + # Perfectly flat equity => std = 0, Sharpe undefined + equity = _make_equity([10000] * 50) + result = backtest_service._compute_metrics(equity, trades=0) + assert result["sharpe_ratio"] is None + + def test_max_drawdown_known_value(self): + # Peak 12000, trough 8000 => drawdown = (8000-12000)/12000 = -1/3 + equity = _make_equity([10000, 12000, 8000, 9000]) + result = backtest_service._compute_metrics(equity, trades=2) + assert result["max_drawdown"] == pytest.approx(-1 / 3, abs=1e-6) + + def test_max_drawdown_zero_on_monotone_rise(self): + equity = _make_equity([10000, 11000, 12000, 13000]) + result = backtest_service._compute_metrics(equity, trades=1) + assert result["max_drawdown"] == pytest.approx(0.0, abs=1e-6) + + def test_total_trades_propagated(self): + equity = _make_equity([10000, 11000]) + result = backtest_service._compute_metrics(equity, trades=7) + assert result["total_trades"] == 7 + + def test_win_rate_zero_trades(self): + equity = _make_equity([10000, 10000]) + result = backtest_service._compute_metrics(equity, trades=0) + assert result["win_rate"] is None + + def test_equity_curve_last_20_points(self): + values = list(range(100, 160)) # 60 points + equity = _make_equity(values) + result = backtest_service._compute_metrics(equity, trades=10) + assert len(result["equity_curve"]) == 20 + assert result["equity_curve"][-1] == pytest.approx(159.0, abs=1e-6) + + def test_equity_curve_shorter_than_20(self): + values = [10000, 11000, 12000] + equity = _make_equity(values) + result = backtest_service._compute_metrics(equity, trades=1) + assert len(result["equity_curve"]) == 3 + + def test_result_keys_present(self): + equity = _make_equity([10000, 11000]) + result = backtest_service._compute_metrics(equity, trades=1) + expected_keys = { + "total_return", + "annualized_return", + "sharpe_ratio", + "max_drawdown", + "win_rate", + "total_trades", + "equity_curve", + } + assert expected_keys.issubset(result.keys()) + + +# --------------------------------------------------------------------------- +# _compute_sma_signals tests +# --------------------------------------------------------------------------- + + +class TestComputeSmaSignals: + def test_returns_series_with_position_column(self): + prices = _oscillating_prices(200, period=40) + positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20) + assert isinstance(positions, pd.Series) + assert len(positions) == len(prices) + + def test_positions_are_zero_or_one(self): + prices = _oscillating_prices(200, period=40) + positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20) + unique_vals = set(positions.dropna().unique()) + assert unique_vals.issubset({0, 1}) + + def test_no_position_before_long_window(self): + prices = _oscillating_prices(200, period=40) + positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20) + # Before long_window-1 data points, positions should be 0 + assert (positions.iloc[: 19] == 0).all() + + def test_generates_at_least_one_signal_on_oscillating(self): + prices = _oscillating_prices(300, period=60) + positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20) + # Should flip between 0 and 1 at least once on oscillating data + changes = positions.diff().abs().sum() + assert changes > 0 + + def test_flat_prices_produce_no_signals(self): + prices = _flat_prices(100) + positions = backtest_service._compute_sma_signals(prices, short_window=5, long_window=20) + # After warm-up both SMAs equal price; short never strictly above long + assert (positions == 0).all() + + +# --------------------------------------------------------------------------- +# _compute_rsi tests +# --------------------------------------------------------------------------- + + +class TestComputeRsi: + def test_rsi_length(self): + prices = _rising_prices(50) + rsi = backtest_service._compute_rsi(prices, period=14) + assert len(rsi) == len(prices) + + def test_rsi_range(self): + prices = _oscillating_prices(100, period=20) + rsi = backtest_service._compute_rsi(prices, period=14) + valid = rsi.dropna() + assert (valid >= 0).all() + assert (valid <= 100).all() + + def test_rsi_rising_prices_high(self): + # Monotonically rising prices => RSI should be high (>= 70) + prices = _rising_prices(80, step=1.0) + rsi = backtest_service._compute_rsi(prices, period=14) + # After warm-up period, RSI should be very high + assert rsi.iloc[-1] >= 70 + + def test_rsi_falling_prices_low(self): + # Monotonically falling prices => RSI should be low (<= 30) + prices = pd.Series([100 - i * 0.8 for i in range(80)], dtype=float) + rsi = backtest_service._compute_rsi(prices, period=14) + assert rsi.iloc[-1] <= 30 + + +# --------------------------------------------------------------------------- +# _compute_rsi_signals tests +# --------------------------------------------------------------------------- + + +class TestComputeRsiSignals: + def test_returns_series(self): + prices = _oscillating_prices(200, period=40) + positions = backtest_service._compute_rsi_signals( + prices, period=14, oversold=30, overbought=70 + ) + assert isinstance(positions, pd.Series) + assert len(positions) == len(prices) + + def test_positions_are_zero_or_one(self): + prices = _oscillating_prices(200, period=40) + positions = backtest_service._compute_rsi_signals( + prices, period=14, oversold=30, overbought=70 + ) + unique_vals = set(positions.dropna().unique()) + assert unique_vals.issubset({0, 1}) + + +# --------------------------------------------------------------------------- +# backtest_sma_crossover tests (async integration of service layer) +# --------------------------------------------------------------------------- + + +class TestBacktestSmaCrossover: + @pytest.fixture + def mock_hist(self, monkeypatch): + """Patch fetch_historical to return a synthetic OBBject-like result.""" + prices = _oscillating_prices(300, period=60).tolist() + + class FakeBar: + def __init__(self, close): + self.close = close + + class FakeResult: + results = [FakeBar(p) for p in prices] + + async def fake_fetch(symbol, days, **kwargs): + return FakeResult() + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + return FakeResult() + + @pytest.mark.asyncio + async def test_returns_all_required_keys(self, mock_hist): + result = await backtest_service.backtest_sma_crossover( + "AAPL", short_window=5, long_window=20, days=365, initial_capital=10000 + ) + required = { + "total_return", + "annualized_return", + "sharpe_ratio", + "max_drawdown", + "win_rate", + "total_trades", + "equity_curve", + } + assert required.issubset(result.keys()) + + @pytest.mark.asyncio + async def test_equity_curve_max_20_points(self, mock_hist): + result = await backtest_service.backtest_sma_crossover( + "AAPL", short_window=5, long_window=20, days=365, initial_capital=10000 + ) + assert len(result["equity_curve"]) <= 20 + + @pytest.mark.asyncio + async def test_raises_value_error_on_no_data(self, monkeypatch): + async def fake_fetch(symbol, days, **kwargs): + return None + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + with pytest.raises(ValueError, match="No historical data"): + await backtest_service.backtest_sma_crossover( + "AAPL", short_window=5, long_window=20, days=365, initial_capital=10000 + ) + + @pytest.mark.asyncio + async def test_initial_capital_reflected_in_equity(self, mock_hist): + result = await backtest_service.backtest_sma_crossover( + "AAPL", short_window=5, long_window=20, days=365, initial_capital=50000 + ) + # equity_curve values should be in range related to 50000 initial capital + assert result["equity_curve"][0] > 0 + + +# --------------------------------------------------------------------------- +# backtest_rsi tests +# --------------------------------------------------------------------------- + + +class TestBacktestRsi: + @pytest.fixture + def mock_hist(self, monkeypatch): + prices = _oscillating_prices(300, period=60).tolist() + + class FakeBar: + def __init__(self, close): + self.close = close + + class FakeResult: + results = [FakeBar(p) for p in prices] + + async def fake_fetch(symbol, days, **kwargs): + return FakeResult() + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + + @pytest.mark.asyncio + async def test_returns_all_required_keys(self, mock_hist): + result = await backtest_service.backtest_rsi( + "AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000 + ) + required = { + "total_return", + "annualized_return", + "sharpe_ratio", + "max_drawdown", + "win_rate", + "total_trades", + "equity_curve", + } + assert required.issubset(result.keys()) + + @pytest.mark.asyncio + async def test_equity_curve_max_20_points(self, mock_hist): + result = await backtest_service.backtest_rsi( + "AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000 + ) + assert len(result["equity_curve"]) <= 20 + + @pytest.mark.asyncio + async def test_raises_value_error_on_no_data(self, monkeypatch): + async def fake_fetch(symbol, days, **kwargs): + return None + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + with pytest.raises(ValueError, match="No historical data"): + await backtest_service.backtest_rsi( + "AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000 + ) + + +# --------------------------------------------------------------------------- +# backtest_buy_and_hold tests +# --------------------------------------------------------------------------- + + +class TestBacktestBuyAndHold: + @pytest.fixture + def mock_hist_rising(self, monkeypatch): + prices = _rising_prices(252, start=100.0, step=1.0).tolist() + + class FakeBar: + def __init__(self, close): + self.close = close + + class FakeResult: + results = [FakeBar(p) for p in prices] + + async def fake_fetch(symbol, days, **kwargs): + return FakeResult() + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + + @pytest.mark.asyncio + async def test_returns_all_required_keys(self, mock_hist_rising): + result = await backtest_service.backtest_buy_and_hold( + "AAPL", days=365, initial_capital=10000 + ) + required = { + "total_return", + "annualized_return", + "sharpe_ratio", + "max_drawdown", + "win_rate", + "total_trades", + "equity_curve", + } + assert required.issubset(result.keys()) + + @pytest.mark.asyncio + async def test_total_trades_always_one(self, mock_hist_rising): + result = await backtest_service.backtest_buy_and_hold( + "AAPL", days=365, initial_capital=10000 + ) + assert result["total_trades"] == 1 + + @pytest.mark.asyncio + async def test_rising_prices_positive_return(self, mock_hist_rising): + result = await backtest_service.backtest_buy_and_hold( + "AAPL", days=365, initial_capital=10000 + ) + assert result["total_return"] > 0 + + @pytest.mark.asyncio + async def test_known_return_value(self, monkeypatch): + # 100 -> 200: 100% total return + prices = [100.0, 200.0] + + class FakeBar: + def __init__(self, close): + self.close = close + + class FakeResult: + results = [FakeBar(p) for p in prices] + + async def fake_fetch(symbol, days, **kwargs): + return FakeResult() + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + result = await backtest_service.backtest_buy_and_hold( + "AAPL", days=365, initial_capital=10000 + ) + assert result["total_return"] == pytest.approx(1.0, abs=1e-6) + assert result["equity_curve"][-1] == pytest.approx(20000.0, abs=1e-6) + + @pytest.mark.asyncio + async def test_raises_value_error_on_no_data(self, monkeypatch): + async def fake_fetch(symbol, days, **kwargs): + return None + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + with pytest.raises(ValueError, match="No historical data"): + await backtest_service.backtest_buy_and_hold("AAPL", days=365, initial_capital=10000) + + @pytest.mark.asyncio + async def test_flat_prices_zero_return(self, monkeypatch): + prices = _flat_prices(50).tolist() + + class FakeBar: + def __init__(self, close): + self.close = close + + class FakeResult: + results = [FakeBar(p) for p in prices] + + async def fake_fetch(symbol, days, **kwargs): + return FakeResult() + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + result = await backtest_service.backtest_buy_and_hold( + "AAPL", days=365, initial_capital=10000 + ) + assert result["total_return"] == pytest.approx(0.0, abs=1e-6) + + +# --------------------------------------------------------------------------- +# backtest_momentum tests +# --------------------------------------------------------------------------- + + +class TestBacktestMomentum: + @pytest.fixture + def mock_multi_hist(self, monkeypatch): + """Three symbols with different return profiles.""" + aapl_prices = _rising_prices(200, start=100.0, step=2.0).tolist() + msft_prices = _rising_prices(200, start=100.0, step=0.5).tolist() + googl_prices = _flat_prices(200, price=150.0).tolist() + + price_map = { + "AAPL": aapl_prices, + "MSFT": msft_prices, + "GOOGL": googl_prices, + } + + class FakeBar: + def __init__(self, close): + self.close = close + + class FakeResult: + def __init__(self, prices): + self.results = [FakeBar(p) for p in prices] + + async def fake_fetch(symbol, days, **kwargs): + return FakeResult(price_map[symbol]) + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + + @pytest.mark.asyncio + async def test_returns_all_required_keys(self, mock_multi_hist): + result = await backtest_service.backtest_momentum( + symbols=["AAPL", "MSFT", "GOOGL"], + lookback=20, + top_n=2, + rebalance_days=30, + days=365, + initial_capital=10000, + ) + required = { + "total_return", + "annualized_return", + "sharpe_ratio", + "max_drawdown", + "win_rate", + "total_trades", + "equity_curve", + "allocation_history", + } + assert required.issubset(result.keys()) + + @pytest.mark.asyncio + async def test_allocation_history_is_list(self, mock_multi_hist): + result = await backtest_service.backtest_momentum( + symbols=["AAPL", "MSFT", "GOOGL"], + lookback=20, + top_n=2, + rebalance_days=30, + days=365, + initial_capital=10000, + ) + assert isinstance(result["allocation_history"], list) + + @pytest.mark.asyncio + async def test_top_n_respected_in_allocations(self, mock_multi_hist): + result = await backtest_service.backtest_momentum( + symbols=["AAPL", "MSFT", "GOOGL"], + lookback=20, + top_n=2, + rebalance_days=30, + days=365, + initial_capital=10000, + ) + for entry in result["allocation_history"]: + assert len(entry["symbols"]) <= 2 + + @pytest.mark.asyncio + async def test_raises_value_error_on_no_data(self, monkeypatch): + async def fake_fetch(symbol, days, **kwargs): + return None + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + with pytest.raises(ValueError, match="No price data"): + await backtest_service.backtest_momentum( + symbols=["AAPL", "MSFT"], + lookback=20, + top_n=1, + rebalance_days=30, + days=365, + initial_capital=10000, + ) + + @pytest.mark.asyncio + async def test_equity_curve_max_20_points(self, mock_multi_hist): + result = await backtest_service.backtest_momentum( + symbols=["AAPL", "MSFT", "GOOGL"], + lookback=20, + top_n=2, + rebalance_days=30, + days=365, + initial_capital=10000, + ) + assert len(result["equity_curve"]) <= 20 + + +# --------------------------------------------------------------------------- +# Edge case: insufficient data +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + @pytest.mark.asyncio + async def test_sma_crossover_insufficient_bars_raises(self, monkeypatch): + """Fewer bars than long_window should raise ValueError.""" + prices = [100.0, 101.0, 102.0] # Only 3 bars + + class FakeBar: + def __init__(self, close): + self.close = close + + class FakeResult: + results = [FakeBar(p) for p in prices] + + async def fake_fetch(symbol, days, **kwargs): + return FakeResult() + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + with pytest.raises(ValueError, match="Insufficient data"): + await backtest_service.backtest_sma_crossover( + "AAPL", short_window=5, long_window=20, days=365, initial_capital=10000 + ) + + @pytest.mark.asyncio + async def test_rsi_insufficient_bars_raises(self, monkeypatch): + prices = [100.0, 101.0] + + class FakeBar: + def __init__(self, close): + self.close = close + + class FakeResult: + results = [FakeBar(p) for p in prices] + + async def fake_fetch(symbol, days, **kwargs): + return FakeResult() + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + with pytest.raises(ValueError, match="Insufficient data"): + await backtest_service.backtest_rsi( + "AAPL", period=14, oversold=30, overbought=70, days=365, initial_capital=10000 + ) + + @pytest.mark.asyncio + async def test_buy_and_hold_single_bar_raises(self, monkeypatch): + prices = [100.0] + + class FakeBar: + def __init__(self, close): + self.close = close + + class FakeResult: + results = [FakeBar(p) for p in prices] + + async def fake_fetch(symbol, days, **kwargs): + return FakeResult() + + monkeypatch.setattr(backtest_service, "fetch_historical", fake_fetch) + with pytest.raises(ValueError, match="Insufficient data"): + await backtest_service.backtest_buy_and_hold( + "AAPL", days=365, initial_capital=10000 + ) diff --git a/tests/test_routes_backtest.py b/tests/test_routes_backtest.py new file mode 100644 index 0000000..36e1547 --- /dev/null +++ b/tests/test_routes_backtest.py @@ -0,0 +1,443 @@ +"""Integration tests for backtest routes - written FIRST (TDD RED phase).""" + +from unittest.mock import AsyncMock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +from main import app + + +@pytest.fixture +async def client(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +# Shared mock response used across strategy tests +MOCK_BACKTEST_RESULT = { + "total_return": 0.15, + "annualized_return": 0.14, + "sharpe_ratio": 1.2, + "max_drawdown": -0.08, + "win_rate": 0.6, + "total_trades": 10, + "equity_curve": [10000 + i * 75 for i in range(20)], +} + +MOCK_MOMENTUM_RESULT = { + **MOCK_BACKTEST_RESULT, + "allocation_history": [ + {"date": "2024-01-01", "symbols": ["AAPL", "MSFT"], "weights": [0.5, 0.5]}, + ], +} + + +# --------------------------------------------------------------------------- +# POST /api/v1/backtest/sma-crossover +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock) +async def test_sma_crossover_happy_path(mock_fn, client): + mock_fn.return_value = MOCK_BACKTEST_RESULT + resp = await client.post( + "/api/v1/backtest/sma-crossover", + json={ + "symbol": "AAPL", + "short_window": 20, + "long_window": 50, + "days": 365, + "initial_capital": 10000, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert data["data"]["total_return"] == pytest.approx(0.15) + assert data["data"]["total_trades"] == 10 + assert len(data["data"]["equity_curve"]) == 20 + mock_fn.assert_called_once_with( + "AAPL", + short_window=20, + long_window=50, + days=365, + initial_capital=10000.0, + ) + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock) +async def test_sma_crossover_default_values(mock_fn, client): + mock_fn.return_value = MOCK_BACKTEST_RESULT + resp = await client.post( + "/api/v1/backtest/sma-crossover", + json={"symbol": "MSFT"}, + ) + assert resp.status_code == 200 + mock_fn.assert_called_once_with( + "MSFT", + short_window=20, + long_window=50, + days=365, + initial_capital=10000.0, + ) + + +@pytest.mark.asyncio +async def test_sma_crossover_missing_symbol(client): + resp = await client.post( + "/api/v1/backtest/sma-crossover", + json={"short_window": 20, "long_window": 50}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_sma_crossover_short_window_too_small(client): + resp = await client.post( + "/api/v1/backtest/sma-crossover", + json={"symbol": "AAPL", "short_window": 2, "long_window": 50}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_sma_crossover_long_window_too_large(client): + resp = await client.post( + "/api/v1/backtest/sma-crossover", + json={"symbol": "AAPL", "short_window": 20, "long_window": 500}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_sma_crossover_days_too_few(client): + resp = await client.post( + "/api/v1/backtest/sma-crossover", + json={"symbol": "AAPL", "days": 5}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_sma_crossover_days_too_many(client): + resp = await client.post( + "/api/v1/backtest/sma-crossover", + json={"symbol": "AAPL", "days": 9999}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_sma_crossover_capital_zero(client): + resp = await client.post( + "/api/v1/backtest/sma-crossover", + json={"symbol": "AAPL", "initial_capital": 0}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock) +async def test_sma_crossover_service_error_returns_502(mock_fn, client): + mock_fn.side_effect = RuntimeError("Data fetch failed") + resp = await client.post( + "/api/v1/backtest/sma-crossover", + json={"symbol": "AAPL", "short_window": 20, "long_window": 50}, + ) + assert resp.status_code == 502 + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_sma_crossover", new_callable=AsyncMock) +async def test_sma_crossover_value_error_returns_400(mock_fn, client): + mock_fn.side_effect = ValueError("No historical data") + resp = await client.post( + "/api/v1/backtest/sma-crossover", + json={"symbol": "AAPL", "short_window": 20, "long_window": 50}, + ) + assert resp.status_code == 400 + + +# --------------------------------------------------------------------------- +# POST /api/v1/backtest/rsi +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock) +async def test_rsi_happy_path(mock_fn, client): + mock_fn.return_value = MOCK_BACKTEST_RESULT + resp = await client.post( + "/api/v1/backtest/rsi", + json={ + "symbol": "AAPL", + "period": 14, + "oversold": 30, + "overbought": 70, + "days": 365, + "initial_capital": 10000, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert data["data"]["sharpe_ratio"] == pytest.approx(1.2) + mock_fn.assert_called_once_with( + "AAPL", + period=14, + oversold=30.0, + overbought=70.0, + days=365, + initial_capital=10000.0, + ) + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock) +async def test_rsi_default_values(mock_fn, client): + mock_fn.return_value = MOCK_BACKTEST_RESULT + resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"}) + assert resp.status_code == 200 + mock_fn.assert_called_once_with( + "AAPL", + period=14, + oversold=30.0, + overbought=70.0, + days=365, + initial_capital=10000.0, + ) + + +@pytest.mark.asyncio +async def test_rsi_missing_symbol(client): + resp = await client.post("/api/v1/backtest/rsi", json={"period": 14}) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_rsi_period_too_small(client): + resp = await client.post( + "/api/v1/backtest/rsi", + json={"symbol": "AAPL", "period": 1}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_rsi_oversold_too_high(client): + # oversold must be < 50 + resp = await client.post( + "/api/v1/backtest/rsi", + json={"symbol": "AAPL", "oversold": 55, "overbought": 70}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_rsi_overbought_too_low(client): + # overbought must be > 50 + resp = await client.post( + "/api/v1/backtest/rsi", + json={"symbol": "AAPL", "oversold": 30, "overbought": 45}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock) +async def test_rsi_service_error_returns_502(mock_fn, client): + mock_fn.side_effect = RuntimeError("upstream error") + resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"}) + assert resp.status_code == 502 + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_rsi", new_callable=AsyncMock) +async def test_rsi_value_error_returns_400(mock_fn, client): + mock_fn.side_effect = ValueError("Insufficient data") + resp = await client.post("/api/v1/backtest/rsi", json={"symbol": "AAPL"}) + assert resp.status_code == 400 + + +# --------------------------------------------------------------------------- +# POST /api/v1/backtest/buy-and-hold +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock) +async def test_buy_and_hold_happy_path(mock_fn, client): + mock_fn.return_value = MOCK_BACKTEST_RESULT + resp = await client.post( + "/api/v1/backtest/buy-and-hold", + json={"symbol": "AAPL", "days": 365, "initial_capital": 10000}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert data["data"]["total_return"] == pytest.approx(0.15) + mock_fn.assert_called_once_with("AAPL", days=365, initial_capital=10000.0) + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock) +async def test_buy_and_hold_default_values(mock_fn, client): + mock_fn.return_value = MOCK_BACKTEST_RESULT + resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"}) + assert resp.status_code == 200 + mock_fn.assert_called_once_with("AAPL", days=365, initial_capital=10000.0) + + +@pytest.mark.asyncio +async def test_buy_and_hold_missing_symbol(client): + resp = await client.post("/api/v1/backtest/buy-and-hold", json={"days": 365}) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_buy_and_hold_days_too_few(client): + resp = await client.post( + "/api/v1/backtest/buy-and-hold", + json={"symbol": "AAPL", "days": 10}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock) +async def test_buy_and_hold_service_error_returns_502(mock_fn, client): + mock_fn.side_effect = RuntimeError("provider down") + resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"}) + assert resp.status_code == 502 + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_buy_and_hold", new_callable=AsyncMock) +async def test_buy_and_hold_value_error_returns_400(mock_fn, client): + mock_fn.side_effect = ValueError("No historical data") + resp = await client.post("/api/v1/backtest/buy-and-hold", json={"symbol": "AAPL"}) + assert resp.status_code == 400 + + +# --------------------------------------------------------------------------- +# POST /api/v1/backtest/momentum +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock) +async def test_momentum_happy_path(mock_fn, client): + mock_fn.return_value = MOCK_MOMENTUM_RESULT + resp = await client.post( + "/api/v1/backtest/momentum", + json={ + "symbols": ["AAPL", "MSFT", "GOOGL"], + "lookback": 60, + "top_n": 2, + "rebalance_days": 30, + "days": 365, + "initial_capital": 10000, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert "allocation_history" in data["data"] + assert data["data"]["total_trades"] == 10 + mock_fn.assert_called_once_with( + symbols=["AAPL", "MSFT", "GOOGL"], + lookback=60, + top_n=2, + rebalance_days=30, + days=365, + initial_capital=10000.0, + ) + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock) +async def test_momentum_default_values(mock_fn, client): + mock_fn.return_value = MOCK_MOMENTUM_RESULT + resp = await client.post( + "/api/v1/backtest/momentum", + json={"symbols": ["AAPL", "MSFT"]}, + ) + assert resp.status_code == 200 + mock_fn.assert_called_once_with( + symbols=["AAPL", "MSFT"], + lookback=60, + top_n=2, + rebalance_days=30, + days=365, + initial_capital=10000.0, + ) + + +@pytest.mark.asyncio +async def test_momentum_missing_symbols(client): + resp = await client.post("/api/v1/backtest/momentum", json={"lookback": 60}) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_momentum_too_few_symbols(client): + # min_length=2 on symbols list + resp = await client.post( + "/api/v1/backtest/momentum", + json={"symbols": ["AAPL"]}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_momentum_too_many_symbols(client): + symbols = [f"SYM{i}" for i in range(25)] + resp = await client.post( + "/api/v1/backtest/momentum", + json={"symbols": symbols}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_momentum_lookback_too_small(client): + resp = await client.post( + "/api/v1/backtest/momentum", + json={"symbols": ["AAPL", "MSFT"], "lookback": 2}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_momentum_days_too_few(client): + resp = await client.post( + "/api/v1/backtest/momentum", + json={"symbols": ["AAPL", "MSFT"], "days": 10}, + ) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock) +async def test_momentum_service_error_returns_502(mock_fn, client): + mock_fn.side_effect = RuntimeError("provider down") + resp = await client.post( + "/api/v1/backtest/momentum", + json={"symbols": ["AAPL", "MSFT"]}, + ) + assert resp.status_code == 502 + + +@pytest.mark.asyncio +@patch("routes_backtest.backtest_service.backtest_momentum", new_callable=AsyncMock) +async def test_momentum_value_error_returns_400(mock_fn, client): + mock_fn.side_effect = ValueError("No price data available") + resp = await client.post( + "/api/v1/backtest/momentum", + json={"symbols": ["AAPL", "MSFT"]}, + ) + assert resp.status_code == 400