refactor: address python review findings

- Move FRED credential registration to FastAPI lifespan (was fragile
  import-order-dependent side-effect)
- Add noqa E402 annotations for imports after curl_cffi patch
- Fix all return type hints: bare dict -> dict[str, Any]
- Move yfinance import to module level (was inline in functions)
- Fix datetime.now() -> datetime.now(tz=timezone.utc) in openbb_service
- Add try/except error handling to Group B service functions
- Fix dict mutation in relative_rotation (immutable pattern)
- Extract _classify_rrg_quadrant helper function
- Fix type builtin shadow in routes_economy (type -> gdp_type)
- Fix falsy int guard (if year: -> if year is not None:)
- Remove user input echo from error messages
This commit is contained in:
Yaojia Wang
2026-03-19 17:40:47 +01:00
parent e2cf6e2488
commit 89bdc6c552
6 changed files with 118 additions and 80 deletions

View File

@@ -174,7 +174,7 @@ async def get_fomc_documents(year: int | None = None) -> list[dict[str, Any]]:
"""Get FOMC meeting documents (minutes, projections, etc.).""" """Get FOMC meeting documents (minutes, projections, etc.)."""
try: try:
kwargs: dict[str, Any] = {"provider": "federal_reserve"} kwargs: dict[str, Any] = {"provider": "federal_reserve"}
if year: if year is not None:
kwargs["year"] = year kwargs["year"] = year
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.economy.fomc_documents, **kwargs obb.economy.fomc_documents, **kwargs

46
main.py
View File

