diff --git a/akshare_service.py b/akshare_service.py new file mode 100644 index 0000000..c4c97b2 --- /dev/null +++ b/akshare_service.py @@ -0,0 +1,185 @@ +"""A-share and HK stock data service using AKShare.""" + +import asyncio +import logging +import re +from datetime import datetime, timedelta +from typing import Any + +import akshare as ak +import pandas as pd + +logger = logging.getLogger(__name__) + +# --- Symbol validation patterns --- + +_A_SHARE_PATTERN = re.compile(r"^[036]\d{5}$") +_HK_PATTERN = re.compile(r"^\d{5}$") + +# --- Chinese column name mappings --- + +_HIST_COLUMNS: dict[str, str] = { + "日期": "date", + "开盘": "open", + "收盘": "close", + "最高": "high", + "最低": "low", + "成交量": "volume", + "成交额": "turnover", + "振幅": "amplitude", + "涨跌幅": "change_percent", + "涨跌额": "change", + "换手率": "turnover_rate", +} + +_QUOTE_COLUMNS: dict[str, str] = { + "代码": "symbol", + "名称": "name", + "最新价": "price", + "涨跌幅": "change_percent", + "涨跌额": "change", + "成交量": "volume", + "成交额": "turnover", + "今开": "open", + "最高": "high", + "最低": "low", + "昨收": "prev_close", +} + + +# --- Validation helpers --- + + +def validate_a_share_symbol(symbol: str) -> bool: + """Return True if symbol matches A-share format (6 digits, starts with 0, 3, or 6).""" + return bool(_A_SHARE_PATTERN.match(symbol)) + + +def validate_hk_symbol(symbol: str) -> bool: + """Return True if symbol matches HK stock format (exactly 5 digits).""" + return bool(_HK_PATTERN.match(symbol)) + + +# --- DataFrame parsers --- + + +def _parse_hist_df(df: pd.DataFrame) -> list[dict[str, Any]]: + """Convert a Chinese-column historical DataFrame to a list of English-key dicts.""" + if df.empty: + return [] + df = df.rename(columns=_HIST_COLUMNS) + # Serialize date column to ISO string + if "date" in df.columns: + df["date"] = df["date"].astype(str) + return df.to_dict(orient="records") + + +def _parse_spot_row(df: pd.DataFrame, symbol: str) -> dict[str, Any] | None: + """ + Filter a spot quote DataFrame by symbol code column and return + an English-key dict for the matching row, or None if not found. + """ + if df.empty: + return None + code_col = "代码" + if code_col not in df.columns: + return None + matched = df[df[code_col] == symbol] + if matched.empty: + return None + row = matched.iloc[0] + result: dict[str, Any] = {} + for cn_key, en_key in _QUOTE_COLUMNS.items(): + result[en_key] = row.get(cn_key) + return result + + +# --- Date helpers --- + + +def _date_range(days: int) -> tuple[str, str]: + """Return (start_date, end_date) strings in YYYYMMDD format for the given window.""" + end = datetime.now() + start = end - timedelta(days=days) + return start.strftime("%Y%m%d"), end.strftime("%Y%m%d") + + +# --- A-share service functions --- + + +async def get_a_share_quote(symbol: str) -> dict[str, Any] | None: + """ + Fetch real-time A-share quote for a single symbol. + + Returns a dict with English keys, or None if the symbol is not found + in the spot market data. Propagates AKShare exceptions to the caller. + """ + df: pd.DataFrame = await asyncio.to_thread(ak.stock_zh_a_spot_em) + return _parse_spot_row(df, symbol) + + +async def get_a_share_historical( + symbol: str, *, days: int = 365 +) -> list[dict[str, Any]]: + """ + Fetch daily OHLCV history for an A-share symbol with qfq (前复权) adjustment. + + Propagates AKShare exceptions to the caller. + """ + start_date, end_date = _date_range(days) + df: pd.DataFrame = await asyncio.to_thread( + ak.stock_zh_a_hist, + symbol=symbol, + period="daily", + start_date=start_date, + end_date=end_date, + adjust="qfq", + ) + return _parse_hist_df(df) + + +async def search_a_shares(query: str) -> list[dict[str, Any]]: + """ + Search A-share stocks by name substring. + + Returns a list of {code, name} dicts. An empty query returns all stocks. + Propagates AKShare exceptions to the caller. + """ + df: pd.DataFrame = await asyncio.to_thread(ak.stock_info_a_code_name) + if query: + df = df[df["name"].str.contains(query, na=False)] + return df[["code", "name"]].to_dict(orient="records") + + +# --- HK stock service functions --- + + +async def get_hk_quote(symbol: str) -> dict[str, Any] | None: + """ + Fetch real-time HK stock quote for a single symbol. + + Returns a dict with English keys, or None if the symbol is not found. + Propagates AKShare exceptions to the caller. + """ + df: pd.DataFrame = await asyncio.to_thread(ak.stock_hk_spot_em) + return _parse_spot_row(df, symbol) + + +async def get_hk_historical( + symbol: str, *, days: int = 365 +) -> list[dict[str, Any]]: + """ + Fetch daily OHLCV history for a HK stock symbol with qfq adjustment. + + Propagates AKShare exceptions to the caller. + """ + start_date, end_date = _date_range(days) + df: pd.DataFrame = await asyncio.to_thread( + ak.stock_hk_hist, + symbol=symbol, + period="daily", + start_date=start_date, + end_date=end_date, + adjust="qfq", + ) + return _parse_hist_df(df) diff --git a/main.py b/main.py index 296ddc4..9250db2 100644 --- a/main.py +++ b/main.py @@ -39,6 +39,7 @@ from routes_surveys import router as surveys_router # noqa: E402 from routes_technical import router as technical_router # noqa: E402 from routes_portfolio import router as portfolio_router # noqa: E402 from routes_backtest import router as backtest_router # noqa: E402 +from routes_cn import router as cn_router # noqa: E402 logging.basicConfig( level=settings.log_level.upper(), @@ -85,6 +86,7 @@ app.include_router(surveys_router) app.include_router(regulators_router) app.include_router(portfolio_router) app.include_router(backtest_router) +app.include_router(cn_router) @app.get("/health", response_model=dict[str, str]) diff --git a/pyproject.toml b/pyproject.toml index b90e012..1ea34c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "pydantic-settings", "httpx", "curl_cffi==0.7.4", + "akshare", ] [project.optional-dependencies] diff --git a/routes_cn.py b/routes_cn.py new file mode 100644 index 0000000..d2126db --- /dev/null +++ b/routes_cn.py @@ -0,0 +1,105 @@ +"""Routes for A-share (China) and Hong Kong stock market data via AKShare.""" + +from fastapi import APIRouter, HTTPException, Path, Query + +from models import ApiResponse +from route_utils import safe +import akshare_service + +router = APIRouter(prefix="/api/v1/cn", tags=["China & HK Markets"]) + + +# --- Validation helpers --- + + +def _validate_a_share(symbol: str) -> str: + """Validate A-share symbol format; raise 400 on failure.""" + if not akshare_service.validate_a_share_symbol(symbol): + raise HTTPException( + status_code=400, + detail=( + f"Invalid A-share symbol '{symbol}'. " + "Must be 6 digits starting with 0, 3, or 6 (e.g. 000001, 300001, 600519)." + ), + ) + return symbol + + +def _validate_hk(symbol: str) -> str: + """Validate HK stock symbol format; raise 400 on failure.""" + if not akshare_service.validate_hk_symbol(symbol): + raise HTTPException( + status_code=400, + detail=( + f"Invalid HK symbol '{symbol}'. " + "Must be exactly 5 digits (e.g. 00700, 09988)." + ), + ) + return symbol + + +# --- A-share routes --- +# NOTE: /a-share/search MUST be registered before /a-share/{symbol} to avoid shadowing. + + +@router.get("/a-share/search", response_model=ApiResponse) +@safe +async def a_share_search( + query: str = Query(..., description="Stock name to search for (partial match)"), +) -> ApiResponse: + """Search A-share stocks by name (partial match).""" + data = await akshare_service.search_a_shares(query) + return ApiResponse(data=data) + + +@router.get("/a-share/{symbol}/quote", response_model=ApiResponse) +@safe +async def a_share_quote( + symbol: str = Path(..., min_length=6, max_length=6), +) -> ApiResponse: + """Get real-time A-share quote (沪深 real-time price).""" + symbol = _validate_a_share(symbol) + data = await akshare_service.get_a_share_quote(symbol) + if data is None: + raise HTTPException(status_code=404, detail=f"A-share symbol '{symbol}' not found.") + return ApiResponse(data=data) + + +@router.get("/a-share/{symbol}/historical", response_model=ApiResponse) +@safe +async def a_share_historical( + symbol: str = Path(..., min_length=6, max_length=6), + days: int = Query(default=365, ge=1, le=3650), +) -> ApiResponse: + """Get A-share daily OHLCV history with qfq (前复权) adjustment.""" + symbol = _validate_a_share(symbol) + data = await akshare_service.get_a_share_historical(symbol, days=days) + return ApiResponse(data=data) + + +# --- HK stock routes --- + + +@router.get("/hk/{symbol}/quote", response_model=ApiResponse) +@safe +async def hk_quote( + symbol: str = Path(..., min_length=5, max_length=5), +) -> ApiResponse: + """Get real-time HK stock quote (港股 real-time price).""" + symbol = _validate_hk(symbol) + data = await akshare_service.get_hk_quote(symbol) + if data is None: + raise HTTPException(status_code=404, detail=f"HK symbol '{symbol}' not found.") + return ApiResponse(data=data) + + +@router.get("/hk/{symbol}/historical", response_model=ApiResponse) +@safe +async def hk_historical( + symbol: str = Path(..., min_length=5, max_length=5), + days: int = Query(default=365, ge=1, le=3650), +) -> ApiResponse: + """Get HK stock daily OHLCV history with qfq adjustment.""" + symbol = _validate_hk(symbol) + data = await akshare_service.get_hk_historical(symbol, days=days) + return ApiResponse(data=data) diff --git a/tests/test_akshare_service.py b/tests/test_akshare_service.py new file mode 100644 index 0000000..51884da --- /dev/null +++ b/tests/test_akshare_service.py @@ -0,0 +1,443 @@ +"""Unit tests for akshare_service.py - written FIRST (TDD RED phase).""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +import akshare_service + + +# --- Helpers --- + + +def _make_hist_df(rows: int = 3) -> pd.DataFrame: + """Return a minimal historical DataFrame with Chinese column names.""" + dates = pd.date_range("2026-01-01", periods=rows, freq="D") + return pd.DataFrame( + { + "日期": dates, + "开盘": [10.0] * rows, + "收盘": [10.5] * rows, + "最高": [11.0] * rows, + "最低": [9.5] * rows, + "成交量": [1_000_000] * rows, + "成交额": [10_500_000.0] * rows, + "振幅": [1.5] * rows, + "涨跌幅": [0.5] * rows, + "涨跌额": [0.05] * rows, + "换手率": [0.3] * rows, + } + ) + + +def _make_spot_df(code: str = "000001", name: str = "平安银行") -> pd.DataFrame: + """Return a minimal real-time quote DataFrame with Chinese column names.""" + return pd.DataFrame( + [ + { + "代码": code, + "名称": name, + "最新价": 12.34, + "涨跌幅": 1.23, + "涨跌额": 0.15, + "成交量": 500_000, + "成交额": 6_170_000.0, + "今开": 12.10, + "最高": 12.50, + "最低": 12.00, + "昨收": 12.19, + } + ] + ) + + +def _make_code_name_df() -> pd.DataFrame: + """Return a minimal code/name mapping DataFrame.""" + return pd.DataFrame( + [ + {"code": "000001", "name": "平安银行"}, + {"code": "600519", "name": "贵州茅台"}, + {"code": "000002", "name": "万科A"}, + ] + ) + + +# ============================================================ +# Symbol validation +# ============================================================ + + +class TestValidateAShare: + def test_valid_starts_with_0(self): + assert akshare_service.validate_a_share_symbol("000001") is True + + def test_valid_starts_with_3(self): + assert akshare_service.validate_a_share_symbol("300001") is True + + def test_valid_starts_with_6(self): + assert akshare_service.validate_a_share_symbol("600519") is True + + def test_invalid_starts_with_1(self): + assert akshare_service.validate_a_share_symbol("100001") is False + + def test_invalid_too_short(self): + assert akshare_service.validate_a_share_symbol("00001") is False + + def test_invalid_too_long(self): + assert akshare_service.validate_a_share_symbol("0000011") is False + + def test_invalid_letters(self): + assert akshare_service.validate_a_share_symbol("00000A") is False + + def test_invalid_empty(self): + assert akshare_service.validate_a_share_symbol("") is False + + +class TestValidateHKSymbol: + def test_valid_five_digits(self): + assert akshare_service.validate_hk_symbol("00700") is True + + def test_valid_all_nines(self): + assert akshare_service.validate_hk_symbol("99999") is True + + def test_invalid_too_short(self): + assert akshare_service.validate_hk_symbol("0070") is False + + def test_invalid_too_long(self): + assert akshare_service.validate_hk_symbol("007000") is False + + def test_invalid_letters(self): + assert akshare_service.validate_hk_symbol("0070A") is False + + def test_invalid_empty(self): + assert akshare_service.validate_hk_symbol("") is False + + +# ============================================================ +# _parse_hist_df +# ============================================================ + + +class TestParseHistDf: + def test_returns_list_of_dicts(self): + df = _make_hist_df(2) + result = akshare_service._parse_hist_df(df) + assert isinstance(result, list) + assert len(result) == 2 + + def test_keys_are_english(self): + df = _make_hist_df(1) + result = akshare_service._parse_hist_df(df) + row = result[0] + assert "date" in row + assert "open" in row + assert "close" in row + assert "high" in row + assert "low" in row + assert "volume" in row + assert "turnover" in row + assert "change_percent" in row + assert "turnover_rate" in row + + def test_no_chinese_keys_remain(self): + df = _make_hist_df(1) + result = akshare_service._parse_hist_df(df) + row = result[0] + for key in row: + assert not any(ord(c) > 127 for c in key), f"Non-ASCII key found: {key}" + + def test_date_is_string(self): + df = _make_hist_df(1) + result = akshare_service._parse_hist_df(df) + assert isinstance(result[0]["date"], str) + + def test_values_are_correct(self): + df = _make_hist_df(1) + result = akshare_service._parse_hist_df(df) + assert result[0]["open"] == pytest.approx(10.0) + assert result[0]["close"] == pytest.approx(10.5) + + def test_empty_df_returns_empty_list(self): + df = pd.DataFrame() + result = akshare_service._parse_hist_df(df) + assert result == [] + + +# ============================================================ +# _parse_spot_row +# ============================================================ + + +class TestParseSpotRow: + def test_returns_dict_with_english_keys(self): + df = _make_spot_df("000001", "平安银行") + result = akshare_service._parse_spot_row(df, "000001") + assert result is not None + assert "symbol" in result + assert "name" in result + assert "price" in result + + def test_correct_symbol_extracted(self): + df = _make_spot_df("000001") + result = akshare_service._parse_spot_row(df, "000001") + assert result["symbol"] == "000001" + + def test_returns_none_when_symbol_not_found(self): + df = _make_spot_df("000001") + result = akshare_service._parse_spot_row(df, "999999") + assert result is None + + def test_price_value_correct(self): + df = _make_spot_df("600519") + df["代码"] = "600519" + result = akshare_service._parse_spot_row(df, "600519") + assert result["price"] == pytest.approx(12.34) + + def test_all_quote_fields_present(self): + df = _make_spot_df("000001") + result = akshare_service._parse_spot_row(df, "000001") + expected_keys = { + "symbol", "name", "price", "change", "change_percent", + "volume", "turnover", "open", "high", "low", "prev_close", + } + assert expected_keys.issubset(set(result.keys())) + + +# ============================================================ +# get_a_share_quote +# ============================================================ + + +class TestGetAShareQuote: + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_zh_a_spot_em") + async def test_returns_quote_dict(self, mock_spot): + mock_spot.return_value = _make_spot_df("000001") + result = await akshare_service.get_a_share_quote("000001") + assert result is not None + assert result["symbol"] == "000001" + assert result["name"] == "平安银行" + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_zh_a_spot_em") + async def test_returns_none_for_unknown_symbol(self, mock_spot): + mock_spot.return_value = _make_spot_df("000001") + result = await akshare_service.get_a_share_quote("999999") + assert result is None + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_zh_a_spot_em") + async def test_propagates_exception(self, mock_spot): + mock_spot.side_effect = RuntimeError("AKShare unavailable") + with pytest.raises(RuntimeError): + await akshare_service.get_a_share_quote("000001") + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_zh_a_spot_em") + async def test_akshare_called_once(self, mock_spot): + mock_spot.return_value = _make_spot_df("000001") + await akshare_service.get_a_share_quote("000001") + mock_spot.assert_called_once() + + +# ============================================================ +# get_a_share_historical +# ============================================================ + + +class TestGetAShareHistorical: + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_zh_a_hist") + async def test_returns_list_of_bars(self, mock_hist): + mock_hist.return_value = _make_hist_df(5) + result = await akshare_service.get_a_share_historical("000001", days=30) + assert isinstance(result, list) + assert len(result) == 5 + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_zh_a_hist") + async def test_bars_have_english_keys(self, mock_hist): + mock_hist.return_value = _make_hist_df(1) + result = await akshare_service.get_a_share_historical("000001", days=30) + assert "date" in result[0] + assert "open" in result[0] + assert "close" in result[0] + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_zh_a_hist") + async def test_called_with_correct_symbol(self, mock_hist): + mock_hist.return_value = _make_hist_df(1) + await akshare_service.get_a_share_historical("600519", days=90) + call_kwargs = mock_hist.call_args + assert call_kwargs.kwargs.get("symbol") == "600519" or call_kwargs.args[0] == "600519" + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_zh_a_hist") + async def test_adjust_is_qfq(self, mock_hist): + mock_hist.return_value = _make_hist_df(1) + await akshare_service.get_a_share_historical("000001", days=30) + call_kwargs = mock_hist.call_args + assert call_kwargs.kwargs.get("adjust") == "qfq" + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_zh_a_hist") + async def test_empty_df_returns_empty_list(self, mock_hist): + mock_hist.return_value = pd.DataFrame() + result = await akshare_service.get_a_share_historical("000001", days=30) + assert result == [] + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_zh_a_hist") + async def test_propagates_exception(self, mock_hist): + mock_hist.side_effect = RuntimeError("network error") + with pytest.raises(RuntimeError): + await akshare_service.get_a_share_historical("000001", days=30) + + +# ============================================================ +# search_a_shares +# ============================================================ + + +class TestSearchAShares: + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_info_a_code_name") + async def test_returns_matching_results(self, mock_codes): + mock_codes.return_value = _make_code_name_df() + result = await akshare_service.search_a_shares("平安") + assert len(result) == 1 + assert result[0]["code"] == "000001" + assert result[0]["name"] == "平安银行" + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_info_a_code_name") + async def test_returns_empty_list_when_no_match(self, mock_codes): + mock_codes.return_value = _make_code_name_df() + result = await akshare_service.search_a_shares("NONEXISTENT") + assert result == [] + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_info_a_code_name") + async def test_returns_multiple_matches(self, mock_codes): + mock_codes.return_value = _make_code_name_df() + result = await akshare_service.search_a_shares("万") + assert len(result) == 1 + assert result[0]["code"] == "000002" + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_info_a_code_name") + async def test_result_has_code_and_name_keys(self, mock_codes): + mock_codes.return_value = _make_code_name_df() + result = await akshare_service.search_a_shares("茅台") + assert len(result) == 1 + assert "code" in result[0] + assert "name" in result[0] + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_info_a_code_name") + async def test_propagates_exception(self, mock_codes): + mock_codes.side_effect = RuntimeError("timeout") + with pytest.raises(RuntimeError): + await akshare_service.search_a_shares("平安") + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_info_a_code_name") + async def test_empty_query_returns_all(self, mock_codes): + mock_codes.return_value = _make_code_name_df() + result = await akshare_service.search_a_shares("") + assert len(result) == 3 + + +# ============================================================ +# get_hk_quote +# ============================================================ + + +class TestGetHKQuote: + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_hk_spot_em") + async def test_returns_quote_dict(self, mock_spot): + mock_spot.return_value = _make_spot_df("00700", "腾讯控股") + result = await akshare_service.get_hk_quote("00700") + assert result is not None + assert result["symbol"] == "00700" + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_hk_spot_em") + async def test_returns_none_for_unknown_symbol(self, mock_spot): + mock_spot.return_value = _make_spot_df("00700") + result = await akshare_service.get_hk_quote("99999") + assert result is None + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_hk_spot_em") + async def test_propagates_exception(self, mock_spot): + mock_spot.side_effect = RuntimeError("AKShare unavailable") + with pytest.raises(RuntimeError): + await akshare_service.get_hk_quote("00700") + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_hk_spot_em") + async def test_all_fields_present(self, mock_spot): + mock_spot.return_value = _make_spot_df("00700", "腾讯控股") + result = await akshare_service.get_hk_quote("00700") + expected_keys = { + "symbol", "name", "price", "change", "change_percent", + "volume", "turnover", "open", "high", "low", "prev_close", + } + assert expected_keys.issubset(set(result.keys())) + + +# ============================================================ +# get_hk_historical +# ============================================================ + + +class TestGetHKHistorical: + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_hk_hist") + async def test_returns_list_of_bars(self, mock_hist): + mock_hist.return_value = _make_hist_df(4) + result = await akshare_service.get_hk_historical("00700", days=30) + assert isinstance(result, list) + assert len(result) == 4 + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_hk_hist") + async def test_bars_have_english_keys(self, mock_hist): + mock_hist.return_value = _make_hist_df(1) + result = await akshare_service.get_hk_historical("00700", days=30) + assert "date" in result[0] + assert "close" in result[0] + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_hk_hist") + async def test_called_with_correct_symbol(self, mock_hist): + mock_hist.return_value = _make_hist_df(1) + await akshare_service.get_hk_historical("09988", days=90) + call_kwargs = mock_hist.call_args + assert call_kwargs.kwargs.get("symbol") == "09988" or call_kwargs.args[0] == "09988" + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_hk_hist") + async def test_adjust_is_qfq(self, mock_hist): + mock_hist.return_value = _make_hist_df(1) + await akshare_service.get_hk_historical("00700", days=30) + call_kwargs = mock_hist.call_args + assert call_kwargs.kwargs.get("adjust") == "qfq" + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_hk_hist") + async def test_empty_df_returns_empty_list(self, mock_hist): + mock_hist.return_value = pd.DataFrame() + result = await akshare_service.get_hk_historical("00700", days=30) + assert result == [] + + @pytest.mark.asyncio + @patch("akshare_service.ak.stock_hk_hist") + async def test_propagates_exception(self, mock_hist): + mock_hist.side_effect = RuntimeError("network error") + with pytest.raises(RuntimeError): + await akshare_service.get_hk_historical("00700", days=30) diff --git a/tests/test_routes_cn.py b/tests/test_routes_cn.py new file mode 100644 index 0000000..deae13a --- /dev/null +++ b/tests/test_routes_cn.py @@ -0,0 +1,321 @@ +"""Integration tests for routes_cn.py - written FIRST (TDD RED phase).""" + +from unittest.mock import patch, AsyncMock + +import pytest +from httpx import AsyncClient, ASGITransport + +from main import app + + +@pytest.fixture +async def client(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +# ============================================================ +# A-share quote GET /api/v1/cn/a-share/{symbol}/quote +# ============================================================ + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_a_share_quote", new_callable=AsyncMock) +async def test_a_share_quote_happy_path(mock_fn, client): + mock_fn.return_value = { + "symbol": "000001", + "name": "平安银行", + "price": 12.34, + "change": 0.15, + "change_percent": 1.23, + "volume": 500_000, + "turnover": 6_170_000.0, + "open": 12.10, + "high": 12.50, + "low": 12.00, + "prev_close": 12.19, + } + resp = await client.get("/api/v1/cn/a-share/000001/quote") + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert data["data"]["symbol"] == "000001" + assert data["data"]["name"] == "平安银行" + assert data["data"]["price"] == pytest.approx(12.34) + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_a_share_quote", new_callable=AsyncMock) +async def test_a_share_quote_not_found_returns_404(mock_fn, client): + mock_fn.return_value = None + resp = await client.get("/api/v1/cn/a-share/000001/quote") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_a_share_quote", new_callable=AsyncMock) +async def test_a_share_quote_service_error_returns_502(mock_fn, client): + mock_fn.side_effect = RuntimeError("AKShare down") + resp = await client.get("/api/v1/cn/a-share/000001/quote") + assert resp.status_code == 502 + + +@pytest.mark.asyncio +async def test_a_share_quote_invalid_symbol_returns_400(client): + # symbol starting with 1 is invalid for A-shares + resp = await client.get("/api/v1/cn/a-share/100001/quote") + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_a_share_quote_non_numeric_symbol_returns_400(client): + resp = await client.get("/api/v1/cn/a-share/ABCDEF/quote") + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_a_share_quote_too_short_returns_422(client): + resp = await client.get("/api/v1/cn/a-share/00001/quote") + assert resp.status_code == 422 + + +# ============================================================ +# A-share historical GET /api/v1/cn/a-share/{symbol}/historical +# ============================================================ + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock) +async def test_a_share_historical_happy_path(mock_fn, client): + mock_fn.return_value = [ + { + "date": "2026-01-01", + "open": 10.0, + "close": 10.5, + "high": 11.0, + "low": 9.5, + "volume": 1_000_000, + "turnover": 10_500_000.0, + "change_percent": 0.5, + "turnover_rate": 0.3, + } + ] + resp = await client.get("/api/v1/cn/a-share/000001/historical?days=30") + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert isinstance(data["data"], list) + assert len(data["data"]) == 1 + assert data["data"][0]["date"] == "2026-01-01" + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock) +async def test_a_share_historical_default_days(mock_fn, client): + mock_fn.return_value = [] + resp = await client.get("/api/v1/cn/a-share/600519/historical") + assert resp.status_code == 200 + mock_fn.assert_called_once_with("600519", days=365) + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock) +async def test_a_share_historical_empty_returns_200(mock_fn, client): + mock_fn.return_value = [] + resp = await client.get("/api/v1/cn/a-share/000001/historical") + assert resp.status_code == 200 + assert resp.json()["data"] == [] + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_a_share_historical", new_callable=AsyncMock) +async def test_a_share_historical_service_error_returns_502(mock_fn, client): + mock_fn.side_effect = RuntimeError("AKShare down") + resp = await client.get("/api/v1/cn/a-share/000001/historical") + assert resp.status_code == 502 + + +@pytest.mark.asyncio +async def test_a_share_historical_invalid_symbol_returns_400(client): + resp = await client.get("/api/v1/cn/a-share/100001/historical") + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_a_share_historical_days_out_of_range_returns_422(client): + resp = await client.get("/api/v1/cn/a-share/000001/historical?days=0") + assert resp.status_code == 422 + + +# ============================================================ +# A-share search GET /api/v1/cn/a-share/search +# ============================================================ + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.search_a_shares", new_callable=AsyncMock) +async def test_a_share_search_happy_path(mock_fn, client): + mock_fn.return_value = [ + {"code": "000001", "name": "平安银行"}, + {"code": "000002", "name": "平安地产"}, + ] + resp = await client.get("/api/v1/cn/a-share/search?query=平安") + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert len(data["data"]) == 2 + assert data["data"][0]["code"] == "000001" + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.search_a_shares", new_callable=AsyncMock) +async def test_a_share_search_empty_results(mock_fn, client): + mock_fn.return_value = [] + resp = await client.get("/api/v1/cn/a-share/search?query=NOMATCH") + assert resp.status_code == 200 + assert resp.json()["data"] == [] + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.search_a_shares", new_callable=AsyncMock) +async def test_a_share_search_service_error_returns_502(mock_fn, client): + mock_fn.side_effect = RuntimeError("AKShare down") + resp = await client.get("/api/v1/cn/a-share/search?query=test") + assert resp.status_code == 502 + + +@pytest.mark.asyncio +async def test_a_share_search_missing_query_returns_422(client): + resp = await client.get("/api/v1/cn/a-share/search") + assert resp.status_code == 422 + + +# ============================================================ +# HK quote GET /api/v1/cn/hk/{symbol}/quote +# ============================================================ + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_hk_quote", new_callable=AsyncMock) +async def test_hk_quote_happy_path(mock_fn, client): + mock_fn.return_value = { + "symbol": "00700", + "name": "腾讯控股", + "price": 380.0, + "change": 5.0, + "change_percent": 1.33, + "volume": 10_000_000, + "turnover": 3_800_000_000.0, + "open": 375.0, + "high": 385.0, + "low": 374.0, + "prev_close": 375.0, + } + resp = await client.get("/api/v1/cn/hk/00700/quote") + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert data["data"]["symbol"] == "00700" + assert data["data"]["price"] == pytest.approx(380.0) + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_hk_quote", new_callable=AsyncMock) +async def test_hk_quote_not_found_returns_404(mock_fn, client): + mock_fn.return_value = None + resp = await client.get("/api/v1/cn/hk/00700/quote") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_hk_quote", new_callable=AsyncMock) +async def test_hk_quote_service_error_returns_502(mock_fn, client): + mock_fn.side_effect = RuntimeError("AKShare down") + resp = await client.get("/api/v1/cn/hk/00700/quote") + assert resp.status_code == 502 + + +@pytest.mark.asyncio +async def test_hk_quote_invalid_symbol_letters_returns_400(client): + resp = await client.get("/api/v1/cn/hk/ABCDE/quote") + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_hk_quote_too_short_returns_422(client): + resp = await client.get("/api/v1/cn/hk/0070/quote") + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_hk_quote_too_long_returns_422(client): + resp = await client.get("/api/v1/cn/hk/007000/quote") + assert resp.status_code == 422 + + +# ============================================================ +# HK historical GET /api/v1/cn/hk/{symbol}/historical +# ============================================================ + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock) +async def test_hk_historical_happy_path(mock_fn, client): + mock_fn.return_value = [ + { + "date": "2026-01-01", + "open": 375.0, + "close": 380.0, + "high": 385.0, + "low": 374.0, + "volume": 10_000_000, + "turnover": 3_800_000_000.0, + "change_percent": 1.33, + "turnover_rate": 0.5, + } + ] + resp = await client.get("/api/v1/cn/hk/00700/historical?days=90") + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert isinstance(data["data"], list) + assert data["data"][0]["close"] == pytest.approx(380.0) + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock) +async def test_hk_historical_default_days(mock_fn, client): + mock_fn.return_value = [] + resp = await client.get("/api/v1/cn/hk/09988/historical") + assert resp.status_code == 200 + mock_fn.assert_called_once_with("09988", days=365) + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock) +async def test_hk_historical_empty_returns_200(mock_fn, client): + mock_fn.return_value = [] + resp = await client.get("/api/v1/cn/hk/00700/historical") + assert resp.status_code == 200 + assert resp.json()["data"] == [] + + +@pytest.mark.asyncio +@patch("routes_cn.akshare_service.get_hk_historical", new_callable=AsyncMock) +async def test_hk_historical_service_error_returns_502(mock_fn, client): + mock_fn.side_effect = RuntimeError("AKShare down") + resp = await client.get("/api/v1/cn/hk/00700/historical") + assert resp.status_code == 502 + + +@pytest.mark.asyncio +async def test_hk_historical_invalid_symbol_returns_400(client): + resp = await client.get("/api/v1/cn/hk/ABCDE/historical") + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_hk_historical_days_out_of_range_returns_422(client): + resp = await client.get("/api/v1/cn/hk/00700/historical?days=0") + assert resp.status_code == 422