From 003c1d6ffcacbfea7687728b666029d911819423 Mon Sep 17 00:00:00 2001 From: Yaojia Wang Date: Mon, 9 Mar 2026 10:56:21 +0100 Subject: [PATCH] refactor: fix code review issues across routes and services - Extract shared route_utils.py (validate_symbol, safe decorator) removing duplication from 6 route files - Extract shared obb_utils.py (to_list, extract_single, safe_last) removing duplication from calendar_service and market_service - Fix _to_list dict mutation during iteration (use comprehension) - Fix double vars() call and live __dict__ mutation risk - Fix route ordering: /etf/search and /crypto/search now registered before /{symbol} path params to prevent shadowing - Add date format validation (YYYY-MM-DD pattern) on calendar routes - Use timezone-aware datetime.now(tz=timezone.utc) in all services - Add explicit type annotation for asyncio.gather results --- calendar_service.py | 40 +++++---------- market_service.py | 56 ++++++++------------- obb_utils.py | 51 +++++++++++++++++++ quantitative_service.py | 53 ++++++-------------- route_utils.py | 39 +++++++++++++++ routes.py | 77 +++++++++-------------------- routes_calendar.py | 95 ++++++++++++++++-------------------- routes_macro.py | 33 ++----------- routes_market.py | 105 +++++++++++++++------------------------- routes_quantitative.py | 53 +++++--------------- routes_sentiment.py | 56 ++++++--------------- routes_technical.py | 41 ++-------------- 12 files changed, 271 insertions(+), 428 deletions(-) create mode 100644 obb_utils.py create mode 100644 route_utils.py diff --git a/calendar_service.py b/calendar_service.py index 26aee26..6d08f6e 100644 --- a/calendar_service.py +++ b/calendar_service.py @@ -6,6 +6,8 @@ from typing import Any from openbb import obb +from obb_utils import to_list + logger = logging.getLogger(__name__) @@ -20,7 +22,7 @@ async def get_earnings_calendar( if end_date: kwargs["end_date"] = end_date result = await asyncio.to_thread(obb.equity.calendar.earnings, **kwargs) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Earnings calendar failed", exc_info=True) return [] @@ -37,7 +39,7 @@ async def get_dividend_calendar( if end_date: kwargs["end_date"] = end_date result = await asyncio.to_thread(obb.equity.calendar.dividend, **kwargs) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Dividend calendar failed", exc_info=True) return [] @@ -54,7 +56,7 @@ async def get_ipo_calendar( if end_date: kwargs["end_date"] = end_date result = await asyncio.to_thread(obb.equity.calendar.ipo, **kwargs) - return _to_list(result) + return to_list(result) except Exception: logger.warning("IPO calendar failed", exc_info=True) return [] @@ -71,7 +73,7 @@ async def get_splits_calendar( if end_date: kwargs["end_date"] = end_date result = await asyncio.to_thread(obb.equity.calendar.splits, **kwargs) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Splits calendar failed", exc_info=True) return [] @@ -83,7 +85,7 @@ async def get_analyst_estimates(symbol: str) -> dict[str, Any]: result = await asyncio.to_thread( obb.equity.estimates.consensus, symbol, provider="yfinance" ) - items = _to_list(result) + items = to_list(result) return {"symbol": symbol, "estimates": items} except Exception: logger.warning("Analyst estimates failed for %s", symbol, exc_info=True) @@ -96,7 +98,7 @@ async def get_share_statistics(symbol: str) -> dict[str, Any]: result = await asyncio.to_thread( obb.equity.ownership.share_statistics, symbol, provider="yfinance" ) - items = _to_list(result) + items = to_list(result) return items[0] if items else {} except Exception: logger.warning("Share statistics failed for %s", symbol, exc_info=True) @@ -109,7 +111,7 @@ async def get_insider_trading(symbol: str) -> list[dict[str, Any]]: result = await asyncio.to_thread( obb.equity.ownership.insider_trading, symbol, provider="sec" ) - return _to_list(result) + return to_list(result) except Exception: logger.warning("SEC insider trading failed for %s", symbol, exc_info=True) return [] @@ -121,7 +123,7 @@ async def get_institutional_holders(symbol: str) -> list[dict[str, Any]]: result = await asyncio.to_thread( obb.equity.ownership.form_13f, symbol, provider="sec" ) - return _to_list(result) + return to_list(result) except Exception: logger.warning("13F data failed for %s", symbol, exc_info=True) return [] @@ -133,27 +135,7 @@ async def screen_stocks() -> list[dict[str, Any]]: result = await asyncio.to_thread( obb.equity.screener, provider="yfinance" ) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Stock screener failed", exc_info=True) return [] - - -def _to_list(result: Any) -> list[dict[str, Any]]: - """Convert OBBject result to list of dicts.""" - if result is None or result.results is None: - return [] - items = result.results - if not isinstance(items, list): - items = [items] - out = [] - for item in items: - if hasattr(item, "model_dump"): - d = item.model_dump() - else: - d = vars(item) if vars(item) else {} - for k, v in d.items(): - if hasattr(v, "isoformat"): - d[k] = v.isoformat() - out.append(d) - return out diff --git a/market_service.py b/market_service.py index 53b0d5d..f58b7af 100644 --- a/market_service.py +++ b/market_service.py @@ -2,11 +2,13 @@ import asyncio import logging -from datetime import datetime, timedelta +from datetime import datetime, timezone, timedelta from typing import Any from openbb import obb +from obb_utils import to_list + logger = logging.getLogger(__name__) PROVIDER = "yfinance" @@ -19,7 +21,7 @@ async def get_etf_info(symbol: str) -> dict[str, Any]: """Get ETF profile/info.""" try: result = await asyncio.to_thread(obb.etf.info, symbol, provider=PROVIDER) - items = _to_list(result) + items = to_list(result) return items[0] if items else {} except Exception: logger.warning("ETF info failed for %s", symbol, exc_info=True) @@ -28,12 +30,12 @@ async def get_etf_info(symbol: str) -> dict[str, Any]: async def get_etf_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]: """Get ETF price history.""" - start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") + start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d") try: result = await asyncio.to_thread( obb.etf.historical, symbol, start_date=start, provider=PROVIDER ) - return _to_list(result) + return to_list(result) except Exception: logger.warning("ETF historical failed for %s", symbol, exc_info=True) return [] @@ -43,7 +45,7 @@ async def search_etf(query: str) -> list[dict[str, Any]]: """Search for ETFs by name or keyword.""" try: result = await asyncio.to_thread(obb.etf.search, query) - return _to_list(result) + return to_list(result) except Exception: logger.warning("ETF search failed for %s", query, exc_info=True) return [] @@ -56,7 +58,7 @@ async def get_available_indices() -> list[dict[str, Any]]: """List available market indices.""" try: result = await asyncio.to_thread(obb.index.available, provider=PROVIDER) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Available indices failed", exc_info=True) return [] @@ -64,12 +66,12 @@ async def get_available_indices() -> list[dict[str, Any]]: async def get_index_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]: """Get index price history.""" - start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") + start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d") try: result = await asyncio.to_thread( obb.index.price.historical, symbol, start_date=start, provider=PROVIDER ) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Index historical failed for %s", symbol, exc_info=True) return [] @@ -80,12 +82,12 @@ async def get_index_historical(symbol: str, days: int = 365) -> list[dict[str, A async def get_crypto_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]: """Get cryptocurrency price history.""" - start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") + start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d") try: result = await asyncio.to_thread( obb.crypto.price.historical, symbol, start_date=start, provider=PROVIDER ) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Crypto historical failed for %s", symbol, exc_info=True) return [] @@ -95,7 +97,7 @@ async def search_crypto(query: str) -> list[dict[str, Any]]: """Search for cryptocurrencies.""" try: result = await asyncio.to_thread(obb.crypto.search, query) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Crypto search failed for %s", query, exc_info=True) return [] @@ -108,12 +110,12 @@ async def get_currency_historical( symbol: str, days: int = 365 ) -> list[dict[str, Any]]: """Get forex price history (e.g., EURUSD).""" - start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") + start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d") try: result = await asyncio.to_thread( obb.currency.price.historical, symbol, start_date=start, provider=PROVIDER ) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Currency historical failed for %s", symbol, exc_info=True) return [] @@ -128,7 +130,7 @@ async def get_options_chains(symbol: str) -> list[dict[str, Any]]: result = await asyncio.to_thread( obb.derivatives.options.chains, symbol, provider=PROVIDER ) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Options chains failed for %s", symbol, exc_info=True) return [] @@ -138,12 +140,12 @@ async def get_futures_historical( symbol: str, days: int = 365 ) -> list[dict[str, Any]]: """Get futures price history.""" - start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") + start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d") try: result = await asyncio.to_thread( obb.derivatives.futures.historical, symbol, start_date=start, provider=PROVIDER ) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Futures historical failed for %s", symbol, exc_info=True) return [] @@ -155,27 +157,7 @@ async def get_futures_curve(symbol: str) -> list[dict[str, Any]]: result = await asyncio.to_thread( obb.derivatives.futures.curve, symbol, provider=PROVIDER ) - return _to_list(result) + return to_list(result) except Exception: logger.warning("Futures curve failed for %s", symbol, exc_info=True) return [] - - -def _to_list(result: Any) -> list[dict[str, Any]]: - """Convert OBBject result to list of dicts.""" - if result is None or result.results is None: - return [] - items = result.results - if not isinstance(items, list): - items = [items] - out = [] - for item in items: - if hasattr(item, "model_dump"): - d = item.model_dump() - else: - d = vars(item) if vars(item) else {} - for k, v in d.items(): - if hasattr(v, "isoformat"): - d[k] = v.isoformat() - out.append(d) - return out diff --git a/obb_utils.py b/obb_utils.py new file mode 100644 index 0000000..fc1d91c --- /dev/null +++ b/obb_utils.py @@ -0,0 +1,51 @@ +"""Shared OpenBB result conversion utilities.""" + +from typing import Any + + +def to_list(result: Any) -> list[dict[str, Any]]: + """Convert OBBject result to list of dicts with serialized dates.""" + if result is None or result.results is None: + return [] + items = result.results + if not isinstance(items, list): + items = [items] + out = [] + for item in items: + if hasattr(item, "model_dump"): + d = item.model_dump() + else: + raw = vars(item) + d = dict(raw) if raw else {} + d = { + k: v.isoformat() if hasattr(v, "isoformat") else v + for k, v in d.items() + } + out.append(d) + return out + + +def extract_single(result: Any) -> dict[str, Any]: + """Extract data from an OBBject result (single model or list).""" + if result is None: + return {} + items = getattr(result, "results", None) + if items is None: + return {} + if hasattr(items, "model_dump"): + return items.model_dump() + if isinstance(items, list) and items: + last = items[-1] + return last.model_dump() if hasattr(last, "model_dump") else {} + return {} + + +def safe_last(result: Any) -> dict[str, Any] | None: + """Get the last item from a list result, or None.""" + if result is None: + return None + items = getattr(result, "results", None) + if items is None or not isinstance(items, list) or not items: + return None + last = items[-1] + return last.model_dump() if hasattr(last, "model_dump") else None diff --git a/quantitative_service.py b/quantitative_service.py index 2d2f302..dfbd8ed 100644 --- a/quantitative_service.py +++ b/quantitative_service.py @@ -2,11 +2,13 @@ import asyncio import logging -from datetime import datetime, timedelta +from datetime import datetime, timezone, timedelta from typing import Any from openbb import obb +from obb_utils import extract_single, safe_last + logger = logging.getLogger(__name__) PROVIDER = "yfinance" @@ -20,7 +22,7 @@ async def get_performance_metrics(symbol: str, days: int = 365) -> dict[str, Any """Calculate Sharpe ratio, summary stats, and volatility for a symbol.""" # Need at least 252 trading days for Sharpe window fetch_days = max(days, PERF_DAYS) - start = (datetime.now() - timedelta(days=fetch_days)).strftime("%Y-%m-%d") + start = (datetime.now(tz=timezone.utc) - timedelta(days=fetch_days)).strftime("%Y-%m-%d") try: hist = await asyncio.to_thread( @@ -29,7 +31,7 @@ async def get_performance_metrics(symbol: str, days: int = 365) -> dict[str, Any if not hist or not hist.results: return {"symbol": symbol, "error": "No historical data"} - sharpe_result, summary_result, stdev_result = await asyncio.gather( + results: tuple[Any, ...] = await asyncio.gather( asyncio.to_thread( obb.quantitative.performance.sharpe_ratio, data=hist.results, target=TARGET, @@ -42,10 +44,11 @@ async def get_performance_metrics(symbol: str, days: int = 365) -> dict[str, Any ), return_exceptions=True, ) + sharpe_result, summary_result, stdev_result = results - sharpe = _safe_last(sharpe_result) if not isinstance(sharpe_result, BaseException) else None - summary = _extract_single(summary_result) if not isinstance(summary_result, BaseException) else {} - stdev = _safe_last(stdev_result) if not isinstance(stdev_result, BaseException) else None + sharpe = safe_last(sharpe_result) if not isinstance(sharpe_result, BaseException) else None + summary = extract_single(summary_result) if not isinstance(summary_result, BaseException) else {} + stdev = safe_last(stdev_result) if not isinstance(stdev_result, BaseException) else None return { "symbol": symbol, @@ -61,7 +64,7 @@ async def get_performance_metrics(symbol: str, days: int = 365) -> dict[str, Any async def get_capm(symbol: str) -> dict[str, Any]: """Calculate CAPM metrics: beta, alpha, systematic/idiosyncratic risk.""" - start = (datetime.now() - timedelta(days=PERF_DAYS)).strftime("%Y-%m-%d") + start = (datetime.now(tz=timezone.utc) - timedelta(days=PERF_DAYS)).strftime("%Y-%m-%d") try: hist = await asyncio.to_thread( @@ -73,7 +76,7 @@ async def get_capm(symbol: str) -> dict[str, Any]: capm = await asyncio.to_thread( obb.quantitative.capm, data=hist.results, target=TARGET ) - return {"symbol": symbol, **_extract_single(capm)} + return {"symbol": symbol, **extract_single(capm)} except Exception: logger.warning("CAPM failed for %s", symbol, exc_info=True) return {"symbol": symbol, "error": "Failed to compute CAPM"} @@ -82,7 +85,7 @@ async def get_capm(symbol: str) -> dict[str, Any]: async def get_normality_test(symbol: str, days: int = 365) -> dict[str, Any]: """Run normality tests (Jarque-Bera, Shapiro-Wilk, etc.) on returns.""" fetch_days = max(days, PERF_DAYS) - start = (datetime.now() - timedelta(days=fetch_days)).strftime("%Y-%m-%d") + start = (datetime.now(tz=timezone.utc) - timedelta(days=fetch_days)).strftime("%Y-%m-%d") try: hist = await asyncio.to_thread( @@ -94,7 +97,7 @@ async def get_normality_test(symbol: str, days: int = 365) -> dict[str, Any]: norm = await asyncio.to_thread( obb.quantitative.normality, data=hist.results, target=TARGET ) - return {"symbol": symbol, **_extract_single(norm)} + return {"symbol": symbol, **extract_single(norm)} except Exception: logger.warning("Normality test failed for %s", symbol, exc_info=True) return {"symbol": symbol, "error": "Failed to compute normality tests"} @@ -103,7 +106,7 @@ async def get_normality_test(symbol: str, days: int = 365) -> dict[str, Any]: async def get_unitroot_test(symbol: str, days: int = 365) -> dict[str, Any]: """Run unit root tests (ADF, KPSS) for stationarity.""" fetch_days = max(days, PERF_DAYS) - start = (datetime.now() - timedelta(days=fetch_days)).strftime("%Y-%m-%d") + start = (datetime.now(tz=timezone.utc) - timedelta(days=fetch_days)).strftime("%Y-%m-%d") try: hist = await asyncio.to_thread( @@ -115,33 +118,7 @@ async def get_unitroot_test(symbol: str, days: int = 365) -> dict[str, Any]: ur = await asyncio.to_thread( obb.quantitative.unitroot_test, data=hist.results, target=TARGET ) - return {"symbol": symbol, **_extract_single(ur)} + return {"symbol": symbol, **extract_single(ur)} except Exception: logger.warning("Unit root test failed for %s", symbol, exc_info=True) return {"symbol": symbol, "error": "Failed to compute unit root test"} - - -def _extract_single(result: Any) -> dict[str, Any]: - """Extract data from an OBBject result (single model or list).""" - if result is None: - return {} - items = getattr(result, "results", None) - if items is None: - return {} - if hasattr(items, "model_dump"): - return items.model_dump() - if isinstance(items, list) and items: - last = items[-1] - return last.model_dump() if hasattr(last, "model_dump") else {} - return {} - - -def _safe_last(result: Any) -> dict[str, Any] | None: - """Get the last item from a list result, or None.""" - if result is None: - return None - items = getattr(result, "results", None) - if items is None or not isinstance(items, list) or not items: - return None - last = items[-1] - return last.model_dump() if hasattr(last, "model_dump") else None diff --git a/route_utils.py b/route_utils.py new file mode 100644 index 0000000..966451d --- /dev/null +++ b/route_utils.py @@ -0,0 +1,39 @@ +"""Shared route utilities: symbol validation and error handling decorator.""" + +import functools +import logging +from collections.abc import Awaitable, Callable +from typing import ParamSpec, TypeVar + +from fastapi import HTTPException + +from models import SYMBOL_PATTERN + +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +R = TypeVar("R") + + +def validate_symbol(symbol: str) -> str: + """Validate and normalize a stock symbol.""" + if not SYMBOL_PATTERN.match(symbol): + raise HTTPException(status_code=400, detail="Invalid symbol format") + return symbol.upper() + + +def safe(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + """Decorator to catch upstream errors and return 502.""" + @functools.wraps(fn) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + try: + return await fn(*args, **kwargs) + except HTTPException: + raise + except Exception: + logger.exception("Upstream data error") + raise HTTPException( + status_code=502, + detail="Data provider error. Check server logs.", + ) + return wrapper # type: ignore[return-value] diff --git a/routes.py b/routes.py index 652946a..2424837 100644 --- a/routes.py +++ b/routes.py @@ -1,9 +1,4 @@ -import functools -import logging -from collections.abc import Awaitable, Callable -from typing import ParamSpec, TypeVar - -from fastapi import APIRouter, HTTPException, Path, Query +from fastapi import APIRouter, Path, Query from mappers import ( discover_items_from_list, @@ -12,7 +7,6 @@ from mappers import ( quote_from_dict, ) from models import ( - SYMBOL_PATTERN, ApiResponse, FinancialsResponse, HistoricalBar, @@ -21,87 +15,60 @@ from models import ( PortfolioResponse, SummaryResponse, ) +from route_utils import safe, validate_symbol import openbb_service import analysis_service -logger = logging.getLogger(__name__) - router = APIRouter(prefix="/api/v1") -P = ParamSpec("P") -R = TypeVar("R") - - -def _validate_symbol(symbol: str) -> str: - if not SYMBOL_PATTERN.match(symbol): - raise HTTPException(status_code=400, detail="Invalid symbol format") - return symbol.upper() - - -def _safe(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: - """Decorator to catch OpenBB errors and return 502.""" - @functools.wraps(fn) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - try: - return await fn(*args, **kwargs) - except HTTPException: - raise - except Exception: - logger.exception("Upstream data error") - raise HTTPException( - status_code=502, - detail="Data provider error. Check server logs.", - ) - return wrapper # type: ignore[return-value] - # --- Stock Data --- @router.get("/stock/{symbol}/quote", response_model=ApiResponse) -@_safe +@safe async def stock_quote(symbol: str = Path(..., min_length=1, max_length=20)): """Get current quote for a stock.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await openbb_service.get_quote(symbol) return ApiResponse(data=quote_from_dict(symbol, data).model_dump()) @router.get("/stock/{symbol}/profile", response_model=ApiResponse) -@_safe +@safe async def stock_profile(symbol: str = Path(..., min_length=1, max_length=20)): """Get company profile.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await openbb_service.get_profile(symbol) return ApiResponse(data=profile_from_dict(symbol, data).model_dump()) @router.get("/stock/{symbol}/metrics", response_model=ApiResponse) -@_safe +@safe async def stock_metrics(symbol: str = Path(..., min_length=1, max_length=20)): """Get key financial metrics (PE, PB, ROE, etc.).""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await openbb_service.get_metrics(symbol) return ApiResponse(data=metrics_from_dict(symbol, data).model_dump()) @router.get("/stock/{symbol}/financials", response_model=ApiResponse) -@_safe +@safe async def stock_financials(symbol: str = Path(..., min_length=1, max_length=20)): """Get income statement, balance sheet, and cash flow.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await openbb_service.get_financials(symbol) return ApiResponse(data=FinancialsResponse(**data).model_dump()) @router.get("/stock/{symbol}/historical", response_model=ApiResponse) -@_safe +@safe async def stock_historical( symbol: str = Path(..., min_length=1, max_length=20), days: int = Query(default=365, ge=1, le=3650), ): """Get historical price data.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await openbb_service.get_historical(symbol, days=days) bars = [ HistoricalBar( @@ -118,10 +85,10 @@ async def stock_historical( @router.get("/stock/{symbol}/news", response_model=ApiResponse) -@_safe +@safe async def stock_news(symbol: str = Path(..., min_length=1, max_length=20)): """Get recent company news.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await openbb_service.get_news(symbol) news = [ NewsItem( @@ -136,10 +103,10 @@ async def stock_news(symbol: str = Path(..., min_length=1, max_length=20)): @router.get("/stock/{symbol}/summary", response_model=ApiResponse) -@_safe +@safe async def stock_summary(symbol: str = Path(..., min_length=1, max_length=20)): """Get aggregated stock data: quote + profile + metrics + financials.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await openbb_service.get_summary(symbol) summary = SummaryResponse( quote=quote_from_dict(symbol, data.get("quote", {})), @@ -156,7 +123,7 @@ async def stock_summary(symbol: str = Path(..., min_length=1, max_length=20)): @router.post("/portfolio/analyze", response_model=ApiResponse) -@_safe +@safe async def portfolio_analyze(request: PortfolioRequest): """Analyze portfolio holdings with rule-based engine.""" result: PortfolioResponse = await analysis_service.analyze_portfolio( @@ -169,7 +136,7 @@ async def portfolio_analyze(request: PortfolioRequest): @router.get("/discover/gainers", response_model=ApiResponse) -@_safe +@safe async def discover_gainers(): """Get top gainers (US market).""" data = await openbb_service.get_gainers() @@ -177,7 +144,7 @@ async def discover_gainers(): @router.get("/discover/losers", response_model=ApiResponse) -@_safe +@safe async def discover_losers(): """Get top losers (US market).""" data = await openbb_service.get_losers() @@ -185,7 +152,7 @@ async def discover_losers(): @router.get("/discover/active", response_model=ApiResponse) -@_safe +@safe async def discover_active(): """Get most active stocks (US market).""" data = await openbb_service.get_active() @@ -193,7 +160,7 @@ async def discover_active(): @router.get("/discover/undervalued", response_model=ApiResponse) -@_safe +@safe async def discover_undervalued(): """Get undervalued large cap stocks.""" data = await openbb_service.get_undervalued() @@ -201,7 +168,7 @@ async def discover_undervalued(): @router.get("/discover/growth", response_model=ApiResponse) -@_safe +@safe async def discover_growth(): """Get growth tech stocks.""" data = await openbb_service.get_growth() diff --git a/routes_calendar.py b/routes_calendar.py index cf7310a..116db33 100644 --- a/routes_calendar.py +++ b/routes_calendar.py @@ -1,53 +1,28 @@ """Routes for calendar events, screening, ownership, and estimates.""" -import functools -import logging -from collections.abc import Awaitable, Callable -from typing import ParamSpec, TypeVar +from fastapi import APIRouter, Path, Query -from fastapi import APIRouter, HTTPException, Path, Query - -from models import SYMBOL_PATTERN, ApiResponse +from models import ApiResponse +from route_utils import safe, validate_symbol import calendar_service -logger = logging.getLogger(__name__) - router = APIRouter(prefix="/api/v1") -P = ParamSpec("P") -R = TypeVar("R") - - -def _validate_symbol(symbol: str) -> str: - if not SYMBOL_PATTERN.match(symbol): - raise HTTPException(status_code=400, detail="Invalid symbol format") - return symbol.upper() - - -def _safe(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: - @functools.wraps(fn) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - try: - return await fn(*args, **kwargs) - except HTTPException: - raise - except Exception: - logger.exception("Upstream data error") - raise HTTPException( - status_code=502, - detail="Data provider error. Check server logs.", - ) - return wrapper # type: ignore[return-value] +DATE_PATTERN = r"^\d{4}-\d{2}-\d{2}$" # --- Calendar Events --- @router.get("/calendar/earnings", response_model=ApiResponse) -@_safe +@safe async def earnings_calendar( - start_date: str | None = Query(default=None, description="YYYY-MM-DD"), - end_date: str | None = Query(default=None, description="YYYY-MM-DD"), + start_date: str | None = Query( + default=None, pattern=DATE_PATTERN, description="YYYY-MM-DD" + ), + end_date: str | None = Query( + default=None, pattern=DATE_PATTERN, description="YYYY-MM-DD" + ), ): """Get upcoming earnings announcements.""" data = await calendar_service.get_earnings_calendar(start_date, end_date) @@ -55,10 +30,14 @@ async def earnings_calendar( @router.get("/calendar/dividends", response_model=ApiResponse) -@_safe +@safe async def dividend_calendar( - start_date: str | None = Query(default=None, description="YYYY-MM-DD"), - end_date: str | None = Query(default=None, description="YYYY-MM-DD"), + start_date: str | None = Query( + default=None, pattern=DATE_PATTERN, description="YYYY-MM-DD" + ), + end_date: str | None = Query( + default=None, pattern=DATE_PATTERN, description="YYYY-MM-DD" + ), ): """Get upcoming dividend dates.""" data = await calendar_service.get_dividend_calendar(start_date, end_date) @@ -66,10 +45,14 @@ async def dividend_calendar( @router.get("/calendar/ipo", response_model=ApiResponse) -@_safe +@safe async def ipo_calendar( - start_date: str | None = Query(default=None, description="YYYY-MM-DD"), - end_date: str | None = Query(default=None, description="YYYY-MM-DD"), + start_date: str | None = Query( + default=None, pattern=DATE_PATTERN, description="YYYY-MM-DD" + ), + end_date: str | None = Query( + default=None, pattern=DATE_PATTERN, description="YYYY-MM-DD" + ), ): """Get upcoming IPOs.""" data = await calendar_service.get_ipo_calendar(start_date, end_date) @@ -77,10 +60,14 @@ async def ipo_calendar( @router.get("/calendar/splits", response_model=ApiResponse) -@_safe +@safe async def splits_calendar( - start_date: str | None = Query(default=None, description="YYYY-MM-DD"), - end_date: str | None = Query(default=None, description="YYYY-MM-DD"), + start_date: str | None = Query( + default=None, pattern=DATE_PATTERN, description="YYYY-MM-DD" + ), + end_date: str | None = Query( + default=None, pattern=DATE_PATTERN, description="YYYY-MM-DD" + ), ): """Get upcoming stock splits.""" data = await calendar_service.get_splits_calendar(start_date, end_date) @@ -91,19 +78,19 @@ async def splits_calendar( @router.get("/stock/{symbol}/estimates", response_model=ApiResponse) -@_safe +@safe async def stock_estimates(symbol: str = Path(..., min_length=1, max_length=20)): """Get analyst consensus estimates.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await calendar_service.get_analyst_estimates(symbol) return ApiResponse(data=data) @router.get("/stock/{symbol}/share-statistics", response_model=ApiResponse) -@_safe +@safe async def stock_share_stats(symbol: str = Path(..., min_length=1, max_length=20)): """Get share statistics: float, outstanding, short interest.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await calendar_service.get_share_statistics(symbol) return ApiResponse(data=data) @@ -112,19 +99,19 @@ async def stock_share_stats(symbol: str = Path(..., min_length=1, max_length=20) @router.get("/stock/{symbol}/sec-insider", response_model=ApiResponse) -@_safe +@safe async def stock_sec_insider(symbol: str = Path(..., min_length=1, max_length=20)): """Get insider trading data from SEC (Form 4).""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await calendar_service.get_insider_trading(symbol) return ApiResponse(data=data) @router.get("/stock/{symbol}/institutional", response_model=ApiResponse) -@_safe +@safe async def stock_institutional(symbol: str = Path(..., min_length=1, max_length=20)): """Get institutional holders from SEC 13F filings.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await calendar_service.get_institutional_holders(symbol) return ApiResponse(data=data) @@ -133,7 +120,7 @@ async def stock_institutional(symbol: str = Path(..., min_length=1, max_length=2 @router.get("/screener", response_model=ApiResponse) -@_safe +@safe async def stock_screener(): """Screen stocks using available filters.""" data = await calendar_service.screen_stocks() diff --git a/routes_macro.py b/routes_macro.py index d67de84..6b1a483 100644 --- a/routes_macro.py +++ b/routes_macro.py @@ -1,41 +1,16 @@ """Routes for macroeconomic data (FRED-powered).""" -import functools -import logging -from collections.abc import Awaitable, Callable -from typing import ParamSpec, TypeVar - -from fastapi import APIRouter, HTTPException, Query +from fastapi import APIRouter, Query from models import ApiResponse +from route_utils import safe import macro_service -logger = logging.getLogger(__name__) - router = APIRouter(prefix="/api/v1") -P = ParamSpec("P") -R = TypeVar("R") - - -def _safe(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: - @functools.wraps(fn) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - try: - return await fn(*args, **kwargs) - except HTTPException: - raise - except Exception: - logger.exception("Upstream data error") - raise HTTPException( - status_code=502, - detail="Data provider error. Check server logs.", - ) - return wrapper # type: ignore[return-value] - @router.get("/macro/overview", response_model=ApiResponse) -@_safe +@safe async def macro_overview(): """Get key macro indicators: Fed rate, treasury yields, CPI, unemployment, GDP, VIX.""" data = await macro_service.get_macro_overview() @@ -43,7 +18,7 @@ async def macro_overview(): @router.get("/macro/series/{series_id}", response_model=ApiResponse) -@_safe +@safe async def macro_series( series_id: str, limit: int = Query(default=30, ge=1, le=1000), diff --git a/routes_market.py b/routes_market.py index b0c6a94..9356ca2 100644 --- a/routes_market.py +++ b/routes_market.py @@ -1,82 +1,52 @@ """Routes for ETF, index, crypto, currency, and derivatives data.""" -import functools -import logging -from collections.abc import Awaitable, Callable -from typing import ParamSpec, TypeVar +from fastapi import APIRouter, Path, Query -from fastapi import APIRouter, HTTPException, Path, Query - -from models import SYMBOL_PATTERN, ApiResponse +from models import ApiResponse +from route_utils import safe, validate_symbol import market_service -logger = logging.getLogger(__name__) - router = APIRouter(prefix="/api/v1") -P = ParamSpec("P") -R = TypeVar("R") - - -def _validate_symbol(symbol: str) -> str: - if not SYMBOL_PATTERN.match(symbol): - raise HTTPException(status_code=400, detail="Invalid symbol format") - return symbol.upper() - - -def _safe(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: - @functools.wraps(fn) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - try: - return await fn(*args, **kwargs) - except HTTPException: - raise - except Exception: - logger.exception("Upstream data error") - raise HTTPException( - status_code=502, - detail="Data provider error. Check server logs.", - ) - return wrapper # type: ignore[return-value] - # --- ETF --- +# NOTE: /etf/search MUST be registered before /etf/{symbol} to avoid shadowing. + + +@router.get("/etf/search", response_model=ApiResponse) +@safe +async def etf_search(query: str = Query(..., min_length=1, max_length=100)): + """Search for ETFs by name or keyword.""" + data = await market_service.search_etf(query) + return ApiResponse(data=data) @router.get("/etf/{symbol}/info", response_model=ApiResponse) -@_safe +@safe async def etf_info(symbol: str = Path(..., min_length=1, max_length=20)): """Get ETF profile and info.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await market_service.get_etf_info(symbol) return ApiResponse(data=data) @router.get("/etf/{symbol}/historical", response_model=ApiResponse) -@_safe +@safe async def etf_historical( symbol: str = Path(..., min_length=1, max_length=20), days: int = Query(default=365, ge=1, le=3650), ): """Get ETF price history.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await market_service.get_etf_historical(symbol, days=days) return ApiResponse(data=data) -@router.get("/etf/search", response_model=ApiResponse) -@_safe -async def etf_search(query: str = Query(..., min_length=1, max_length=100)): - """Search for ETFs by name or keyword.""" - data = await market_service.search_etf(query) - return ApiResponse(data=data) - - # --- Index --- @router.get("/index/available", response_model=ApiResponse) -@_safe +@safe async def index_available(): """List available market indices.""" data = await market_service.get_available_indices() @@ -84,51 +54,52 @@ async def index_available(): @router.get("/index/{symbol}/historical", response_model=ApiResponse) -@_safe +@safe async def index_historical( symbol: str = Path(..., min_length=1, max_length=20), days: int = Query(default=365, ge=1, le=3650), ): """Get index price history (e.g., ^GSPC, ^DJI, ^IXIC).""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await market_service.get_index_historical(symbol, days=days) return ApiResponse(data=data) # --- Crypto --- +# NOTE: /crypto/search MUST be registered before /crypto/{symbol} to avoid shadowing. + + +@router.get("/crypto/search", response_model=ApiResponse) +@safe +async def crypto_search(query: str = Query(..., min_length=1, max_length=100)): + """Search for cryptocurrencies.""" + data = await market_service.search_crypto(query) + return ApiResponse(data=data) @router.get("/crypto/{symbol}/historical", response_model=ApiResponse) -@_safe +@safe async def crypto_historical( symbol: str = Path(..., min_length=1, max_length=20), days: int = Query(default=365, ge=1, le=3650), ): """Get cryptocurrency price history (e.g., BTC-USD).""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await market_service.get_crypto_historical(symbol, days=days) return ApiResponse(data=data) -@router.get("/crypto/search", response_model=ApiResponse) -@_safe -async def crypto_search(query: str = Query(..., min_length=1, max_length=100)): - """Search for cryptocurrencies.""" - data = await market_service.search_crypto(query) - return ApiResponse(data=data) - - # --- Currency --- @router.get("/currency/{symbol}/historical", response_model=ApiResponse) -@_safe +@safe async def currency_historical( symbol: str = Path(..., min_length=1, max_length=20), days: int = Query(default=365, ge=1, le=3650), ): """Get forex price history (e.g., EURUSD, USDSEK).""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await market_service.get_currency_historical(symbol, days=days) return ApiResponse(data=data) @@ -137,30 +108,30 @@ async def currency_historical( @router.get("/options/{symbol}/chains", response_model=ApiResponse) -@_safe +@safe async def options_chains(symbol: str = Path(..., min_length=1, max_length=20)): """Get options chain data.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await market_service.get_options_chains(symbol) return ApiResponse(data=data) @router.get("/futures/{symbol}/historical", response_model=ApiResponse) -@_safe +@safe async def futures_historical( symbol: str = Path(..., min_length=1, max_length=20), days: int = Query(default=365, ge=1, le=3650), ): """Get futures price history.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await market_service.get_futures_historical(symbol, days=days) return ApiResponse(data=data) @router.get("/futures/{symbol}/curve", response_model=ApiResponse) -@_safe +@safe async def futures_curve(symbol: str = Path(..., min_length=1, max_length=20)): """Get futures term structure/curve.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await market_service.get_futures_curve(symbol) return ApiResponse(data=data) diff --git a/routes_quantitative.py b/routes_quantitative.py index 1b5f2ac..cac3b30 100644 --- a/routes_quantitative.py +++ b/routes_quantitative.py @@ -1,85 +1,54 @@ """Routes for quantitative analysis: risk metrics, CAPM, normality, unit root.""" -import functools -import logging -from collections.abc import Awaitable, Callable -from typing import ParamSpec, TypeVar +from fastapi import APIRouter, Path, Query -from fastapi import APIRouter, HTTPException, Path, Query - -from models import SYMBOL_PATTERN, ApiResponse +from models import ApiResponse +from route_utils import safe, validate_symbol import quantitative_service -logger = logging.getLogger(__name__) - router = APIRouter(prefix="/api/v1") -P = ParamSpec("P") -R = TypeVar("R") - - -def _validate_symbol(symbol: str) -> str: - if not SYMBOL_PATTERN.match(symbol): - raise HTTPException(status_code=400, detail="Invalid symbol format") - return symbol.upper() - - -def _safe(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: - @functools.wraps(fn) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - try: - return await fn(*args, **kwargs) - except HTTPException: - raise - except Exception: - logger.exception("Upstream data error") - raise HTTPException( - status_code=502, - detail="Data provider error. Check server logs.", - ) - return wrapper # type: ignore[return-value] - @router.get("/stock/{symbol}/performance", response_model=ApiResponse) -@_safe +@safe async def stock_performance( symbol: str = Path(..., min_length=1, max_length=20), days: int = Query(default=365, ge=30, le=3650), ): """Performance metrics: Sharpe, Sortino, max drawdown, volatility.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await quantitative_service.get_performance_metrics(symbol, days=days) return ApiResponse(data=data) @router.get("/stock/{symbol}/capm", response_model=ApiResponse) -@_safe +@safe async def stock_capm(symbol: str = Path(..., min_length=1, max_length=20)): """CAPM: beta, alpha, systematic and idiosyncratic risk.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await quantitative_service.get_capm(symbol) return ApiResponse(data=data) @router.get("/stock/{symbol}/normality", response_model=ApiResponse) -@_safe +@safe async def stock_normality( symbol: str = Path(..., min_length=1, max_length=20), days: int = Query(default=365, ge=30, le=3650), ): """Normality tests: Jarque-Bera, Shapiro-Wilk on returns.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await quantitative_service.get_normality_test(symbol, days=days) return ApiResponse(data=data) @router.get("/stock/{symbol}/unitroot", response_model=ApiResponse) -@_safe +@safe async def stock_unitroot( symbol: str = Path(..., min_length=1, max_length=20), days: int = Query(default=365, ge=30, le=3650), ): """Unit root tests: ADF, KPSS for stationarity.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await quantitative_service.get_unitroot_test(symbol, days=days) return ApiResponse(data=data) diff --git a/routes_sentiment.py b/routes_sentiment.py index af195b2..8caa71b 100644 --- a/routes_sentiment.py +++ b/routes_sentiment.py @@ -1,55 +1,29 @@ """Routes for sentiment, insider trades, and analyst data (Finnhub + Alpha Vantage).""" import asyncio -import functools -import logging -from collections.abc import Awaitable, Callable -from typing import ParamSpec, TypeVar -from fastapi import APIRouter, HTTPException, Path, Query +from fastapi import APIRouter, Path, Query -from models import SYMBOL_PATTERN, ApiResponse +from models import ApiResponse +from route_utils import safe, validate_symbol import alphavantage_service import finnhub_service +import logging + logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1") -P = ParamSpec("P") -R = TypeVar("R") - - -def _validate_symbol(symbol: str) -> str: - if not SYMBOL_PATTERN.match(symbol): - raise HTTPException(status_code=400, detail="Invalid symbol format") - return symbol.upper() - - -def _safe(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: - @functools.wraps(fn) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - try: - return await fn(*args, **kwargs) - except HTTPException: - raise - except Exception: - logger.exception("Upstream data error") - raise HTTPException( - status_code=502, - detail="Data provider error. Check server logs.", - ) - return wrapper # type: ignore[return-value] - # --- Sentiment & News --- @router.get("/stock/{symbol}/sentiment", response_model=ApiResponse) -@_safe +@safe async def stock_sentiment(symbol: str = Path(..., min_length=1, max_length=20)): """Get aggregated sentiment: Alpha Vantage news sentiment + Finnhub analyst data.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) finnhub_data, av_data = await asyncio.gather( finnhub_service.get_sentiment_summary(symbol), alphavantage_service.get_news_sentiment(symbol, limit=20), @@ -67,22 +41,22 @@ async def stock_sentiment(symbol: str = Path(..., min_length=1, max_length=20)): @router.get("/stock/{symbol}/news-sentiment", response_model=ApiResponse) -@_safe +@safe async def stock_news_sentiment( symbol: str = Path(..., min_length=1, max_length=20), limit: int = Query(default=30, ge=1, le=200), ): """Get news articles with per-ticker sentiment scores (Alpha Vantage).""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await alphavantage_service.get_news_sentiment(symbol, limit=limit) return ApiResponse(data=data) @router.get("/stock/{symbol}/insider-trades", response_model=ApiResponse) -@_safe +@safe async def stock_insider_trades(symbol: str = Path(..., min_length=1, max_length=20)): """Get insider transactions (CEO/CFO buys and sells).""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) raw = await finnhub_service.get_insider_transactions(symbol) trades = [ { @@ -100,10 +74,10 @@ async def stock_insider_trades(symbol: str = Path(..., min_length=1, max_length= @router.get("/stock/{symbol}/recommendations", response_model=ApiResponse) -@_safe +@safe async def stock_recommendations(symbol: str = Path(..., min_length=1, max_length=20)): """Get analyst recommendation trends (monthly buy/hold/sell counts).""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) raw = await finnhub_service.get_recommendation_trends(symbol) recs = [ { @@ -120,10 +94,10 @@ async def stock_recommendations(symbol: str = Path(..., min_length=1, max_length @router.get("/stock/{symbol}/upgrades", response_model=ApiResponse) -@_safe +@safe async def stock_upgrades(symbol: str = Path(..., min_length=1, max_length=20)): """Get recent analyst upgrades and downgrades.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) raw = await finnhub_service.get_upgrade_downgrade(symbol) upgrades = [ { diff --git a/routes_technical.py b/routes_technical.py index e7a31c6..04a073c 100644 --- a/routes_technical.py +++ b/routes_technical.py @@ -1,49 +1,18 @@ """Routes for technical analysis indicators.""" -import functools -import logging -from collections.abc import Awaitable, Callable -from typing import ParamSpec, TypeVar +from fastapi import APIRouter, Path -from fastapi import APIRouter, HTTPException, Path - -from models import SYMBOL_PATTERN, ApiResponse +from models import ApiResponse +from route_utils import safe, validate_symbol import technical_service -logger = logging.getLogger(__name__) - router = APIRouter(prefix="/api/v1") -P = ParamSpec("P") -R = TypeVar("R") - - -def _validate_symbol(symbol: str) -> str: - if not SYMBOL_PATTERN.match(symbol): - raise HTTPException(status_code=400, detail="Invalid symbol format") - return symbol.upper() - - -def _safe(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: - @functools.wraps(fn) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - try: - return await fn(*args, **kwargs) - except HTTPException: - raise - except Exception: - logger.exception("Upstream data error") - raise HTTPException( - status_code=502, - detail="Data provider error. Check server logs.", - ) - return wrapper # type: ignore[return-value] - @router.get("/stock/{symbol}/technical", response_model=ApiResponse) -@_safe +@safe async def stock_technical(symbol: str = Path(..., min_length=1, max_length=20)): """Get technical indicators: RSI, MACD, SMA, EMA, Bollinger Bands + signal interpretation.""" - symbol = _validate_symbol(symbol) + symbol = validate_symbol(symbol) data = await technical_service.get_technical_indicators(symbol) return ApiResponse(data=data)