feat: add t-SNE stock clustering and similarity search (TDD)
2 new endpoints: - POST /portfolio/cluster - t-SNE + KMeans clustering by return similarity. Maps stocks to 2D coordinates with cluster labels. - POST /portfolio/similar - find most/least similar stocks by return correlation against a target symbol. Implementation: - sklearn TSNE (method=exact) + KMeans with auto n_clusters - Jitter handling for identical returns edge case - 33 new tests (17 service unit + 16 route integration) - All 503 tests passing
This commit is contained in:
@@ -261,3 +261,301 @@ async def test_fetch_historical_prices_skips_none(mock_fetch_hist):
|
||||
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert df.empty
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cluster_stocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_prices(symbols: list[str], n_days: int = 60) -> "pd.DataFrame":
|
||||
"""Build a deterministic price DataFrame with enough rows for t-SNE."""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
data = {}
|
||||
for sym in symbols:
|
||||
prices = 100.0 + np.cumsum(rng.normal(0, 1, n_days))
|
||||
data[sym] = prices
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_happy_path(mock_fetch):
|
||||
"""cluster_stocks returns valid structure for 6 symbols."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.cluster_stocks(symbols, days=180)
|
||||
|
||||
assert result["method"] == "t-SNE + KMeans"
|
||||
assert result["days"] == 180
|
||||
assert set(result["symbols"]) == set(symbols)
|
||||
|
||||
coords = result["coordinates"]
|
||||
assert len(coords) == len(symbols)
|
||||
for c in coords:
|
||||
assert "symbol" in c
|
||||
assert "x" in c
|
||||
assert "y" in c
|
||||
assert "cluster" in c
|
||||
assert isinstance(c["x"], float)
|
||||
assert isinstance(c["y"], float)
|
||||
assert isinstance(c["cluster"], int)
|
||||
|
||||
clusters = result["clusters"]
|
||||
assert isinstance(clusters, dict)
|
||||
all_in_clusters = []
|
||||
for members in clusters.values():
|
||||
all_in_clusters.extend(members)
|
||||
assert set(all_in_clusters) == set(symbols)
|
||||
|
||||
assert "n_clusters" in result
|
||||
assert result["n_clusters"] >= 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_custom_n_clusters(mock_fetch):
|
||||
"""Custom n_clusters is respected in the output."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.cluster_stocks(symbols, days=180, n_clusters=3)
|
||||
|
||||
assert result["n_clusters"] == 3
|
||||
assert len(result["clusters"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_minimum_three_symbols(mock_fetch):
|
||||
"""cluster_stocks works correctly with exactly 3 symbols (minimum)."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.cluster_stocks(symbols, days=180)
|
||||
|
||||
assert len(result["coordinates"]) == 3
|
||||
assert set(result["symbols"]) == set(symbols)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cluster_stocks_too_few_symbols_raises():
|
||||
"""cluster_stocks raises ValueError when fewer than 3 symbols are provided."""
|
||||
import portfolio_service
|
||||
|
||||
with pytest.raises(ValueError, match="at least 3"):
|
||||
await portfolio_service.cluster_stocks(["AAPL", "MSFT"], days=180)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cluster_stocks_empty_symbols_raises():
|
||||
"""cluster_stocks raises ValueError for empty symbol list."""
|
||||
import portfolio_service
|
||||
|
||||
with pytest.raises(ValueError, match="at least 3"):
|
||||
await portfolio_service.cluster_stocks([], days=180)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_no_data_raises(mock_fetch):
|
||||
"""cluster_stocks raises ValueError when fetch returns empty DataFrame."""
|
||||
import pandas as pd
|
||||
import portfolio_service
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame()
|
||||
|
||||
with pytest.raises(ValueError, match="No price data"):
|
||||
await portfolio_service.cluster_stocks(["AAPL", "MSFT", "GOOGL"], days=180)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_identical_returns_still_works(mock_fetch):
|
||||
"""t-SNE should not raise even when all symbols have identical returns."""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import portfolio_service
|
||||
|
||||
# All columns identical — edge case for t-SNE
|
||||
flat = pd.DataFrame(
|
||||
{
|
||||
"AAPL": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
|
||||
"MSFT": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
|
||||
"GOOGL": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
|
||||
}
|
||||
)
|
||||
|
||||
mock_fetch.return_value = flat
|
||||
|
||||
result = await portfolio_service.cluster_stocks(
|
||||
["AAPL", "MSFT", "GOOGL"], days=180
|
||||
)
|
||||
|
||||
assert len(result["coordinates"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_coordinates_are_floats(mock_fetch):
|
||||
"""x and y coordinates must be Python floats (JSON-serializable)."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.cluster_stocks(symbols, days=180)
|
||||
|
||||
for c in result["coordinates"]:
|
||||
assert type(c["x"]) is float
|
||||
assert type(c["y"]) is float
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_cluster_stocks_clusters_key_is_str(mock_fetch):
|
||||
"""clusters dict keys must be strings (JSON object keys)."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.cluster_stocks(symbols, days=180)
|
||||
|
||||
for key in result["clusters"]:
|
||||
assert isinstance(key, str), f"Expected str key, got {type(key)}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_similar_stocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_happy_path(mock_fetch):
|
||||
"""most_similar is sorted descending by correlation; least_similar ascending."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.find_similar_stocks(
|
||||
"AAPL", ["MSFT", "GOOGL", "AMZN", "JPM", "BAC"], days=180, top_n=3
|
||||
)
|
||||
|
||||
assert result["symbol"] == "AAPL"
|
||||
most = result["most_similar"]
|
||||
least = result["least_similar"]
|
||||
|
||||
assert len(most) <= 3
|
||||
assert len(least) <= 3
|
||||
|
||||
# most_similar sorted descending
|
||||
corrs_most = [e["correlation"] for e in most]
|
||||
assert corrs_most == sorted(corrs_most, reverse=True)
|
||||
|
||||
# least_similar sorted ascending
|
||||
corrs_least = [e["correlation"] for e in least]
|
||||
assert corrs_least == sorted(corrs_least)
|
||||
|
||||
# Each entry has symbol and correlation
|
||||
for entry in most + least:
|
||||
assert "symbol" in entry
|
||||
assert "correlation" in entry
|
||||
assert isinstance(entry["correlation"], float)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_top_n_larger_than_universe(mock_fetch):
|
||||
"""top_n larger than universe size is handled gracefully (returns all)."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.find_similar_stocks(
|
||||
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=10
|
||||
)
|
||||
|
||||
# Should return at most len(universe) entries, not crash
|
||||
assert len(result["most_similar"]) <= 2
|
||||
assert len(result["least_similar"]) <= 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_no_overlap_with_most_and_least(mock_fetch):
|
||||
"""most_similar and least_similar should not contain the target symbol."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.find_similar_stocks(
|
||||
"AAPL", ["MSFT", "GOOGL", "AMZN", "JPM"], days=180, top_n=2
|
||||
)
|
||||
|
||||
all_symbols = [e["symbol"] for e in result["most_similar"] + result["least_similar"]]
|
||||
assert "AAPL" not in all_symbols
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_no_data_raises(mock_fetch):
|
||||
"""find_similar_stocks raises ValueError when no price data is returned."""
|
||||
import pandas as pd
|
||||
import portfolio_service
|
||||
|
||||
mock_fetch.return_value = pd.DataFrame()
|
||||
|
||||
with pytest.raises(ValueError, match="No price data"):
|
||||
await portfolio_service.find_similar_stocks(
|
||||
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_target_not_in_data_raises(mock_fetch):
|
||||
"""find_similar_stocks raises ValueError when target symbol has no data."""
|
||||
import pandas as pd
|
||||
import portfolio_service
|
||||
|
||||
# Only universe symbols have data, not the target
|
||||
mock_fetch.return_value = _make_prices(["MSFT", "GOOGL"])
|
||||
|
||||
with pytest.raises(ValueError, match="AAPL"):
|
||||
await portfolio_service.find_similar_stocks(
|
||||
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
|
||||
async def test_find_similar_stocks_default_top_n(mock_fetch):
|
||||
"""Default top_n=5 returns at most 5 entries in most_similar."""
|
||||
import portfolio_service
|
||||
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC", "WFC", "GS"]
|
||||
mock_fetch.return_value = _make_prices(symbols)
|
||||
|
||||
result = await portfolio_service.find_similar_stocks(
|
||||
"AAPL",
|
||||
["MSFT", "GOOGL", "AMZN", "JPM", "BAC", "WFC", "GS"],
|
||||
days=180,
|
||||
)
|
||||
|
||||
assert len(result["most_similar"]) <= 5
|
||||
assert len(result["least_similar"]) <= 5
|
||||
|
||||
Reference in New Issue
Block a user