refactor: address python review findings
- Move FRED credential registration to FastAPI lifespan (was fragile import-order-dependent side-effect) - Add noqa E402 annotations for imports after curl_cffi patch - Fix all return type hints: bare dict -> dict[str, Any] - Move yfinance import to module level (was inline in functions) - Fix datetime.now() -> datetime.now(tz=timezone.utc) in openbb_service - Add try/except error handling to Group B service functions - Fix dict mutation in relative_rotation (immutable pattern) - Extract _classify_rrg_quadrant helper function - Fix type builtin shadow in routes_economy (type -> gdp_type) - Fix falsy int guard (if year: -> if year is not None:) - Remove user input echo from error messages
This commit is contained in:
@@ -174,7 +174,7 @@ async def get_fomc_documents(year: int | None = None) -> list[dict[str, Any]]:
|
||||
"""Get FOMC meeting documents (minutes, projections, etc.)."""
|
||||
try:
|
||||
kwargs: dict[str, Any] = {"provider": "federal_reserve"}
|
||||
if year:
|
||||
if year is not None:
|
||||
kwargs["year"] = year
|
||||
result = await asyncio.to_thread(
|
||||
obb.economy.fomc_documents, **kwargs
|
||||
|
||||
46
main.py
46
main.py
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
@@ -12,41 +13,52 @@ import curl_cffi.requests as _cffi_requests
|
||||
|
||||
_orig_session_init = _cffi_requests.Session.__init__
|
||||
|
||||
|
||||
def _patched_session_init(self, *args, **kwargs):
|
||||
if kwargs.get("impersonate") == "chrome":
|
||||
kwargs["impersonate"] = "safari"
|
||||
_orig_session_init(self, *args, **kwargs)
|
||||
|
||||
|
||||
_cffi_requests.Session.__init__ = _patched_session_init
|
||||
|
||||
from openbb import obb
|
||||
from config import settings
|
||||
from openbb import obb # noqa: E402 - must be after curl_cffi patch
|
||||
|
||||
# Register optional provider credentials with OpenBB
|
||||
if settings.fred_api_key:
|
||||
obb.user.credentials.fred_api_key = settings.fred_api_key
|
||||
from routes import router
|
||||
from routes_sentiment import router as sentiment_router
|
||||
from routes_macro import router as macro_router
|
||||
from routes_technical import router as technical_router
|
||||
from routes_quantitative import router as quantitative_router
|
||||
from routes_calendar import router as calendar_router
|
||||
from routes_market import router as market_router
|
||||
from routes_shorts import router as shorts_router
|
||||
from routes_fixed_income import router as fixed_income_router
|
||||
from routes_economy import router as economy_router
|
||||
from routes_surveys import router as surveys_router
|
||||
from routes_regulators import router as regulators_router
|
||||
from config import settings # noqa: E402
|
||||
from routes import router # noqa: E402
|
||||
from routes_calendar import router as calendar_router # noqa: E402
|
||||
from routes_economy import router as economy_router # noqa: E402
|
||||
from routes_fixed_income import router as fixed_income_router # noqa: E402
|
||||
from routes_macro import router as macro_router # noqa: E402
|
||||
from routes_market import router as market_router # noqa: E402
|
||||
from routes_quantitative import router as quantitative_router # noqa: E402
|
||||
from routes_regulators import router as regulators_router # noqa: E402
|
||||
from routes_sentiment import router as sentiment_router # noqa: E402
|
||||
from routes_shorts import router as shorts_router # noqa: E402
|
||||
from routes_surveys import router as surveys_router # noqa: E402
|
||||
from routes_technical import router as technical_router # noqa: E402
|
||||
|
||||
logging.basicConfig(
|
||||
level=settings.log_level.upper(),
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Register provider credentials once at startup."""
|
||||
if settings.fred_api_key:
|
||||
obb.user.credentials.fred_api_key = settings.fred_api_key
|
||||
logger.info("FRED API key registered")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="OpenBB Investment Analysis API",
|
||||
version="0.1.0",
|
||||
description="REST API for stock data and rule-based investment analysis, powered by OpenBB SDK.",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import yfinance as yf
|
||||
from openbb import obb
|
||||
|
||||
from obb_utils import to_list, first_or_empty
|
||||
@@ -12,15 +13,15 @@ logger = logging.getLogger(__name__)
|
||||
PROVIDER = "yfinance"
|
||||
|
||||
|
||||
async def get_quote(symbol: str) -> dict:
|
||||
async def get_quote(symbol: str) -> dict[str, Any]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.price.quote, symbol, provider=PROVIDER
|
||||
)
|
||||
return first_or_empty(result)
|
||||
|
||||
|
||||
async def get_historical(symbol: str, days: int = 365) -> list[dict]:
|
||||
start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
async def get_historical(symbol: str, days: int = 365) -> list[dict[str, Any]]:
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.price.historical,
|
||||
symbol,
|
||||
@@ -36,42 +37,42 @@ async def get_historical(symbol: str, days: int = 365) -> list[dict]:
|
||||
]
|
||||
|
||||
|
||||
async def get_profile(symbol: str) -> dict:
|
||||
async def get_profile(symbol: str) -> dict[str, Any]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.profile, symbol, provider=PROVIDER
|
||||
)
|
||||
return first_or_empty(result)
|
||||
|
||||
|
||||
async def get_metrics(symbol: str) -> dict:
|
||||
async def get_metrics(symbol: str) -> dict[str, Any]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.fundamental.metrics, symbol, provider=PROVIDER
|
||||
)
|
||||
return first_or_empty(result)
|
||||
|
||||
|
||||
async def get_income(symbol: str) -> list[dict]:
|
||||
async def get_income(symbol: str) -> list[dict[str, Any]]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.fundamental.income, symbol, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
|
||||
|
||||
async def get_balance(symbol: str) -> list[dict]:
|
||||
async def get_balance(symbol: str) -> list[dict[str, Any]]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.fundamental.balance, symbol, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
|
||||
|
||||
async def get_cash_flow(symbol: str) -> list[dict]:
|
||||
async def get_cash_flow(symbol: str) -> list[dict[str, Any]]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.fundamental.cash, symbol, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
|
||||
|
||||
async def get_financials(symbol: str) -> dict:
|
||||
async def get_financials(symbol: str) -> dict[str, Any]:
|
||||
income, balance, cash_flow = await asyncio.gather(
|
||||
get_income(symbol),
|
||||
get_balance(symbol),
|
||||
@@ -87,8 +88,6 @@ async def get_financials(symbol: str) -> dict:
|
||||
|
||||
async def get_price_target(symbol: str) -> float | None:
|
||||
"""Get consensus analyst price target via yfinance."""
|
||||
import yfinance as yf
|
||||
|
||||
def _fetch() -> float | None:
|
||||
t = yf.Ticker(symbol)
|
||||
return t.info.get("targetMeanPrice")
|
||||
@@ -100,14 +99,14 @@ async def get_price_target(symbol: str) -> float | None:
|
||||
return None
|
||||
|
||||
|
||||
async def get_news(symbol: str) -> list[dict]:
|
||||
async def get_news(symbol: str) -> list[dict[str, Any]]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.news.company, symbol, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
|
||||
|
||||
async def get_summary(symbol: str) -> dict:
|
||||
async def get_summary(symbol: str) -> dict[str, Any]:
|
||||
quote, profile, metrics, financials = await asyncio.gather(
|
||||
get_quote(symbol),
|
||||
get_profile(symbol),
|
||||
@@ -122,45 +121,45 @@ async def get_summary(symbol: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
async def get_gainers() -> list[dict]:
|
||||
async def get_gainers() -> list[dict[str, Any]]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.discovery.gainers, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
|
||||
|
||||
async def get_losers() -> list[dict]:
|
||||
async def get_losers() -> list[dict[str, Any]]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.discovery.losers, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
|
||||
|
||||
async def get_active() -> list[dict]:
|
||||
async def get_active() -> list[dict[str, Any]]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.discovery.active, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
|
||||
|
||||
async def get_undervalued() -> list[dict]:
|
||||
async def get_undervalued() -> list[dict[str, Any]]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.discovery.undervalued_large_caps, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
|
||||
|
||||
async def get_growth() -> list[dict]:
|
||||
async def get_growth() -> list[dict[str, Any]]:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.discovery.growth_tech, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
|
||||
|
||||
async def get_upgrades_downgrades(symbol: str, limit: int = 20) -> list[dict]:
|
||||
async def get_upgrades_downgrades(
|
||||
symbol: str, limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get analyst upgrades/downgrades via yfinance."""
|
||||
import yfinance as yf
|
||||
|
||||
def _fetch() -> list[dict[str, Any]]:
|
||||
t = yf.Ticker(symbol)
|
||||
df = t.upgrades_downgrades
|
||||
@@ -187,32 +186,54 @@ async def get_upgrades_downgrades(symbol: str, limit: int = 20) -> list[dict]:
|
||||
# --- Equity Fundamentals Extended (Group B) ---
|
||||
|
||||
|
||||
async def get_management(symbol: str) -> list[dict]:
|
||||
async def get_management(symbol: str) -> list[dict[str, Any]]:
|
||||
"""Get executive team info (name, title, compensation)."""
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.fundamental.management, symbol, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.fundamental.management, symbol, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
except Exception:
|
||||
logger.warning("Management failed for %s", symbol, exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
async def get_dividends(symbol: str) -> list[dict]:
|
||||
async def get_dividends(symbol: str) -> list[dict[str, Any]]:
|
||||
"""Get historical dividend records."""
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.fundamental.dividends, symbol, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.fundamental.dividends, symbol, provider=PROVIDER
|
||||
)
|
||||
return to_list(result)
|
||||
except Exception:
|
||||
logger.warning("Dividends failed for %s", symbol, exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
async def get_filings(symbol: str, form_type: str | None = None) -> list[dict]:
|
||||
async def get_filings(
|
||||
symbol: str, form_type: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get SEC filings (10-K, 10-Q, 8-K, etc.)."""
|
||||
kwargs: dict[str, Any] = {"symbol": symbol, "provider": "sec"}
|
||||
if form_type:
|
||||
kwargs["type"] = form_type
|
||||
result = await asyncio.to_thread(obb.equity.fundamental.filings, **kwargs)
|
||||
return to_list(result)
|
||||
try:
|
||||
kwargs: dict[str, Any] = {"symbol": symbol, "provider": "sec"}
|
||||
if form_type is not None:
|
||||
kwargs["type"] = form_type
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.fundamental.filings, **kwargs
|
||||
)
|
||||
return to_list(result)
|
||||
except Exception:
|
||||
logger.warning("Filings failed for %s", symbol, exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
async def search_company(query: str) -> list[dict]:
|
||||
async def search_company(query: str) -> list[dict[str, Any]]:
|
||||
"""Search for companies by name."""
|
||||
result = await asyncio.to_thread(obb.equity.search, query, provider="sec")
|
||||
return to_list(result)
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
obb.equity.search, query, provider="sec"
|
||||
)
|
||||
return to_list(result)
|
||||
except Exception:
|
||||
logger.warning("Company search failed for %s", query, exc_info=True)
|
||||
return []
|
||||
|
||||
@@ -167,7 +167,7 @@ async def get_rolling_stat(
|
||||
"""Compute a rolling statistic (variance, stdev, mean, skew, kurtosis, quantile)."""
|
||||
valid_stats = {"variance", "stdev", "mean", "skew", "kurtosis", "quantile"}
|
||||
if stat not in valid_stats:
|
||||
return {"symbol": symbol, "error": f"Invalid stat: {stat}. Use: {', '.join(sorted(valid_stats))}"}
|
||||
return {"symbol": symbol, "error": f"Invalid stat. Valid options: {', '.join(sorted(valid_stats))}"}
|
||||
|
||||
fetch_days = max(days, PERF_DAYS)
|
||||
hist = await fetch_historical(symbol, fetch_days)
|
||||
|
||||
@@ -23,10 +23,10 @@ async def macro_cpi(country: str = Query(default="united_states", max_length=50,
|
||||
@router.get("/macro/gdp", response_model=ApiResponse)
|
||||
@safe
|
||||
async def macro_gdp(
|
||||
type: str = Query(default="real", pattern="^(nominal|real|forecast)$"),
|
||||
gdp_type: str = Query(default="real", pattern="^(nominal|real|forecast)$"),
|
||||
):
|
||||
"""GDP: nominal, real, or forecast."""
|
||||
data = await economy_service.get_gdp(gdp_type=type)
|
||||
data = await economy_service.get_gdp(gdp_type=gdp_type)
|
||||
return ApiResponse(data=data)
|
||||
|
||||
|
||||
|
||||
@@ -411,13 +411,12 @@ async def get_relative_rotation(
|
||||
Returns RS-Ratio and RS-Momentum for each symbol, indicating
|
||||
which RRG quadrant they occupy (Leading/Weakening/Lagging/Improving).
|
||||
"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timedelta, timezone as tz
|
||||
|
||||
start = (datetime.now(tz=timezone.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
start = (datetime.now(tz=tz.utc) - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
all_symbols = ",".join(symbols + [benchmark])
|
||||
|
||||
try:
|
||||
# Fetch multi-symbol historical data in one call
|
||||
hist = await asyncio.to_thread(
|
||||
obb.equity.price.historical,
|
||||
all_symbols,
|
||||
@@ -434,26 +433,17 @@ async def get_relative_rotation(
|
||||
study=study,
|
||||
)
|
||||
items = to_list(result)
|
||||
# Return the latest data point per symbol
|
||||
latest_by_symbol: dict[str, dict] = {}
|
||||
|
||||
latest_by_symbol: dict[str, dict[str, Any]] = {}
|
||||
for item in items:
|
||||
sym = item.get("symbol")
|
||||
if sym and sym != benchmark:
|
||||
latest_by_symbol[sym] = item
|
||||
|
||||
entries = list(latest_by_symbol.values())
|
||||
for entry in entries:
|
||||
rs_ratio = entry.get("rs_ratio")
|
||||
rs_momentum = entry.get("rs_momentum")
|
||||
if rs_ratio is not None and rs_momentum is not None:
|
||||
if rs_ratio > 100 and rs_momentum > 100:
|
||||
entry["quadrant"] = "Leading"
|
||||
elif rs_ratio > 100 and rs_momentum <= 100:
|
||||
entry["quadrant"] = "Weakening"
|
||||
elif rs_ratio <= 100 and rs_momentum <= 100:
|
||||
entry["quadrant"] = "Lagging"
|
||||
else:
|
||||
entry["quadrant"] = "Improving"
|
||||
entries = [
|
||||
{**item, "quadrant": _classify_rrg_quadrant(item)}
|
||||
for item in latest_by_symbol.values()
|
||||
]
|
||||
|
||||
return {
|
||||
"symbols": symbols,
|
||||
@@ -466,6 +456,21 @@ async def get_relative_rotation(
|
||||
return {"symbols": symbols, "error": "Failed to compute relative rotation"}
|
||||
|
||||
|
||||
def _classify_rrg_quadrant(item: dict[str, Any]) -> str | None:
|
||||
"""Classify RRG quadrant from RS-Ratio and RS-Momentum."""
|
||||
rs_ratio = item.get("rs_ratio")
|
||||
rs_momentum = item.get("rs_momentum")
|
||||
if rs_ratio is None or rs_momentum is None:
|
||||
return None
|
||||
if rs_ratio > 100 and rs_momentum > 100:
|
||||
return "Leading"
|
||||
if rs_ratio > 100:
|
||||
return "Weakening"
|
||||
if rs_momentum <= 100:
|
||||
return "Lagging"
|
||||
return "Improving"
|
||||
|
||||
|
||||
async def get_cones(symbol: str, days: int = 365) -> dict[str, Any]:
|
||||
"""Volatility Cones -- realized volatility quantiles for options analysis."""
|
||||
hist = await fetch_historical(symbol, days)
|
||||
|
||||
Reference in New Issue
Block a user