feat: add t-SNE stock clustering and similarity search (TDD)

2 new endpoints:
- POST /portfolio/cluster - t-SNE + KMeans clustering by return
  similarity. Maps stocks to 2D coordinates with cluster labels.
- POST /portfolio/similar - find most/least similar stocks by
  return correlation against a target symbol.

Implementation:
- sklearn TSNE (method=exact) + KMeans with auto n_clusters
- Jitter handling for identical returns edge case
- 33 new tests (17 service unit + 16 route integration)
- All 503 tests passing
This commit is contained in:
Yaojia Wang
2026-03-19 22:53:27 +01:00
parent 9ee3ec9b4e
commit 4915f1bae4
4 changed files with 759 additions and 1 deletions

View File

@@ -261,3 +261,301 @@ async def test_fetch_historical_prices_skips_none(mock_fetch_hist):
assert isinstance(df, pd.DataFrame)
assert df.empty
# ---------------------------------------------------------------------------
# cluster_stocks
# ---------------------------------------------------------------------------
def _make_prices(symbols: list[str], n_days: int = 60) -> "pd.DataFrame":
"""Build a deterministic price DataFrame with enough rows for t-SNE."""
import numpy as np
import pandas as pd
rng = np.random.default_rng(42)
data = {}
for sym in symbols:
prices = 100.0 + np.cumsum(rng.normal(0, 1, n_days))
data[sym] = prices
return pd.DataFrame(data)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_happy_path(mock_fetch):
"""cluster_stocks returns valid structure for 6 symbols."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.cluster_stocks(symbols, days=180)
assert result["method"] == "t-SNE + KMeans"
assert result["days"] == 180
assert set(result["symbols"]) == set(symbols)
coords = result["coordinates"]
assert len(coords) == len(symbols)
for c in coords:
assert "symbol" in c
assert "x" in c
assert "y" in c
assert "cluster" in c
assert isinstance(c["x"], float)
assert isinstance(c["y"], float)
assert isinstance(c["cluster"], int)
clusters = result["clusters"]
assert isinstance(clusters, dict)
all_in_clusters = []
for members in clusters.values():
all_in_clusters.extend(members)
assert set(all_in_clusters) == set(symbols)
assert "n_clusters" in result
assert result["n_clusters"] >= 2
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_custom_n_clusters(mock_fetch):
"""Custom n_clusters is respected in the output."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.cluster_stocks(symbols, days=180, n_clusters=3)
assert result["n_clusters"] == 3
assert len(result["clusters"]) == 3
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_minimum_three_symbols(mock_fetch):
"""cluster_stocks works correctly with exactly 3 symbols (minimum)."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.cluster_stocks(symbols, days=180)
assert len(result["coordinates"]) == 3
assert set(result["symbols"]) == set(symbols)
@pytest.mark.asyncio
async def test_cluster_stocks_too_few_symbols_raises():
"""cluster_stocks raises ValueError when fewer than 3 symbols are provided."""
import portfolio_service
with pytest.raises(ValueError, match="at least 3"):
await portfolio_service.cluster_stocks(["AAPL", "MSFT"], days=180)
@pytest.mark.asyncio
async def test_cluster_stocks_empty_symbols_raises():
"""cluster_stocks raises ValueError for empty symbol list."""
import portfolio_service
with pytest.raises(ValueError, match="at least 3"):
await portfolio_service.cluster_stocks([], days=180)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_no_data_raises(mock_fetch):
"""cluster_stocks raises ValueError when fetch returns empty DataFrame."""
import pandas as pd
import portfolio_service
mock_fetch.return_value = pd.DataFrame()
with pytest.raises(ValueError, match="No price data"):
await portfolio_service.cluster_stocks(["AAPL", "MSFT", "GOOGL"], days=180)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_identical_returns_still_works(mock_fetch):
"""t-SNE should not raise even when all symbols have identical returns."""
import numpy as np
import pandas as pd
import portfolio_service
# All columns identical — edge case for t-SNE
flat = pd.DataFrame(
{
"AAPL": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
"MSFT": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
"GOOGL": [100.0, 101.0, 102.0, 103.0, 104.0] * 12,
}
)
mock_fetch.return_value = flat
result = await portfolio_service.cluster_stocks(
["AAPL", "MSFT", "GOOGL"], days=180
)
assert len(result["coordinates"]) == 3
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_coordinates_are_floats(mock_fetch):
"""x and y coordinates must be Python floats (JSON-serializable)."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.cluster_stocks(symbols, days=180)
for c in result["coordinates"]:
assert type(c["x"]) is float
assert type(c["y"]) is float
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_cluster_stocks_clusters_key_is_str(mock_fetch):
"""clusters dict keys must be strings (JSON object keys)."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.cluster_stocks(symbols, days=180)
for key in result["clusters"]:
assert isinstance(key, str), f"Expected str key, got {type(key)}"
# ---------------------------------------------------------------------------
# find_similar_stocks
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_happy_path(mock_fetch):
"""most_similar is sorted descending by correlation; least_similar ascending."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.find_similar_stocks(
"AAPL", ["MSFT", "GOOGL", "AMZN", "JPM", "BAC"], days=180, top_n=3
)
assert result["symbol"] == "AAPL"
most = result["most_similar"]
least = result["least_similar"]
assert len(most) <= 3
assert len(least) <= 3
# most_similar sorted descending
corrs_most = [e["correlation"] for e in most]
assert corrs_most == sorted(corrs_most, reverse=True)
# least_similar sorted ascending
corrs_least = [e["correlation"] for e in least]
assert corrs_least == sorted(corrs_least)
# Each entry has symbol and correlation
for entry in most + least:
assert "symbol" in entry
assert "correlation" in entry
assert isinstance(entry["correlation"], float)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_top_n_larger_than_universe(mock_fetch):
"""top_n larger than universe size is handled gracefully (returns all)."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.find_similar_stocks(
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=10
)
# Should return at most len(universe) entries, not crash
assert len(result["most_similar"]) <= 2
assert len(result["least_similar"]) <= 2
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_no_overlap_with_most_and_least(mock_fetch):
"""most_similar and least_similar should not contain the target symbol."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.find_similar_stocks(
"AAPL", ["MSFT", "GOOGL", "AMZN", "JPM"], days=180, top_n=2
)
all_symbols = [e["symbol"] for e in result["most_similar"] + result["least_similar"]]
assert "AAPL" not in all_symbols
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_no_data_raises(mock_fetch):
"""find_similar_stocks raises ValueError when no price data is returned."""
import pandas as pd
import portfolio_service
mock_fetch.return_value = pd.DataFrame()
with pytest.raises(ValueError, match="No price data"):
await portfolio_service.find_similar_stocks(
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=5
)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_target_not_in_data_raises(mock_fetch):
"""find_similar_stocks raises ValueError when target symbol has no data."""
import pandas as pd
import portfolio_service
# Only universe symbols have data, not the target
mock_fetch.return_value = _make_prices(["MSFT", "GOOGL"])
with pytest.raises(ValueError, match="AAPL"):
await portfolio_service.find_similar_stocks(
"AAPL", ["MSFT", "GOOGL"], days=180, top_n=5
)
@pytest.mark.asyncio
@patch("portfolio_service.fetch_historical_prices", new_callable=AsyncMock)
async def test_find_similar_stocks_default_top_n(mock_fetch):
"""Default top_n=5 returns at most 5 entries in most_similar."""
import portfolio_service
symbols = ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC", "WFC", "GS"]
mock_fetch.return_value = _make_prices(symbols)
result = await portfolio_service.find_similar_stocks(
"AAPL",
["MSFT", "GOOGL", "AMZN", "JPM", "BAC", "WFC", "GS"],
days=180,
)
assert len(result["most_similar"]) <= 5
assert len(result["least_similar"]) <= 5