@@ -1,4 +1,5 @@
import logging import logging
from contextlib import asynccontextmanager
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
@@ -12,41 +13,52 @@ import curl_cffi.requests as _cffi_requests
_orig_session_init = _cffi_requests.Session.__init__ _orig_session_init = _cffi_requests.Session.__init__
def _patched_session_init(self, *args, **kwargs): def _patched_session_init(self, *args, **kwargs):
if kwargs.get("impersonate") == "chrome": if kwargs.get("impersonate") == "chrome":
kwargs["impersonate"] = "safari" kwargs["impersonate"] = "safari"
_orig_session_init(self, *args, **kwargs) _orig_session_init(self, *args, **kwargs)
_cffi_requests.Session.__init__ = _patched_session_init _cffi_requests.Session.__init__ = _patched_session_init
from openbb import obb from openbb import obb # noqa: E402 - must be after curl_cffi patch
from config import settings
# Register optional provider credentials with OpenBB from config import settings # noqa: E402
if settings.fred_api_key: from routes import router # noqa: E402
obb.user.credentials.fred_api_key = settings.fred_api_key from routes_calendar import router as calendar_router # noqa: E402
from routes import router from routes_economy import router as economy_router # noqa: E402
from routes_sentiment import router as sentiment_router from routes_fixed_income import router as fixed_income_router # noqa: E402
from routes_macro import router as macro_router from routes_macro import router as macro_router # noqa: E402
from routes_technical import router as technical_router from routes_market import router as market_router # noqa: E402
from routes_quantitative import router as quantitative_router from routes_quantitative import router as quantitative_router # noqa: E402
from routes_calendar import router as calendar_router from routes_regulators import router as regulators_router # noqa: E402
from routes_market import router as market_router from routes_sentiment import router as sentiment_router # noqa: E402
from routes_shorts import router as shorts_router from routes_shorts import router as shorts_router # noqa: E402
from routes_fixed_income import router as fixed_income_router from routes_surveys import router as surveys_router # noqa: E402
from routes_economy import router as economy_router from routes_technical import router as technical_router # noqa: E402
from routes_surveys import router as surveys_router
from routes_regulators import router as regulators_router
logging.basicConfig( logging.basicConfig(
level=settings.log_level.upper(), level=settings.log_level.upper(),
format="%(asctime)s %(levelname)s %(name)s: %(message)s", format="%(asctime)s %(levelname)s %(name)s: %(message)s",
) )
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Register provider credentials once at startup."""
if settings.fred_api_key:
obb.user.credentials.fred_api_key = settings.fred_api_key
logger.info("FRED API key registered")
yield
app = FastAPI( app = FastAPI(
title="OpenBB Investment Analysis API", title="OpenBB Investment Analysis API",
version="0.1.0", version="0.1.0",
description="REST API for stock data and rule-based investment analysis, powered by OpenBB SDK.", description="REST API for stock data and rule-based investment analysis, powered by OpenBB SDK.",
lifespan=lifespan,
) )
app.add_middleware( app.add_middleware(

View File

@@ -1,8 +1,9 @@
import asyncio import asyncio
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import Any from typing import Any
import yfinance as yf
from openbb import obb from openbb import obb
from obb_utils import to_list, first_or_empty from obb_utils import to_list, first_or_empty
@@ -12,15 +13,15 @@ logger = logging.getLogger(__name__)
PROVIDER = "yfinance" PROVIDER = "yfinance"
async def get_quote(symbol: str) -> dict: async def get_quote(symbol: str) -> dict[str, Any]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.price.quote, symbol, provider=PROVIDER obb.equity.price.quote, symbol, provider=PROVIDER
) )
return first_or_empty(result) return first_or_empty(result)
async def get_historical(symbol: str, days: int = 365) -> list[dict]: async def get_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]:
start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.price.historical, obb.equity.price.historical,
symbol, symbol,
@@ -36,42 +37,42 @@ async def get_historical(symbol: str, days: int = 365) -> list[dict]:
] ]
async def get_profile(symbol: str) -> dict: async def get_profile(symbol: str) -> dict[str, Any]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.profile, symbol, provider=PROVIDER obb.equity.profile, symbol, provider=PROVIDER
) )
return first_or_empty(result) return first_or_empty(result)
async def get_metrics(symbol: str) -> dict: async def get_metrics(symbol: str) -> dict[str, Any]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.fundamental.metrics, symbol, provider=PROVIDER obb.equity.fundamental.metrics, symbol, provider=PROVIDER
) )
return first_or_empty(result) return first_or_empty(result)
async def get_income(symbol: str) -> list[dict]: async def get_income(symbol: str) -> list[dict[str, Any]]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.fundamental.income, symbol, provider=PROVIDER obb.equity.fundamental.income, symbol, provider=PROVIDER
) )
return to_list(result) return to_list(result)
async def get_balance(symbol: str) -> list[dict]: async def get_balance(symbol: str) -> list[dict[str, Any]]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.fundamental.balance, symbol, provider=PROVIDER obb.equity.fundamental.balance, symbol, provider=PROVIDER
) )
return to_list(result) return to_list(result)
async def get_cash_flow(symbol: str) -> list[dict]: async def get_cash_flow(symbol: str) -> list[dict[str, Any]]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.fundamental.cash, symbol, provider=PROVIDER obb.equity.fundamental.cash, symbol, provider=PROVIDER
) )
return to_list(result) return to_list(result)
async def get_financials(symbol: str) -> dict: async def get_financials(symbol: str) -> dict[str, Any]:
income, balance, cash_flow = await asyncio.gather( income, balance, cash_flow = await asyncio.gather(
get_income(symbol), get_income(symbol),
get_balance(symbol), get_balance(symbol),
@@ -87,8 +88,6 @@ async def get_financials(symbol: str) -> dict:
async def get_price_target(symbol: str) -> float | None: async def get_price_target(symbol: str) -> float | None:
"""Get consensus analyst price target via yfinance.""" """Get consensus analyst price target via yfinance."""
import yfinance as yf
def _fetch() -> float | None: def _fetch() -> float | None:
t = yf.Ticker(symbol) t = yf.Ticker(symbol)
return t.info.get("targetMeanPrice") return t.info.get("targetMeanPrice")
@@ -100,14 +99,14 @@ async def get_price_target(symbol: str) -> float | None:
return None return None
async def get_news(symbol: str) -> list[dict]: async def get_news(symbol: str) -> list[dict[str, Any]]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.news.company, symbol, provider=PROVIDER obb.news.company, symbol, provider=PROVIDER
) )
return to_list(result) return to_list(result)
async def get_summary(symbol: str) -> dict: async def get_summary(symbol: str) -> dict[str, Any]:
quote, profile, metrics, financials = await asyncio.gather( quote, profile, metrics, financials = await asyncio.gather(
get_quote(symbol), get_quote(symbol),
get_profile(symbol), get_profile(symbol),
@@ -122,45 +121,45 @@ async def get_summary(symbol: str) -> dict:
} }
async def get_gainers() -> list[dict]: async def get_gainers() -> list[dict[str, Any]]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.discovery.gainers, provider=PROVIDER obb.equity.discovery.gainers, provider=PROVIDER
) )
return to_list(result) return to_list(result)
async def get_losers() -> list[dict]: async def get_losers() -> list[dict[str, Any]]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.discovery.losers, provider=PROVIDER obb.equity.discovery.losers, provider=PROVIDER
) )
return to_list(result) return to_list(result)
async def get_active() -> list[dict]: async def get_active() -> list[dict[str, Any]]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.discovery.active, provider=PROVIDER obb.equity.discovery.active, provider=PROVIDER
) )
return to_list(result) return to_list(result)
async def get_undervalued() -> list[dict]: async def get_undervalued() -> list[dict[str, Any]]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.discovery.undervalued_large_caps, provider=PROVIDER obb.equity.discovery.undervalued_large_caps, provider=PROVIDER
) )
return to_list(result) return to_list(result)
async def get_growth() -> list[dict]: async def get_growth() -> list[dict[str, Any]]:
result = await asyncio.to_thread( result = await asyncio.to_thread(
obb.equity.discovery.growth_tech, provider=PROVIDER obb.equity.discovery.growth_tech, provider=PROVIDER
) )
return to_list(result) return to_list(result)
async def get_upgrades_downgrades(symbol: str, limit: int = 20) -> list[dict]: async def get_upgrades_downgrades(
symbol: str, limit: int = 20,
) -> list[dict[str, Any]]:
"""Get analyst upgrades/downgrades via yfinance.""" """Get analyst upgrades/downgrades via yfinance."""
import yfinance as yf
def _fetch() -> list[dict[str, Any]]: def _fetch() -> list[dict[str, Any]]:
t = yf.Ticker(symbol) t = yf.Ticker(symbol)
df = t.upgrades_downgrades df = t.upgrades_downgrades
@@ -187,32 +186,54 @@ async def get_upgrades_downgrades(symbol: str, limit: int = 20) -> list[dict]:
# --- Equity Fundamentals Extended (Group B) --- # --- Equity Fundamentals Extended (Group B) ---
async def get_management(symbol: str) -> list[dict]: async def get_management(symbol: str) -> list[dict[str, Any]]:
"""Get executive team info (name, title, compensation).""" """Get executive team info (name, title, compensation)."""
result = await asyncio.to_thread( try:
obb.equity.fundamental.management, symbol, provider=PROVIDER result = await asyncio.to_thread(
) obb.equity.fundamental.management, symbol, provider=PROVIDER
return to_list(result) )
return to_list(result)
except Exception:
logger.warning("Management failed for %s", symbol, exc_info=True)
return []
async def get_dividends(symbol: str) -> list[dict]: async def get_dividends(symbol: str) -> list[dict[str, Any]]:
"""Get historical dividend records.""" """Get historical dividend records."""
result = await asyncio.to_thread( try:
obb.equity.fundamental.dividends, symbol, provider=PROVIDER result = await asyncio.to_thread(
) obb.equity.fundamental.dividends, symbol, provider=PROVIDER
return to_list(result) )
return to_list(result)
except Exception:
logger.warning("Dividends failed for %s", symbol, exc_info=True)
return []
async def get_filings(symbol: str, form_type: str | None = None) -> list[dict]: async def get_filings(
symbol: str, form_type: str | None = None,
) -> list[dict[str, Any]]:
"""Get SEC filings (10-K, 10-Q, 8-K, etc.).""" """Get SEC filings (10-K, 10-Q, 8-K, etc.)."""
kwargs: dict[str, Any] = {"symbol": symbol, "provider": "sec"} try:
if form_type: kwargs: dict[str, Any] = {"symbol": symbol, "provider": "sec"}
kwargs["type"] = form_type if form_type is not None:
result = await asyncio.to_thread(obb.equity.fundamental.filings, **kwargs) kwargs["type"] = form_type
return to_list(result) result = await asyncio.to_thread(
obb.equity.fundamental.filings, **kwargs
)
return to_list(result)
except Exception:
logger.warning("Filings failed for %s", symbol, exc_info=True)
return []
async def search_company(query: str) -> list[dict]: async def search_company(query: str) -> list[dict[str, Any]]:
"""Search for companies by name.""" """Search for companies by name."""
result = await asyncio.to_thread(obb.equity.search, query, provider="sec") try:
return to_list(result) result = await asyncio.to_thread(
obb.equity.search, query, provider="sec"
)
return to_list(result)
except Exception:
logger.warning("Company search failed for %s", query, exc_info=True)
return []

View File

@@ -167,7 +167,7 @@ async def get_rolling_stat(
"""Compute a rolling statistic (variance, stdev, mean, skew, kurtosis, quantile).""" """Compute a rolling statistic (variance, stdev, mean, skew, kurtosis, quantile)."""
valid_stats = {"variance", "stdev", "mean", "skew", "kurtosis", "quantile"} valid_stats = {"variance", "stdev", "mean", "skew", "kurtosis", "quantile"}
if stat not in valid_stats: if stat not in valid_stats:
return {"symbol": symbol, "error": f"Invalid stat: {stat}. Use: {', '.join(sorted(valid_stats))}"} return {"symbol": symbol, "error": f"Invalid stat. Valid options: {', '.join(sorted(valid_stats))}"}
fetch_days = max(days, PERF_DAYS) fetch_days = max(days, PERF_DAYS)
hist = await fetch_historical(symbol, fetch_days) hist = await fetch_historical(symbol, fetch_days)

View File

@@ -23,10 +23,10 @@ async def macro_cpi(country: str = Query(default="united_states", max_length=50,
@router.get("/macro/gdp", response_model=ApiResponse) @router.get("/macro/gdp", response_model=ApiResponse)
@safe @safe
async def macro_gdp( async def macro_gdp(
type: str = Query(default="real", pattern="^(nominal|real|forecast)$"), gdp_type: str = Query(default="real", pattern="^(nominal|real|forecast)$"),
): ):
"""GDP: nominal, real, or forecast.""" """GDP: nominal, real, or forecast."""
data = await economy_service.get_gdp(gdp_type=type) data = await economy_service.get_gdp(gdp_type=gdp_type)
return ApiResponse(data=data) return ApiResponse(data=data)

View File

@@ -411,13 +411,12 @@ async def get_relative_rotation(
Returns RS-Ratio and RS-Momentum for each symbol, indicating Returns RS-Ratio and RS-Momentum for each symbol, indicating
which RRG quadrant they occupy (Leading/Weakening/Lagging/Improving). which RRG quadrant they occupy (Leading/Weakening/Lagging/Improving).
""" """
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone as tz
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d") start = (datetime.now(tz=tz.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
all_symbols = ",".join(symbols + [benchmark]) all_symbols = ",".join(symbols + [benchmark])
try: try:
# Fetch multi-symbol historical data in one call
hist = await asyncio.to_thread( hist = await asyncio.to_thread(
obb.equity.price.historical, obb.equity.price.historical,
all_symbols, all_symbols,
@@ -434,26 +433,17 @@ async def get_relative_rotation(
study=study, study=study,
) )
items = to_list(result) items = to_list(result)
# Return the latest data point per symbol
latest_by_symbol: dict[str, dict] = {} latest_by_symbol: dict[str, dict[str, Any]] = {}
for item in items: for item in items:
sym = item.get("symbol") sym = item.get("symbol")
if sym and sym != benchmark: if sym and sym != benchmark:
latest_by_symbol[sym] = item latest_by_symbol[sym] = item
entries = list(latest_by_symbol.values()) entries = [
for entry in entries: {**item, "quadrant": _classify_rrg_quadrant(item)}
rs_ratio = entry.get("rs_ratio") for item in latest_by_symbol.values()
rs_momentum = entry.get("rs_momentum") ]
if rs_ratio is not None and rs_momentum is not None:
if rs_ratio > 100 and rs_momentum > 100:
entry["quadrant"] = "Leading"
elif rs_ratio > 100 and rs_momentum <= 100:
entry["quadrant"] = "Weakening"
elif rs_ratio <= 100 and rs_momentum <= 100:
entry["quadrant"] = "Lagging"
else:
entry["quadrant"] = "Improving"
return { return {
"symbols": symbols, "symbols": symbols,
@@ -466,6 +456,21 @@ async def get_relative_rotation(
return {"symbols": symbols, "error": "Failed to compute relative rotation"} return {"symbols": symbols, "error": "Failed to compute relative rotation"}
def _classify_rrg_quadrant(item: dict[str, Any]) -> str | None:
"""Classify RRG quadrant from RS-Ratio and RS-Momentum."""
rs_ratio = item.get("rs_ratio")
rs_momentum = item.get("rs_momentum")
if rs_ratio is None or rs_momentum is None:
return None
if rs_ratio > 100 and rs_momentum > 100:
return "Leading"
if rs_ratio > 100:
return "Weakening"
if rs_momentum <= 100:
return "Lagging"
return "Improving"
async def get_cones(symbol: str, days: int = 365) -> dict[str, Any]: async def get_cones(symbol: str, days: int = 365) -> dict[str, Any]:
"""Volatility Cones -- realized volatility quantiles for options analysis.""" """Volatility Cones -- realized volatility quantiles for options analysis."""
hist = await fetch_historical(symbol, days) hist = await fetch_historical(symbol, days)