View File

@@ -223,3 +223,271 @@ async def test_portfolio_risk_parity_default_days(mock_fn, client):
)
assert resp.status_code == 200
mock_fn.assert_called_once_with(["AAPL"], days=365)
# ---------------------------------------------------------------------------
# POST /api/v1/portfolio/cluster
# ---------------------------------------------------------------------------
_CLUSTER_RESULT = {
"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
"coordinates": [
{"symbol": "AAPL", "x": 12.5, "y": -3.2, "cluster": 0},
{"symbol": "MSFT", "x": 11.8, "y": -2.9, "cluster": 0},
{"symbol": "GOOGL", "x": 10.1, "y": -1.5, "cluster": 0},
{"symbol": "AMZN", "x": 9.5, "y": -0.8, "cluster": 0},
{"symbol": "JPM", "x": -5.1, "y": 8.3, "cluster": 1},
{"symbol": "BAC", "x": -4.9, "y": 7.9, "cluster": 1},
],
"clusters": {"0": ["AAPL", "MSFT", "GOOGL", "AMZN"], "1": ["JPM", "BAC"]},
"method": "t-SNE + KMeans",
"n_clusters": 2,
"days": 180,
}
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
async def test_portfolio_cluster_happy_path(mock_fn, client):
"""POST /cluster returns 200 with valid cluster result."""
mock_fn.return_value = _CLUSTER_RESULT
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"], "days": 180},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["method"] == "t-SNE + KMeans"
assert "coordinates" in data["data"]
assert "clusters" in data["data"]
mock_fn.assert_called_once_with(
["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"], days=180, n_clusters=None
)
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
async def test_portfolio_cluster_with_custom_n_clusters(mock_fn, client):
"""n_clusters is forwarded to service when provided."""
mock_fn.return_value = _CLUSTER_RESULT
resp = await client.post(
"/api/v1/portfolio/cluster",
json={
"symbols": ["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
"days": 180,
"n_clusters": 3,
},
)
assert resp.status_code == 200
mock_fn.assert_called_once_with(
["AAPL", "MSFT", "GOOGL", "AMZN", "JPM", "BAC"], days=180, n_clusters=3
)
@pytest.mark.asyncio
async def test_portfolio_cluster_too_few_symbols_returns_422(client):
"""Fewer than 3 symbols triggers Pydantic validation error (422)."""
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT"], "days": 180},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_cluster_missing_symbols_returns_422(client):
"""Missing symbols field returns 422."""
resp = await client.post("/api/v1/portfolio/cluster", json={"days": 180})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_cluster_too_many_symbols_returns_422(client):
"""More than 50 symbols returns 422."""
symbols = [f"SYM{i}" for i in range(51)]
resp = await client.post(
"/api/v1/portfolio/cluster", json={"symbols": symbols, "days": 180}
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_cluster_days_below_minimum_returns_422(client):
"""days < 30 returns 422."""
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 10},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_cluster_n_clusters_below_minimum_returns_422(client):
"""n_clusters < 2 returns 422."""
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 180, "n_clusters": 1},
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
async def test_portfolio_cluster_value_error_returns_400(mock_fn, client):
"""ValueError from service returns 400."""
mock_fn.side_effect = ValueError("at least 3 symbols required")
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 180},
)
assert resp.status_code == 400
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
async def test_portfolio_cluster_upstream_error_returns_502(mock_fn, client):
"""Unexpected exception from service returns 502."""
mock_fn.side_effect = RuntimeError("upstream failure")
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL"], "days": 180},
)
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.cluster_stocks", new_callable=AsyncMock)
async def test_portfolio_cluster_default_days(mock_fn, client):
"""Default days=180 is used when not provided."""
mock_fn.return_value = _CLUSTER_RESULT
resp = await client.post(
"/api/v1/portfolio/cluster",
json={"symbols": ["AAPL", "MSFT", "GOOGL"]},
)
assert resp.status_code == 200
mock_fn.assert_called_once_with(
["AAPL", "MSFT", "GOOGL"], days=180, n_clusters=None
)
# ---------------------------------------------------------------------------
# POST /api/v1/portfolio/similar
# ---------------------------------------------------------------------------
_SIMILAR_RESULT = {
"symbol": "AAPL",
"most_similar": [
{"symbol": "MSFT", "correlation": 0.85},
{"symbol": "GOOGL", "correlation": 0.78},
],
"least_similar": [
{"symbol": "JPM", "correlation": 0.32},
{"symbol": "BAC", "correlation": 0.28},
],
}
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
async def test_portfolio_similar_happy_path(mock_fn, client):
"""POST /similar returns 200 with most_similar and least_similar."""
mock_fn.return_value = _SIMILAR_RESULT
resp = await client.post(
"/api/v1/portfolio/similar",
json={
"symbol": "AAPL",
"universe": ["MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
"days": 180,
"top_n": 2,
},
)
assert resp.status_code == 200
data = resp.json()
assert data["success"] is True
assert data["data"]["symbol"] == "AAPL"
assert "most_similar" in data["data"]
assert "least_similar" in data["data"]
mock_fn.assert_called_once_with(
"AAPL",
["MSFT", "GOOGL", "AMZN", "JPM", "BAC"],
days=180,
top_n=2,
)
@pytest.mark.asyncio
async def test_portfolio_similar_missing_symbol_returns_422(client):
"""Missing symbol field returns 422."""
resp = await client.post(
"/api/v1/portfolio/similar",
json={"universe": ["MSFT", "GOOGL"], "days": 180},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_similar_missing_universe_returns_422(client):
"""Missing universe field returns 422."""
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "days": 180},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_similar_universe_too_small_returns_422(client):
"""universe with fewer than 2 entries returns 422."""
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "universe": ["MSFT"], "days": 180},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_portfolio_similar_top_n_below_minimum_returns_422(client):
"""top_n < 1 returns 422."""
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL"], "days": 180, "top_n": 0},
)
assert resp.status_code == 422
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
async def test_portfolio_similar_value_error_returns_400(mock_fn, client):
"""ValueError from service returns 400."""
mock_fn.side_effect = ValueError("AAPL not found in price data")
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL"], "days": 180},
)
assert resp.status_code == 400
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
async def test_portfolio_similar_upstream_error_returns_502(mock_fn, client):
"""Unexpected exception from service returns 502."""
mock_fn.side_effect = RuntimeError("upstream failure")
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL"], "days": 180},
)
assert resp.status_code == 502
@pytest.mark.asyncio
@patch("routes_portfolio.portfolio_service.find_similar_stocks", new_callable=AsyncMock)
async def test_portfolio_similar_default_top_n(mock_fn, client):
"""Default top_n=5 is passed to service when not specified."""
mock_fn.return_value = _SIMILAR_RESULT
resp = await client.post(
"/api/v1/portfolio/similar",
json={"symbol": "AAPL", "universe": ["MSFT", "GOOGL", "AMZN"]},
)
assert resp.status_code == 200
mock_fn.assert_called_once_with("AAPL", ["MSFT", "GOOGL", "AMZN"], days=180, top_n=5)