From 260ff04a2ceca16a933527f9e8d03602a0dc47d7 Mon Sep 17 00:00:00 2001 From: Alessandro Date: Thu, 30 Apr 2026 18:19:06 +0200 Subject: [PATCH 1/3] add alpaca historicals provider --- pyproject.toml | 6 +- src/ml4t/data/providers/AGENTS.md | 1 + src/ml4t/data/providers/__init__.py | 9 + src/ml4t/data/providers/alpaca.py | 493 ++++++++++++++++++++++++++++ tests/test_alpaca_provider_unit.py | 315 ++++++++++++++++++ uv.lock | 91 ++++- 6 files changed, 913 insertions(+), 2 deletions(-) create mode 100644 src/ml4t/data/providers/alpaca.py create mode 100644 tests/test_alpaca_provider_unit.py diff --git a/pyproject.toml b/pyproject.toml index 19376bd..bbfc00e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,9 +110,12 @@ oanda = [ cot = [ "cot-reports>=0.1.0", # CFTC Commitment of Traders data ] +alpaca = [ + "alpaca-py>=0.30.0", +] all-providers = [ - "ml4t-data[yahoo,databento,oanda,cot]", + "ml4t-data[yahoo,databento,oanda,cot,alpaca]", ] # Development dependencies @@ -171,6 +174,7 @@ dev = [ "yfinance>=0.2.0", "oandapyV20>=0.7.0", "xlsxwriter>=3.1.0", + "alpaca-py>=0.30.0", ] [tool.pytest.ini_options] diff --git a/src/ml4t/data/providers/AGENTS.md b/src/ml4t/data/providers/AGENTS.md index f0e3bc2..c3b437d 100644 --- a/src/ml4t/data/providers/AGENTS.md +++ b/src/ml4t/data/providers/AGENTS.md @@ -13,6 +13,7 @@ | File | Lines | Purpose | |------|-------|---------| | yahoo.py | 603 | Yahoo Finance (free) | +| alpaca.py | — | Alpaca Markets OHLCV | | binance_api.py | 410 | Binance REST API | | binance_bulk.py | 1430 | Binance bulk historical archive | | eodhd.py | 464 | EOD Historical Data | diff --git a/src/ml4t/data/providers/__init__.py b/src/ml4t/data/providers/__init__.py index 9a66521..005ceb6 100644 --- a/src/ml4t/data/providers/__init__.py +++ b/src/ml4t/data/providers/__init__.py @@ -5,6 +5,7 @@ Available Providers (20 live + 3 synthetic/testing): - BaseProvider: Abstract base class for all providers - YahooFinanceProvider: Yahoo Finance (free, no API key) + - AlpacaProvider: Alpaca Markets bars (optional: pip install 'ml4t-data[alpaca]') - TiingoProvider: Tiingo stocks (free tier: 1000 req/day, 500 symbols/month) - FinnhubProvider: Finnhub multi-asset data (free tier: 60 req/min) - EODHDProvider: EODHD global equities (free tier: 500 req/day, 1 year depth) @@ -48,6 +49,12 @@ except ImportError: YahooFinanceProvider = None # type: ignore +try: + from ml4t.data.providers.alpaca import AlpacaHistoricalProvider, AlpacaProvider +except ImportError: + AlpacaHistoricalProvider = None # type: ignore + AlpacaProvider = None # type: ignore + try: from ml4t.data.providers.tiingo import TiingoProvider except ImportError: @@ -148,6 +155,8 @@ "Provider", # Equity providers "YahooFinanceProvider", + "AlpacaHistoricalProvider", + "AlpacaProvider", "TiingoProvider", "FinnhubProvider", "EODHDProvider", diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py new file mode 100644 index 0000000..2a26c2b --- /dev/null +++ b/src/ml4t/data/providers/alpaca.py @@ -0,0 +1,493 @@ +"""Alpaca Markets historical OHLCV bars via alpaca-py. + +Equities and options require API keys. Crypto historical data does not (higher +rate limits if keys are provided). See +https://alpaca.markets/sdks/python/market_data.html + +Authentication: + For stocks/options, set ``ALPACA_API_KEY`` and ``ALPACA_SECRET_KEY`` or pass + ``api_key`` / ``secret_key`` to :class:`AlpacaProvider`. + + Optional: set ``ALPACA_STOCK_FEED`` (e.g. ``iex``, ``sip``) when ``feed`` is not + passed to the constructor; useful for free-tier IEX access. + +Date strings ``start`` / ``end`` may include stray whitespace (e.g. ``2024-12- 31``); +it is normalized before parsing. + +This module requires ``alpaca-py`` (``pip install 'ml4t-data[alpaca]'``). If the +package is not installed, ``import ml4t.data.providers.alpaca`` raises +``ImportError``; :mod:`ml4t.data.providers` still loads and exposes +``AlpacaProvider`` as ``None`` when the extra is missing. + +Example: + >>> from ml4t.data.providers.alpaca import AlpacaProvider + >>> p = AlpacaProvider() # crypto only, no keys + >>> # p.fetch_ohlcv("BTC/USD", "2024-01-01", "2024-01-31", "daily") + >>> p2 = AlpacaProvider(api_key="...", secret_key="...") + >>> # p2.fetch_ohlcv("AAPL", "2024-01-01", "2024-01-31", "daily") +""" + +from __future__ import annotations + +import asyncio +import os +import re +import time +from collections import defaultdict +from datetime import UTC, datetime, timedelta +from datetime import time as dt_time +from typing import ClassVar, Literal + +import pandas as pd +import polars as pl +from alpaca.data.enums import Adjustment, DataFeed +from alpaca.data.historical import ( + CryptoHistoricalDataClient, + OptionHistoricalDataClient, + StockHistoricalDataClient, +) +from alpaca.data.requests import CryptoBarsRequest, OptionBarsRequest, StockBarsRequest +from alpaca.data.timeframe import TimeFrame, TimeFrameUnit + +from ml4t.data.core.exceptions import ( + AuthenticationError, + DataValidationError, + SymbolNotFoundError, +) +from ml4t.data.providers.base import BaseProvider + +__all__ = ["AlpacaHistoricalProvider", "AlpacaProvider"] + + +def _chunks(lst: list[str], n: int): + """Yield successive n-sized chunks from lst (same pattern as Yahoo).""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +_OPTION_RE = re.compile(r"^[A-Z]{1,5}\d{6,7}[CP]\d{8}$") +Kind = Literal["crypto", "stock", "option"] + + +def _normalize_iso_date(s: str) -> str: + """Collapse whitespace in YYYY-MM-DD strings so values like ``2024-12- 31`` parse.""" + return "".join(s.split()) + + +def _optional_stock_feed_from_env() -> DataFeed | None: + """Resolve ``ALPACA_STOCK_FEED`` to :class:`~alpaca.data.enums.DataFeed` if set.""" + raw = os.getenv("ALPACA_STOCK_FEED", "").strip() + if not raw: + return None + try: + return DataFeed(raw.lower()) + except ValueError: + return None + + +def _infer_kind(symbol: str) -> Kind: + if "/" in symbol: + return "crypto" + if _OPTION_RE.match(symbol): + return "option" + return "stock" + + +def _timeframe_for(frequency: str) -> TimeFrame: + key = frequency.lower() + match key: + case "minute" | "1minute": + return TimeFrame(1, TimeFrameUnit.Minute) + case "5minute": + return TimeFrame(5, TimeFrameUnit.Minute) + case "15minute": + return TimeFrame(15, TimeFrameUnit.Minute) + case "30minute": + return TimeFrame(30, TimeFrameUnit.Minute) + case "hourly" | "1hour": + return TimeFrame(1, TimeFrameUnit.Hour) + case "daily" | "1day": + return TimeFrame(1, TimeFrameUnit.Day) + case "weekly" | "1week": + return TimeFrame(1, TimeFrameUnit.Week) + case "monthly" | "1month": + return TimeFrame(1, TimeFrameUnit.Month) + case _: + raise DataValidationError( + "alpaca", + f"Unsupported frequency for Alpaca: {frequency!r}", + field="frequency", + value=frequency, + ) + + +def _utc_range(start: str, end: str) -> tuple[datetime, datetime, datetime]: + """Parse YYYY-MM-DD range to Alpaca UTC window (end exclusive for API, inclusive filter).""" + start_utc = datetime.combine( + datetime.strptime(start, "%Y-%m-%d").date(), dt_time.min, tzinfo=UTC + ) + end_day = datetime.strptime(end, "%Y-%m-%d").date() + end_excl = datetime.combine(end_day + timedelta(days=1), dt_time.min, tzinfo=UTC) + end_inclusive = datetime.combine(end_day, dt_time(23, 59, 59, 999999), tzinfo=UTC) + return start_utc, end_excl, end_inclusive + + +def _bars_pandas_to_polars( + df_pd: pd.DataFrame, + symbol: str, + end_inclusive: datetime, +) -> pl.DataFrame: + """Map Alpaca ``BarSet.df`` (pandas) to ml4t OHLCV Polars schema (single symbol).""" + if isinstance(df_pd.index, pd.MultiIndex) and "symbol" in (df_pd.index.names or []): + df_pd = df_pd.droplevel("symbol") + df_pd = df_pd.reset_index() + time_col = "timestamp" if "timestamp" in df_pd.columns else df_pd.columns[0] + out = pl.from_pandas(df_pd[[time_col, "open", "high", "low", "close", "volume"]]) + out = out.rename({time_col: "timestamp"}) + return ( + out.with_columns( + [ + pl.col("timestamp").cast(pl.Datetime(time_zone="UTC")), + pl.col("open").cast(pl.Float64), + pl.col("high").cast(pl.Float64), + pl.col("low").cast(pl.Float64), + pl.col("close").cast(pl.Float64), + pl.col("volume").cast(pl.Float64), + ] + ) + .filter(pl.col("timestamp") <= pl.lit(end_inclusive)) + .with_columns(pl.lit(symbol).alias("symbol")) + .select(["timestamp", "symbol", "open", "high", "low", "close", "volume"]) + .sort("timestamp") + ) + + +def _bars_pandas_to_polars_batch(df_pd: pd.DataFrame, end_inclusive: datetime) -> pl.DataFrame: + """Map Alpaca multi-symbol ``BarSet.df`` to long-format canonical Polars.""" + if df_pd.empty: + return pl.DataFrame( + schema={ + "timestamp": pl.Datetime(time_zone="UTC"), + "symbol": pl.Utf8, + "open": pl.Float64, + "high": pl.Float64, + "low": pl.Float64, + "close": pl.Float64, + "volume": pl.Float64, + } + ) + df_pd = df_pd.reset_index() + if "symbol" not in df_pd.columns: + raise DataValidationError( + "alpaca", + "Expected multi-symbol Alpaca bars to include a symbol column after reset_index", + field="bars.df", + ) + out = pl.from_pandas( + df_pd[["timestamp", "symbol", "open", "high", "low", "close", "volume"]] + ) + return ( + out.with_columns( + [ + pl.col("timestamp").cast(pl.Datetime(time_zone="UTC")), + pl.col("symbol").cast(pl.Utf8), + pl.col("open").cast(pl.Float64), + pl.col("high").cast(pl.Float64), + pl.col("low").cast(pl.Float64), + pl.col("close").cast(pl.Float64), + pl.col("volume").cast(pl.Float64), + ] + ) + .filter(pl.col("timestamp") <= pl.lit(end_inclusive)) + .select(["timestamp", "symbol", "open", "high", "low", "close", "volume"]) + .sort(["symbol", "timestamp"]) + ) + + +class AlpacaProvider(BaseProvider): + """Alpaca historical bars; output matches :class:`BaseProvider` OHLCV contract. + + Install: ``pip install 'ml4t-data[alpaca]'`` + """ + + FREQUENCY_MAP: ClassVar[dict[str, str]] = { + "minute": "1minute", + "1minute": "1minute", + "5minute": "5minute", + "15minute": "15minute", + "30minute": "30minute", + "hourly": "1hour", + "1hour": "1hour", + "daily": "1day", + "1day": "1day", + "weekly": "1week", + "1week": "1week", + "monthly": "1month", + "1month": "1month", + } + + def __init__( + self, + api_key: str | None = None, + secret_key: str | None = None, + *, + feed: DataFeed | None = None, + adjustment: Adjustment | None = None, + ) -> None: + """Initialize Alpaca provider. + + Args: + api_key: Alpaca API key; defaults to ``ALPACA_API_KEY``. + secret_key: Alpaca secret key; defaults to ``ALPACA_SECRET_KEY``. + feed: Stock bar feed (e.g. SIP); optional. + adjustment: Corporate-action adjustment for stock bars; optional. + """ + super().__init__(rate_limit=None) + self._api_key = api_key or os.getenv("ALPACA_API_KEY") + self._secret_key = secret_key or os.getenv("ALPACA_SECRET_KEY") + self._feed = feed + self._adjustment = adjustment + self._crypto: CryptoHistoricalDataClient | None = None + self._stock: StockHistoricalDataClient | None = None + self._option: OptionHistoricalDataClient | None = None + + @property + def name(self) -> str: + """Return provider name.""" + return "alpaca" + + def _client(self, kind: Kind): + if kind == "crypto": + if self._crypto is None: + self._crypto = CryptoHistoricalDataClient() + return self._crypto + if self._api_key is None or self._secret_key is None: + raise AuthenticationError( + "alpaca", + "api_key and secret_key are required for stock and option bars. " + "Set ALPACA_API_KEY and ALPACA_SECRET_KEY environment variables " + "or pass api_key and secret_key.", + ) + if kind == "stock": + if self._stock is None: + self._stock = StockHistoricalDataClient(self._api_key, self._secret_key) + return self._stock + if self._option is None: + self._option = OptionHistoricalDataClient(self._api_key, self._secret_key) + return self._option + + def _fetch_and_transform_data( + self, symbol: str, start: str, end: str, frequency: str + ) -> pl.DataFrame: + kind = _infer_kind(symbol) + tf = _timeframe_for(self.FREQUENCY_MAP.get(frequency.lower(), frequency)) + client = self._client(kind) + start_utc, end_excl, end_inclusive = _utc_range(start, end) + + self.logger.info( + "Fetching Alpaca bars", + symbol=symbol, + start=start, + end=end, + frequency=frequency, + kind=kind, + ) + + try: + if kind == "crypto": + req = CryptoBarsRequest( + symbol_or_symbols=symbol, + timeframe=tf, + start=start_utc, + end=end_excl, + ) + bars = client.get_crypto_bars(req) + elif kind == "stock": + req = StockBarsRequest( + symbol_or_symbols=symbol, + timeframe=tf, + start=start_utc, + end=end_excl, + feed=self._feed, + adjustment=self._adjustment, + ) + bars = client.get_stock_bars(req) + else: + req = OptionBarsRequest( + symbol_or_symbols=symbol, + timeframe=tf, + start=start_utc, + end=end_excl, + ) + bars = client.get_option_bars(req) + except Exception as e: + raise DataValidationError( + "alpaca", + f"Failed to fetch {symbol}: {e}", + details={"symbol": symbol, "error": str(e)}, + ) from e + + df_pd = bars.df + if df_pd.empty: + raise SymbolNotFoundError( + "alpaca", + symbol, + details={"start": start, "end": end, "frequency": frequency}, + ) + + out = _bars_pandas_to_polars(df_pd, symbol, end_inclusive) + self.logger.info("Fetched Alpaca bars", symbol=symbol, rows=len(out)) + return out + + def fetch_batch_ohlcv( + self, + symbols: list[str], + start: str, + end: str, + frequency: str = "daily", + chunk_size: int = 50, + delay_seconds: float = 0.0, + ) -> pl.DataFrame: + """Fetch OHLCV for multiple symbols using Alpaca multi-symbol bar requests. + + Symbols are grouped by asset class (crypto / stock / option). Each chunk + calls the API with ``symbol_or_symbols=[...]`` (same idea as Yahoo batch). + + Args: + symbols: Ticker list (e.g. ``[\"AAPL\", \"MSFT\"]`` or crypto ``BTC/USD``). + start: Start date ``YYYY-MM-DD``. + end: End date ``YYYY-MM-DD`` (inclusive). + frequency: Same names as :meth:`fetch_ohlcv`. + chunk_size: Max symbols per API request per asset class. + delay_seconds: Pause between chunks (rate limiting). + + Returns: + Long-format Polars frame, sorted by ``symbol``, ``timestamp``. + """ + if not symbols: + return self._create_empty_dataframe() + + tf = _timeframe_for(self.FREQUENCY_MAP.get(frequency.lower(), frequency)) + start_utc, end_excl, end_inclusive = _utc_range(start, end) + + by_kind: dict[Kind, list[str]] = defaultdict(list) + for s in symbols: + by_kind[_infer_kind(s)].append(s) + + if ("stock" in by_kind or "option" in by_kind) and ( + self._api_key is None or self._secret_key is None + ): + raise AuthenticationError( + "alpaca", + "api_key and secret_key are required when batch includes " + "stock or option symbols. Set ALPACA_API_KEY / ALPACA_SECRET_KEY " + "or pass keys to the constructor.", + ) + + self.logger.info( + "Starting Alpaca batch download", + total_symbols=len(symbols), + chunk_size=chunk_size, + start=start, + end=end, + ) + + all_parts: list[pl.DataFrame] = [] + failed: list[str] = [] + + for kind in ("crypto", "stock", "option"): + syms = by_kind.get(kind) + if not syms: + continue + client = self._client(kind) + n_chunks = (len(syms) + chunk_size - 1) // chunk_size + for i, chunk in enumerate(_chunks(syms, chunk_size), 1): + try: + if kind == "crypto": + req = CryptoBarsRequest( + symbol_or_symbols=chunk, + timeframe=tf, + start=start_utc, + end=end_excl, + ) + bars = client.get_crypto_bars(req) + elif kind == "stock": + req = StockBarsRequest( + symbol_or_symbols=chunk, + timeframe=tf, + start=start_utc, + end=end_excl, + feed=self._feed, + adjustment=self._adjustment, + ) + bars = client.get_stock_bars(req) + else: + req = OptionBarsRequest( + symbol_or_symbols=chunk, + timeframe=tf, + start=start_utc, + end=end_excl, + ) + bars = client.get_option_bars(req) + + df_pd = bars.df + if df_pd.empty: + self.logger.warning("Empty Alpaca batch chunk", chunk=i, symbols=chunk) + failed.extend(chunk) + continue + all_parts.append(_bars_pandas_to_polars_batch(df_pd, end_inclusive)) + except Exception as e: + self.logger.error( + "Alpaca batch chunk failed", chunk=i, symbols=chunk, error=str(e) + ) + failed.extend(chunk) + + if i < n_chunks and delay_seconds > 0: + time.sleep(delay_seconds) + + if not all_parts: + self.logger.error("No Alpaca batch data", failed_symbols=failed) + return self._create_empty_dataframe() + + result = pl.concat(all_parts).sort(["symbol", "timestamp"]) + self.logger.info( + "Alpaca batch download complete", + rows=len(result), + failed_symbols=len(failed), + ) + if failed: + self.logger.warning("Some Alpaca batch symbols failed", count=len(failed), sample=failed[:10]) + return result + + async def fetch_batch_ohlcv_async( + self, + symbols: list[str], + start: str, + end: str, + frequency: str = "daily", + chunk_size: int = 50, + delay_seconds: float = 0.0, + ) -> pl.DataFrame: + """Async batch fetch (sync Alpaca client via :func:`asyncio.to_thread`).""" + return await asyncio.to_thread( + self.fetch_batch_ohlcv, + symbols, + start, + end, + frequency, + chunk_size, + delay_seconds, + ) + + async def close_async(self) -> None: + """No network session to close (SDK owns HTTP).""" + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close_async() + + +AlpacaHistoricalProvider = AlpacaProvider diff --git a/tests/test_alpaca_provider_unit.py b/tests/test_alpaca_provider_unit.py new file mode 100644 index 0000000..b6b38ee --- /dev/null +++ b/tests/test_alpaca_provider_unit.py @@ -0,0 +1,315 @@ +"""Tests for Alpaca OHLCV provider (requires ``alpaca-py`` / ``ml4t-data[alpaca]``).""" + +import asyncio +from datetime import UTC, datetime +from unittest.mock import MagicMock, patch + +import pandas as pd +import polars as pl +import pytest + +pytest.importorskip("alpaca") + +from ml4t.data.core.exceptions import AuthenticationError, DataValidationError +from ml4t.data.providers.alpaca import ( + AlpacaHistoricalProvider, + AlpacaProvider, + _bars_pandas_to_polars, + _bars_pandas_to_polars_batch, + _infer_kind, + _timeframe_for, +) + + +class TestInferKind: + """Tests for asset-class inference from symbol.""" + + def test_crypto_slash(self): + assert _infer_kind("BTC/USD") == "crypto" + assert _infer_kind("ETH/USD") == "crypto" + + def test_option_oc(self): + assert _infer_kind("AAPL241220C00300000") == "option" + + def test_stock(self): + assert _infer_kind("AAPL") == "stock" + assert _infer_kind("MSFT") == "stock" + + +class TestFrequencyMap: + """FREQUENCY_MAP mirrors Yahoo-style aliases.""" + + def test_keys_cover_aliases(self): + p = AlpacaProvider() + for key in ( + "daily", + "1day", + "minute", + "5minute", + "hourly", + "monthly", + ): + assert key in p.FREQUENCY_MAP + + def test_map_normalizes_to_internal_names(self): + p = AlpacaProvider() + assert p.FREQUENCY_MAP["daily"] == "1day" + assert p.FREQUENCY_MAP["minute"] == "1minute" + assert p.FREQUENCY_MAP["5minute"] == "5minute" + + +class TestTimeframeFor: + """Tests for Alpaca TimeFrame mapping.""" + + def test_daily(self): + from alpaca.data.timeframe import TimeFrameUnit + + tf = _timeframe_for("daily") + assert tf.amount == 1 + assert tf.unit == TimeFrameUnit.Day + + def test_five_minute(self): + from alpaca.data.timeframe import TimeFrameUnit + + tf = _timeframe_for("5minute") + assert tf.amount == 5 + assert tf.unit == TimeFrameUnit.Minute + + def test_invalid_frequency_raises_data_validation(self): + with pytest.raises(DataValidationError, match="Unsupported frequency"): + _timeframe_for("not_a_real_freq") + + +class TestBarsPandasToPolars: + """Pandas BarSet → Polars (single- and multi-symbol).""" + + def test_multiindex_symbol_timestamp(self): + idx = pd.MultiIndex.from_tuples( + [("BTC/USD", pd.Timestamp("2024-01-02", tz="UTC"))], + names=["symbol", "timestamp"], + ) + df_pd = pd.DataFrame( + {"open": [1.0], "high": [2.0], "low": [0.5], "close": [1.5], "volume": [100.0]}, + index=idx, + ) + end_inc = datetime(2024, 1, 3, 23, 59, 59, 999999, tzinfo=UTC) + out = _bars_pandas_to_polars(df_pd, "BTC/USD", end_inc) + assert len(out) == 1 + assert out.columns == [ + "timestamp", + "symbol", + "open", + "high", + "low", + "close", + "volume", + ] + assert out["symbol"][0] == "BTC/USD" + assert out["close"][0] == 1.5 + + def test_end_filter_excludes_bar_after_range(self): + idx = pd.MultiIndex.from_tuples( + [ + ("BTC/USD", pd.Timestamp("2024-01-02", tz="UTC")), + ("BTC/USD", pd.Timestamp("2024-01-05", tz="UTC")), + ], + names=["symbol", "timestamp"], + ) + df_pd = pd.DataFrame( + { + "open": [1.0, 2.0], + "high": [2.0, 3.0], + "low": [0.5, 1.5], + "close": [1.5, 2.5], + "volume": [100.0, 200.0], + }, + index=idx, + ) + end_inc = datetime(2024, 1, 3, 23, 59, 59, 999999, tzinfo=UTC) + out = _bars_pandas_to_polars(df_pd, "BTC/USD", end_inc) + assert len(out) == 1 + + def test_batch_two_symbols(self): + idx = pd.MultiIndex.from_tuples( + [ + ("BTC/USD", pd.Timestamp("2024-01-02", tz="UTC")), + ("ETH/USD", pd.Timestamp("2024-01-02", tz="UTC")), + ], + names=["symbol", "timestamp"], + ) + df_pd = pd.DataFrame( + { + "open": [1.0, 2.0], + "high": [2.0, 3.0], + "low": [0.5, 1.5], + "close": [1.5, 2.5], + "volume": [100.0, 200.0], + }, + index=idx, + ) + end_inc = datetime(2024, 1, 3, 23, 59, 59, 999999, tzinfo=UTC) + out = _bars_pandas_to_polars_batch(df_pd, end_inc) + assert len(out) == 2 + assert set(out["symbol"].to_list()) == {"BTC/USD", "ETH/USD"} + + +class TestCreateEmptyDataframe: + """Empty OHLCV schema (BaseProvider).""" + + def test_empty_dataframe_columns(self): + provider = AlpacaProvider() + df = provider._create_empty_dataframe() + assert list(df.columns) == [ + "timestamp", + "symbol", + "open", + "high", + "low", + "close", + "volume", + ] + + def test_empty_dataframe_length(self): + provider = AlpacaProvider() + assert len(provider._create_empty_dataframe()) == 0 + + +class TestAuthentication: + """Stock/option paths require credentials.""" + + def test_stock_without_keys_raises(self): + with patch.dict("os.environ", {}, clear=True): + p = AlpacaProvider() + with pytest.raises(AuthenticationError, match="api_key and secret_key"): + p.fetch_ohlcv("AAPL", "2024-01-01", "2024-01-05", "daily") + + def test_option_without_keys_raises(self): + with patch.dict("os.environ", {}, clear=True): + p = AlpacaProvider() + with pytest.raises(AuthenticationError, match="api_key and secret_key"): + p.fetch_ohlcv("AAPL241220C00300000", "2024-12-01", "2024-12-05", "daily") + + def test_batch_stock_without_keys_raises(self): + with patch.dict("os.environ", {}, clear=True): + p = AlpacaProvider() + with pytest.raises(AuthenticationError, match="stock or option"): + p.fetch_batch_ohlcv(["AAPL", "MSFT"], "2024-01-01", "2024-01-05", "daily") + + def test_init_reads_env_keys(self): + with patch.dict( + "os.environ", + {"ALPACA_API_KEY": "pk-x", "ALPACA_SECRET_KEY": "sec-y"}, + ): + p = AlpacaProvider() + assert p._api_key == "pk-x" + assert p._secret_key == "sec-y" + + +class TestFetchCryptoMocked: + """Offline tests with mocked Alpaca client.""" + + def test_fetch_crypto_returns_canonical_schema(self): + idx = pd.MultiIndex.from_tuples( + [("BTC/USD", pd.Timestamp("2024-01-02", tz="UTC"))], + names=["symbol", "timestamp"], + ) + df_pd = pd.DataFrame( + {"open": [1.0], "high": [2.0], "low": [0.5], "close": [1.5], "volume": [100.0]}, + index=idx, + ) + barset = MagicMock() + barset.df = df_pd + + with patch("ml4t.data.providers.alpaca.CryptoHistoricalDataClient") as client_cls: + inst = MagicMock() + client_cls.return_value = inst + inst.get_crypto_bars.return_value = barset + p = AlpacaProvider() + out = p._fetch_and_transform_data("BTC/USD", "2024-01-01", "2024-01-03", "daily") + assert len(out) == 1 + assert out["close"][0] == 1.5 + + def test_fetch_batch_crypto_two_symbols_mocked(self): + idx = pd.MultiIndex.from_tuples( + [ + ("BTC/USD", pd.Timestamp("2024-01-02", tz="UTC")), + ("ETH/USD", pd.Timestamp("2024-01-02", tz="UTC")), + ], + names=["symbol", "timestamp"], + ) + df_pd = pd.DataFrame( + { + "open": [1.0, 2.0], + "high": [2.0, 3.0], + "low": [0.5, 1.5], + "close": [1.5, 2.5], + "volume": [100.0, 200.0], + }, + index=idx, + ) + barset = MagicMock() + barset.df = df_pd + + with patch("ml4t.data.providers.alpaca.CryptoHistoricalDataClient") as client_cls: + inst = MagicMock() + client_cls.return_value = inst + inst.get_crypto_bars.return_value = barset + p = AlpacaProvider() + out = p.fetch_batch_ohlcv( + ["BTC/USD", "ETH/USD"], "2024-01-01", "2024-01-03", "daily", chunk_size=10 + ) + assert len(out) == 2 + assert out.columns == [ + "timestamp", + "symbol", + "open", + "high", + "low", + "close", + "volume", + ] + + def test_fetch_batch_empty_symbols(self): + p = AlpacaProvider() + out = p.fetch_batch_ohlcv([], "2024-01-01", "2024-01-03") + assert len(out) == 0 + + +class TestAlpacaProviderInit: + """Naming and backward-compat alias.""" + + def test_name(self): + assert AlpacaProvider().name == "alpaca" + + def test_historical_alias(self): + assert AlpacaHistoricalProvider is AlpacaProvider + + +class TestInvalidFrequencyInFetch: + """Unsupported frequency → DataValidationError before auth.""" + + def test_fetch_bad_frequency(self): + p = AlpacaProvider() + with pytest.raises(DataValidationError, match="Unsupported frequency"): + p.fetch_ohlcv("BTC/USD", "2024-01-01", "2024-01-03", "bad_freq_xyz") + + def test_stock_bad_frequency_without_keys_is_validation_not_auth(self): + with patch.dict("os.environ", {}, clear=True): + p = AlpacaProvider() + with pytest.raises(DataValidationError, match="Unsupported frequency"): + p.fetch_ohlcv("AAPL", "2024-01-01", "2024-01-03", "bad_freq_xyz") + + +class TestFetchBatchAsync: + """Mirror Yahoo's asyncio.to_thread batch wrapper.""" + + def test_fetch_batch_ohlcv_async_delegates(self): + p = AlpacaProvider() + with patch.object(p, "fetch_batch_ohlcv", return_value=pl.DataFrame()) as mock_fb: + out = asyncio.run( + p.fetch_batch_ohlcv_async( + ["BTC/USD"], "2024-01-01", "2024-01-02", frequency="daily" + ) + ) + mock_fb.assert_called_once() + assert len(out) == 0 diff --git a/uv.lock b/uv.lock index 77b62d1..cb13e31 100644 --- a/uv.lock +++ b/uv.lock @@ -139,6 +139,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "alpaca-py" +version = "0.43.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "msgpack" }, + { name = "pandas" }, + { name = "pydantic" }, + { name = "pytz" }, + { name = "requests" }, + { name = "sseclient-py" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2f/9d/3003f661c15b8003655c447c187aec10f0843647e5c98b391701b04ac3d8/alpaca_py-0.43.4.tar.gz", hash = "sha256:7d529b3654d4e817d9fd7ab461131c4f06a315c736b6a9e4a87d5406bb71114a", size = 97990, upload-time = "2026-04-29T08:41:48.775Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/d5/1f57cc03e7b5925a927cb7f8e7ee5f873e22632633778d28d5d23681c871/alpaca_py-0.43.4-py3-none-any.whl", hash = "sha256:dd49ac30e0f2a8f38550ef1f27a58e7fd8f3f3875deaa4e757443cdbd033a1b4", size = 122534, upload-time = "2026-04-29T08:41:50.149Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -1576,6 +1594,7 @@ dependencies = [ [package.optional-dependencies] all = [ + { name = "alpaca-py" }, { name = "cot-reports" }, { name = "databento" }, { name = "hypothesis" }, @@ -1598,11 +1617,15 @@ all = [ { name = "yfinance" }, ] all-providers = [ + { name = "alpaca-py" }, { name = "cot-reports" }, { name = "databento" }, { name = "oandapyv20" }, { name = "yfinance" }, ] +alpaca = [ + { name = "alpaca-py" }, +] cot = [ { name = "cot-reports" }, ] @@ -1639,6 +1662,7 @@ yahoo = [ [package.dev-dependencies] dev = [ + { name = "alpaca-py" }, { name = "databento" }, { name = "hypothesis" }, { name = "oandapyv20" }, @@ -1658,6 +1682,9 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiofiles", specifier = ">=23.0.0" }, + { name = "alpaca-py", marker = "extra == 'all'", specifier = ">=0.30.0" }, + { name = "alpaca-py", marker = "extra == 'all-providers'", specifier = ">=0.30.0" }, + { name = "alpaca-py", marker = "extra == 'alpaca'", specifier = ">=0.30.0" }, { name = "click", specifier = ">=8.0.0" }, { name = "cot-reports", marker = "extra == 'all'", specifier = ">=0.1.0" }, { name = "cot-reports", marker = "extra == 'all-providers'", specifier = ">=0.1.0" }, @@ -1723,10 +1750,11 @@ requires-dist = [ { name = "yfinance", marker = "extra == 'all-providers'", specifier = ">=0.2.0" }, { name = "yfinance", marker = "extra == 'yahoo'", specifier = ">=0.2.0" }, ] -provides-extras = ["all", "all-providers", "cot", "databento", "dev", "docs", "oanda", "yahoo"] +provides-extras = ["all", "all-providers", "alpaca", "cot", "databento", "dev", "docs", "oanda", "yahoo"] [package.metadata.requires-dev] dev = [ + { name = "alpaca-py", specifier = ">=0.30.0" }, { name = "databento", specifier = ">=0.38.0" }, { name = "hypothesis", specifier = ">=6.80.0" }, { name = "oandapyv20", specifier = ">=0.7.0" }, @@ -1752,6 +1780,59 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/8e/469e5a4a2f5855992e425f3cb33804cc07bf18d48f2db061aec61ce50270/more_itertools-10.8.0-py3-none-any.whl", hash = "sha256:52d4362373dcf7c52546bc4af9a86ee7c4579df9a8dc268be0a2f949d376cc9b", size = 69667, upload-time = "2025-09-02T15:23:09.635Z" }, ] +[[package]] +name = "msgpack" +version = "1.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4d/f2/bfb55a6236ed8725a96b0aa3acbd0ec17588e6a2c3b62a93eb513ed8783f/msgpack-1.1.2.tar.gz", hash = "sha256:3b60763c1373dd60f398488069bcdc703cd08a711477b5d480eecc9f9626f47e", size = 173581, upload-time = "2025-10-08T09:15:56.596Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/97/560d11202bcd537abca693fd85d81cebe2107ba17301de42b01ac1677b69/msgpack-1.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2e86a607e558d22985d856948c12a3fa7b42efad264dca8a3ebbcfa2735d786c", size = 82271, upload-time = "2025-10-08T09:14:49.967Z" }, + { url = "https://files.pythonhosted.org/packages/83/04/28a41024ccbd67467380b6fb440ae916c1e4f25e2cd4c63abe6835ac566e/msgpack-1.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:283ae72fc89da59aa004ba147e8fc2f766647b1251500182fac0350d8af299c0", size = 84914, upload-time = "2025-10-08T09:14:50.958Z" }, + { url = "https://files.pythonhosted.org/packages/71/46/b817349db6886d79e57a966346cf0902a426375aadc1e8e7a86a75e22f19/msgpack-1.1.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:61c8aa3bd513d87c72ed0b37b53dd5c5a0f58f2ff9f26e1555d3bd7948fb7296", size = 416962, upload-time = "2025-10-08T09:14:51.997Z" }, + { url = "https://files.pythonhosted.org/packages/da/e0/6cc2e852837cd6086fe7d8406af4294e66827a60a4cf60b86575a4a65ca8/msgpack-1.1.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:454e29e186285d2ebe65be34629fa0e8605202c60fbc7c4c650ccd41870896ef", size = 426183, upload-time = "2025-10-08T09:14:53.477Z" }, + { url = "https://files.pythonhosted.org/packages/25/98/6a19f030b3d2ea906696cedd1eb251708e50a5891d0978b012cb6107234c/msgpack-1.1.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7bc8813f88417599564fafa59fd6f95be417179f76b40325b500b3c98409757c", size = 411454, upload-time = "2025-10-08T09:14:54.648Z" }, + { url = "https://files.pythonhosted.org/packages/b7/cd/9098fcb6adb32187a70b7ecaabf6339da50553351558f37600e53a4a2a23/msgpack-1.1.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bafca952dc13907bdfdedfc6a5f579bf4f292bdd506fadb38389afa3ac5b208e", size = 422341, upload-time = "2025-10-08T09:14:56.328Z" }, + { url = "https://files.pythonhosted.org/packages/e6/ae/270cecbcf36c1dc85ec086b33a51a4d7d08fc4f404bdbc15b582255d05ff/msgpack-1.1.2-cp311-cp311-win32.whl", hash = "sha256:602b6740e95ffc55bfb078172d279de3773d7b7db1f703b2f1323566b878b90e", size = 64747, upload-time = "2025-10-08T09:14:57.882Z" }, + { url = "https://files.pythonhosted.org/packages/2a/79/309d0e637f6f37e83c711f547308b91af02b72d2326ddd860b966080ef29/msgpack-1.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:d198d275222dc54244bf3327eb8cbe00307d220241d9cec4d306d49a44e85f68", size = 71633, upload-time = "2025-10-08T09:14:59.177Z" }, + { url = "https://files.pythonhosted.org/packages/73/4d/7c4e2b3d9b1106cd0aa6cb56cc57c6267f59fa8bfab7d91df5adc802c847/msgpack-1.1.2-cp311-cp311-win_arm64.whl", hash = "sha256:86f8136dfa5c116365a8a651a7d7484b65b13339731dd6faebb9a0242151c406", size = 64755, upload-time = "2025-10-08T09:15:00.48Z" }, + { url = "https://files.pythonhosted.org/packages/ad/bd/8b0d01c756203fbab65d265859749860682ccd2a59594609aeec3a144efa/msgpack-1.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:70a0dff9d1f8da25179ffcf880e10cf1aad55fdb63cd59c9a49a1b82290062aa", size = 81939, upload-time = "2025-10-08T09:15:01.472Z" }, + { url = "https://files.pythonhosted.org/packages/34/68/ba4f155f793a74c1483d4bdef136e1023f7bcba557f0db4ef3db3c665cf1/msgpack-1.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:446abdd8b94b55c800ac34b102dffd2f6aa0ce643c55dfc017ad89347db3dbdb", size = 85064, upload-time = "2025-10-08T09:15:03.764Z" }, + { url = "https://files.pythonhosted.org/packages/f2/60/a064b0345fc36c4c3d2c743c82d9100c40388d77f0b48b2f04d6041dbec1/msgpack-1.1.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c63eea553c69ab05b6747901b97d620bb2a690633c77f23feb0c6a947a8a7b8f", size = 417131, upload-time = "2025-10-08T09:15:05.136Z" }, + { url = "https://files.pythonhosted.org/packages/65/92/a5100f7185a800a5d29f8d14041f61475b9de465ffcc0f3b9fba606e4505/msgpack-1.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:372839311ccf6bdaf39b00b61288e0557916c3729529b301c52c2d88842add42", size = 427556, upload-time = "2025-10-08T09:15:06.837Z" }, + { url = "https://files.pythonhosted.org/packages/f5/87/ffe21d1bf7d9991354ad93949286f643b2bb6ddbeab66373922b44c3b8cc/msgpack-1.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2929af52106ca73fcb28576218476ffbb531a036c2adbcf54a3664de124303e9", size = 404920, upload-time = "2025-10-08T09:15:08.179Z" }, + { url = "https://files.pythonhosted.org/packages/ff/41/8543ed2b8604f7c0d89ce066f42007faac1eaa7d79a81555f206a5cdb889/msgpack-1.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:be52a8fc79e45b0364210eef5234a7cf8d330836d0a64dfbb878efa903d84620", size = 415013, upload-time = "2025-10-08T09:15:09.83Z" }, + { url = "https://files.pythonhosted.org/packages/41/0d/2ddfaa8b7e1cee6c490d46cb0a39742b19e2481600a7a0e96537e9c22f43/msgpack-1.1.2-cp312-cp312-win32.whl", hash = "sha256:1fff3d825d7859ac888b0fbda39a42d59193543920eda9d9bea44d958a878029", size = 65096, upload-time = "2025-10-08T09:15:11.11Z" }, + { url = "https://files.pythonhosted.org/packages/8c/ec/d431eb7941fb55a31dd6ca3404d41fbb52d99172df2e7707754488390910/msgpack-1.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:1de460f0403172cff81169a30b9a92b260cb809c4cb7e2fc79ae8d0510c78b6b", size = 72708, upload-time = "2025-10-08T09:15:12.554Z" }, + { url = "https://files.pythonhosted.org/packages/c5/31/5b1a1f70eb0e87d1678e9624908f86317787b536060641d6798e3cf70ace/msgpack-1.1.2-cp312-cp312-win_arm64.whl", hash = "sha256:be5980f3ee0e6bd44f3a9e9dea01054f175b50c3e6cdb692bc9424c0bbb8bf69", size = 64119, upload-time = "2025-10-08T09:15:13.589Z" }, + { url = "https://files.pythonhosted.org/packages/6b/31/b46518ecc604d7edf3a4f94cb3bf021fc62aa301f0cb849936968164ef23/msgpack-1.1.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4efd7b5979ccb539c221a4c4e16aac1a533efc97f3b759bb5a5ac9f6d10383bf", size = 81212, upload-time = "2025-10-08T09:15:14.552Z" }, + { url = "https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:42eefe2c3e2af97ed470eec850facbe1b5ad1d6eacdbadc42ec98e7dcf68b4b7", size = 84315, upload-time = "2025-10-08T09:15:15.543Z" }, + { url = "https://files.pythonhosted.org/packages/d3/68/93180dce57f684a61a88a45ed13047558ded2be46f03acb8dec6d7c513af/msgpack-1.1.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1fdf7d83102bf09e7ce3357de96c59b627395352a4024f6e2458501f158bf999", size = 412721, upload-time = "2025-10-08T09:15:16.567Z" }, + { url = "https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fac4be746328f90caa3cd4bc67e6fe36ca2bf61d5c6eb6d895b6527e3f05071e", size = 424657, upload-time = "2025-10-08T09:15:17.825Z" }, + { url = "https://files.pythonhosted.org/packages/38/f8/4398c46863b093252fe67368b44edc6c13b17f4e6b0e4929dbf0bdb13f23/msgpack-1.1.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:fffee09044073e69f2bad787071aeec727183e7580443dfeb8556cbf1978d162", size = 402668, upload-time = "2025-10-08T09:15:19.003Z" }, + { url = "https://files.pythonhosted.org/packages/28/ce/698c1eff75626e4124b4d78e21cca0b4cc90043afb80a507626ea354ab52/msgpack-1.1.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5928604de9b032bc17f5099496417f113c45bc6bc21b5c6920caf34b3c428794", size = 419040, upload-time = "2025-10-08T09:15:20.183Z" }, + { url = "https://files.pythonhosted.org/packages/67/32/f3cd1667028424fa7001d82e10ee35386eea1408b93d399b09fb0aa7875f/msgpack-1.1.2-cp313-cp313-win32.whl", hash = "sha256:a7787d353595c7c7e145e2331abf8b7ff1e6673a6b974ded96e6d4ec09f00c8c", size = 65037, upload-time = "2025-10-08T09:15:21.416Z" }, + { url = "https://files.pythonhosted.org/packages/74/07/1ed8277f8653c40ebc65985180b007879f6a836c525b3885dcc6448ae6cb/msgpack-1.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:a465f0dceb8e13a487e54c07d04ae3ba131c7c5b95e2612596eafde1dccf64a9", size = 72631, upload-time = "2025-10-08T09:15:22.431Z" }, + { url = "https://files.pythonhosted.org/packages/e5/db/0314e4e2db56ebcf450f277904ffd84a7988b9e5da8d0d61ab2d057df2b6/msgpack-1.1.2-cp313-cp313-win_arm64.whl", hash = "sha256:e69b39f8c0aa5ec24b57737ebee40be647035158f14ed4b40e6f150077e21a84", size = 64118, upload-time = "2025-10-08T09:15:23.402Z" }, + { url = "https://files.pythonhosted.org/packages/22/71/201105712d0a2ff07b7873ed3c220292fb2ea5120603c00c4b634bcdafb3/msgpack-1.1.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e23ce8d5f7aa6ea6d2a2b326b4ba46c985dbb204523759984430db7114f8aa00", size = 81127, upload-time = "2025-10-08T09:15:24.408Z" }, + { url = "https://files.pythonhosted.org/packages/1b/9f/38ff9e57a2eade7bf9dfee5eae17f39fc0e998658050279cbb14d97d36d9/msgpack-1.1.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:6c15b7d74c939ebe620dd8e559384be806204d73b4f9356320632d783d1f7939", size = 84981, upload-time = "2025-10-08T09:15:25.812Z" }, + { url = "https://files.pythonhosted.org/packages/8e/a9/3536e385167b88c2cc8f4424c49e28d49a6fc35206d4a8060f136e71f94c/msgpack-1.1.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99e2cb7b9031568a2a5c73aa077180f93dd2e95b4f8d3b8e14a73ae94a9e667e", size = 411885, upload-time = "2025-10-08T09:15:27.22Z" }, + { url = "https://files.pythonhosted.org/packages/2f/40/dc34d1a8d5f1e51fc64640b62b191684da52ca469da9cd74e84936ffa4a6/msgpack-1.1.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:180759d89a057eab503cf62eeec0aa61c4ea1200dee709f3a8e9397dbb3b6931", size = 419658, upload-time = "2025-10-08T09:15:28.4Z" }, + { url = "https://files.pythonhosted.org/packages/3b/ef/2b92e286366500a09a67e03496ee8b8ba00562797a52f3c117aa2b29514b/msgpack-1.1.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:04fb995247a6e83830b62f0b07bf36540c213f6eac8e851166d8d86d83cbd014", size = 403290, upload-time = "2025-10-08T09:15:29.764Z" }, + { url = "https://files.pythonhosted.org/packages/78/90/e0ea7990abea5764e4655b8177aa7c63cdfa89945b6e7641055800f6c16b/msgpack-1.1.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8e22ab046fa7ede9e36eeb4cfad44d46450f37bb05d5ec482b02868f451c95e2", size = 415234, upload-time = "2025-10-08T09:15:31.022Z" }, + { url = "https://files.pythonhosted.org/packages/72/4e/9390aed5db983a2310818cd7d3ec0aecad45e1f7007e0cda79c79507bb0d/msgpack-1.1.2-cp314-cp314-win32.whl", hash = "sha256:80a0ff7d4abf5fecb995fcf235d4064b9a9a8a40a3ab80999e6ac1e30b702717", size = 66391, upload-time = "2025-10-08T09:15:32.265Z" }, + { url = "https://files.pythonhosted.org/packages/6e/f1/abd09c2ae91228c5f3998dbd7f41353def9eac64253de3c8105efa2082f7/msgpack-1.1.2-cp314-cp314-win_amd64.whl", hash = "sha256:9ade919fac6a3e7260b7f64cea89df6bec59104987cbea34d34a2fa15d74310b", size = 73787, upload-time = "2025-10-08T09:15:33.219Z" }, + { url = "https://files.pythonhosted.org/packages/6a/b0/9d9f667ab48b16ad4115c1935d94023b82b3198064cb84a123e97f7466c1/msgpack-1.1.2-cp314-cp314-win_arm64.whl", hash = "sha256:59415c6076b1e30e563eb732e23b994a61c159cec44deaf584e5cc1dd662f2af", size = 66453, upload-time = "2025-10-08T09:15:34.225Z" }, + { url = "https://files.pythonhosted.org/packages/16/67/93f80545eb1792b61a217fa7f06d5e5cb9e0055bed867f43e2b8e012e137/msgpack-1.1.2-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:897c478140877e5307760b0ea66e0932738879e7aa68144d9b78ea4c8302a84a", size = 85264, upload-time = "2025-10-08T09:15:35.61Z" }, + { url = "https://files.pythonhosted.org/packages/87/1c/33c8a24959cf193966ef11a6f6a2995a65eb066bd681fd085afd519a57ce/msgpack-1.1.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a668204fa43e6d02f89dbe79a30b0d67238d9ec4c5bd8a940fc3a004a47b721b", size = 89076, upload-time = "2025-10-08T09:15:36.619Z" }, + { url = "https://files.pythonhosted.org/packages/fc/6b/62e85ff7193663fbea5c0254ef32f0c77134b4059f8da89b958beb7696f3/msgpack-1.1.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5559d03930d3aa0f3aacb4c42c776af1a2ace2611871c84a75afe436695e6245", size = 435242, upload-time = "2025-10-08T09:15:37.647Z" }, + { url = "https://files.pythonhosted.org/packages/c1/47/5c74ecb4cc277cf09f64e913947871682ffa82b3b93c8dad68083112f412/msgpack-1.1.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:70c5a7a9fea7f036b716191c29047374c10721c389c21e9ffafad04df8c52c90", size = 432509, upload-time = "2025-10-08T09:15:38.794Z" }, + { url = "https://files.pythonhosted.org/packages/24/a4/e98ccdb56dc4e98c929a3f150de1799831c0a800583cde9fa022fa90602d/msgpack-1.1.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f2cb069d8b981abc72b41aea1c580ce92d57c673ec61af4c500153a626cb9e20", size = 415957, upload-time = "2025-10-08T09:15:40.238Z" }, + { url = "https://files.pythonhosted.org/packages/da/28/6951f7fb67bc0a4e184a6b38ab71a92d9ba58080b27a77d3e2fb0be5998f/msgpack-1.1.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d62ce1f483f355f61adb5433ebfd8868c5f078d1a52d042b0a998682b4fa8c27", size = 422910, upload-time = "2025-10-08T09:15:41.505Z" }, + { url = "https://files.pythonhosted.org/packages/f0/03/42106dcded51f0a0b5284d3ce30a671e7bd3f7318d122b2ead66ad289fed/msgpack-1.1.2-cp314-cp314t-win32.whl", hash = "sha256:1d1418482b1ee984625d88aa9585db570180c286d942da463533b238b98b812b", size = 75197, upload-time = "2025-10-08T09:15:42.954Z" }, + { url = "https://files.pythonhosted.org/packages/15/86/d0071e94987f8db59d4eeb386ddc64d0bb9b10820a8d82bcd3e53eeb2da6/msgpack-1.1.2-cp314-cp314t-win_amd64.whl", hash = "sha256:5a46bf7e831d09470ad92dff02b8b1ac92175ca36b087f904a0519857c6be3ff", size = 85772, upload-time = "2025-10-08T09:15:43.954Z" }, + { url = "https://files.pythonhosted.org/packages/81/f2/08ace4142eb281c12701fc3b93a10795e4d4dc7f753911d836675050f886/msgpack-1.1.2-cp314-cp314t-win_arm64.whl", hash = "sha256:d99ef64f349d5ec3293688e91486c5fdb925ed03807f64d98d205d2713c60b46", size = 70868, upload-time = "2025-10-08T09:15:44.959Z" }, +] + [[package]] name = "multidict" version = "6.7.0" @@ -2946,6 +3027,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/f3/b67d6ea49ca9154453b6d70b34ea22f3996b9fa55da105a79d8732227adc/soupsieve-2.8.1-py3-none-any.whl", hash = "sha256:a11fe2a6f3d76ab3cf2de04eb339c1be5b506a8a47f2ceb6d139803177f85434", size = 36710, upload-time = "2025-12-18T13:50:33.267Z" }, ] +[[package]] +name = "sseclient-py" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/2e/59920f7d66b7f9932a3d83dd0ec53fab001be1e058bf582606fe414a5198/sseclient_py-1.9.0-py3-none-any.whl", hash = "sha256:340062b1587fc2880892811e2ab5b176d98ef3eee98b3672ff3a3ba1e8ed0f6f", size = 8351, upload-time = "2026-01-02T23:39:30.995Z" }, +] + [[package]] name = "stack-data" version = "0.6.3" From c5dc657ed1754477829c5f845cb503fc5856a301 Mon Sep 17 00:00:00 2001 From: Alessandro Date: Thu, 30 Apr 2026 18:25:32 +0200 Subject: [PATCH 2/3] improve --- src/ml4t/data/providers/alpaca.py | 31 ++++++++----- tests/integration/test_alpaca.py | 70 ++++++++++++++++++++++++++++++ tests/test_alpaca_provider_unit.py | 47 ++++++++++++++++++++ 3 files changed, 138 insertions(+), 10 deletions(-) create mode 100644 tests/integration/test_alpaca.py diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py index 2a26c2b..785b39e 100644 --- a/src/ml4t/data/providers/alpaca.py +++ b/src/ml4t/data/providers/alpaca.py @@ -123,10 +123,20 @@ def _timeframe_for(frequency: str) -> TimeFrame: def _utc_range(start: str, end: str) -> tuple[datetime, datetime, datetime]: """Parse YYYY-MM-DD range to Alpaca UTC window (end exclusive for API, inclusive filter).""" - start_utc = datetime.combine( - datetime.strptime(start, "%Y-%m-%d").date(), dt_time.min, tzinfo=UTC - ) - end_day = datetime.strptime(end, "%Y-%m-%d").date() + start_n = _normalize_iso_date(start) + end_n = _normalize_iso_date(end) + try: + start_utc = datetime.combine( + datetime.strptime(start_n, "%Y-%m-%d").date(), dt_time.min, tzinfo=UTC + ) + end_day = datetime.strptime(end_n, "%Y-%m-%d").date() + except ValueError as e: + raise DataValidationError( + "alpaca", + f"Invalid start/end date (expected YYYY-MM-DD): {e}", + field="start/end", + value=f"{start!r} / {end!r}", + ) from e end_excl = datetime.combine(end_day + timedelta(days=1), dt_time.min, tzinfo=UTC) end_inclusive = datetime.combine(end_day, dt_time(23, 59, 59, 999999), tzinfo=UTC) return start_utc, end_excl, end_inclusive @@ -183,9 +193,7 @@ def _bars_pandas_to_polars_batch(df_pd: pd.DataFrame, end_inclusive: datetime) - "Expected multi-symbol Alpaca bars to include a symbol column after reset_index", field="bars.df", ) - out = pl.from_pandas( - df_pd[["timestamp", "symbol", "open", "high", "low", "close", "volume"]] - ) + out = pl.from_pandas(df_pd[["timestamp", "symbol", "open", "high", "low", "close", "volume"]]) return ( out.with_columns( [ @@ -239,13 +247,14 @@ def __init__( Args: api_key: Alpaca API key; defaults to ``ALPACA_API_KEY``. secret_key: Alpaca secret key; defaults to ``ALPACA_SECRET_KEY``. - feed: Stock bar feed (e.g. SIP); optional. + feed: Stock bar feed (e.g. SIP); optional. If omitted, uses ``ALPACA_STOCK_FEED`` + when set, otherwise lets the SDK default apply. adjustment: Corporate-action adjustment for stock bars; optional. """ super().__init__(rate_limit=None) self._api_key = api_key or os.getenv("ALPACA_API_KEY") self._secret_key = secret_key or os.getenv("ALPACA_SECRET_KEY") - self._feed = feed + self._feed = feed if feed is not None else _optional_stock_feed_from_env() self._adjustment = adjustment self._crypto: CryptoHistoricalDataClient | None = None self._stock: StockHistoricalDataClient | None = None @@ -456,7 +465,9 @@ def fetch_batch_ohlcv( failed_symbols=len(failed), ) if failed: - self.logger.warning("Some Alpaca batch symbols failed", count=len(failed), sample=failed[:10]) + self.logger.warning( + "Some Alpaca batch symbols failed", count=len(failed), sample=failed[:10] + ) return result async def fetch_batch_ohlcv_async( diff --git a/tests/integration/test_alpaca.py b/tests/integration/test_alpaca.py new file mode 100644 index 0000000..e28384e --- /dev/null +++ b/tests/integration/test_alpaca.py @@ -0,0 +1,70 @@ +"""Live Alpaca API checks when ``ALPACA_API_KEY`` and ``ALPACA_SECRET_KEY`` are set.""" + +from __future__ import annotations + +import os + +import polars as pl +import pytest + +pytest.importorskip("alpaca") + +from ml4t.data.providers.alpaca import AlpacaProvider + +pytestmark = pytest.mark.integration + + +@pytest.fixture +def alpaca_keys(): + """Skip unless both Alpaca credentials are present.""" + key = os.getenv("ALPACA_API_KEY") + secret = os.getenv("ALPACA_SECRET_KEY") + if not key or not secret: + pytest.skip("ALPACA_API_KEY and ALPACA_SECRET_KEY not both set") + return key, secret + + +class TestAlpacaLiveOhlcv: + """Minimal real calls to validate batch and single-symbol paths.""" + + def test_fetch_batch_ohlcv_stocks_5minute(self, alpaca_keys): + """Short window + two symbols to limit payload and rate impact.""" + p = AlpacaProvider() + df = p.fetch_batch_ohlcv( + ["AAPL", "MSFT"], + start="2024-01-02", + end="2024-01-04", + frequency="5minute", + ) + assert isinstance(df, pl.DataFrame) + assert not df.is_empty() + assert set(df.columns) == { + "timestamp", + "symbol", + "open", + "high", + "low", + "close", + "volume", + } + syms = set(df["symbol"].unique().to_list()) + assert syms <= {"AAPL", "MSFT"} + assert len(syms) >= 1 + + def test_fetch_ohlcv_single_daily(self, alpaca_keys): + p = AlpacaProvider() + df = p.fetch_ohlcv("AAPL", "2024-01-02", "2024-01-10", frequency="daily") + assert isinstance(df, pl.DataFrame) + assert not df.is_empty() + assert df["symbol"].unique().to_list() == ["AAPL"] + + def test_batch_tolerates_whitespace_in_end_date(self, alpaca_keys): + """Regression: ``2024-12- 31`` style typos should not raise ValueError.""" + p = AlpacaProvider() + df = p.fetch_batch_ohlcv( + ["AAPL"], + start="2024-12-30", + end="2024-12- 31", + frequency="daily", + ) + assert len(df) >= 1 diff --git a/tests/test_alpaca_provider_unit.py b/tests/test_alpaca_provider_unit.py index b6b38ee..d83c2ca 100644 --- a/tests/test_alpaca_provider_unit.py +++ b/tests/test_alpaca_provider_unit.py @@ -17,7 +17,10 @@ _bars_pandas_to_polars, _bars_pandas_to_polars_batch, _infer_kind, + _normalize_iso_date, + _optional_stock_feed_from_env, _timeframe_for, + _utc_range, ) @@ -58,6 +61,50 @@ def test_map_normalizes_to_internal_names(self): assert p.FREQUENCY_MAP["5minute"] == "5minute" +class TestNormalizeIsoDate: + """Whitespace-tolerant YYYY-MM-DD inputs (common copy-paste typo).""" + + def test_collapses_internal_space(self): + assert _normalize_iso_date("2024-12- 31") == "2024-12-31" + + def test_strips_edges(self): + assert _normalize_iso_date(" 2024-01-01 ") == "2024-01-01" + + +class TestUtcRange: + """UTC window parsing for Alpaca requests.""" + + def test_accepts_date_with_stray_whitespace(self): + start_utc, end_excl, end_inc = _utc_range("2024-01-01", "2024-12- 31") + assert start_utc.year == 2024 and start_utc.month == 1 + assert end_inc.month == 12 and end_inc.day == 31 + + def test_invalid_date_raises_data_validation(self): + with pytest.raises(DataValidationError, match="Invalid start/end date"): + _utc_range("2024-13-01", "2024-12-31") + + +class TestStockFeedFromEnv: + """ALPACA_STOCK_FEED mirrors metafreq / Alpaca tier defaults.""" + + def test_unset_returns_none(self): + with patch.dict("os.environ", {}, clear=True): + assert _optional_stock_feed_from_env() is None + + def test_iex_parsed(self): + with patch.dict("os.environ", {"ALPACA_STOCK_FEED": "iex"}): + from alpaca.data.enums import DataFeed + + assert _optional_stock_feed_from_env() == DataFeed.IEX + + def test_init_picks_feed_from_env_when_param_omitted(self): + with patch.dict("os.environ", {"ALPACA_STOCK_FEED": "iex"}): + from alpaca.data.enums import DataFeed + + p = AlpacaProvider() + assert p._feed == DataFeed.IEX + + class TestTimeframeFor: """Tests for Alpaca TimeFrame mapping.""" From 7a48023f69a21e364a2f8518bfd7752d58f9f5c5 Mon Sep 17 00:00:00 2001 From: Alessandro Date: Sat, 2 May 2026 21:06:57 +0200 Subject: [PATCH 3/3] add test and register to provider --- src/ml4t/data/managers/provider_manager.py | 10 ++++++++++ tests/test_provider_registration.py | 11 +++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/ml4t/data/managers/provider_manager.py b/src/ml4t/data/managers/provider_manager.py index 1178d5a..e24183d 100644 --- a/src/ml4t/data/managers/provider_manager.py +++ b/src/ml4t/data/managers/provider_manager.py @@ -215,6 +215,13 @@ def _get_provider_classes(cls) -> dict[str, type]: except ImportError: pass + try: + from ml4t.data.providers.alpaca import AlpacaProvider + + provider_classes["alpaca"] = AlpacaProvider + except ImportError: + pass + cls._PROVIDER_CLASSES = provider_classes return provider_classes @@ -248,6 +255,9 @@ def _detect_available_providers(self) -> None: if free_provider not in self._available_providers: self._available_providers.append(free_provider) + if "alpaca" in self._provider_classes and "alpaca" not in self._available_providers: + self._available_providers.append("alpaca") + @property def available_providers(self) -> list[str]: """Get list of available provider names.""" diff --git a/tests/test_provider_registration.py b/tests/test_provider_registration.py index d76b0c8..34ae456 100644 --- a/tests/test_provider_registration.py +++ b/tests/test_provider_registration.py @@ -1,6 +1,7 @@ """Test provider registration in DataManager.""" from ml4t.data.data_manager import DataManager +from ml4t.data.providers.alpaca import AlpacaProvider from ml4t.data.providers.binance import BinanceProvider from ml4t.data.providers.binance_public import BinancePublicProvider from ml4t.data.providers.cryptocompare import CryptoCompareProvider @@ -13,7 +14,7 @@ def test_all_providers_registered(): - """Test that all 9 OHLCV providers are registered in DataManager. + """Test that all 10 OHLCV providers are registered in DataManager. Note: Specialized providers (factor, prediction market) are not in DataManager. - Factor providers: aqr, fama_french (standalone, different API) @@ -30,10 +31,11 @@ def test_all_providers_registered(): "okx": OKXProvider, "synthetic": SyntheticProvider, "yahoo": YahooFinanceProvider, + "alpaca": AlpacaProvider, } assert expected_providers == DataManager.PROVIDER_CLASSES - assert len(DataManager.PROVIDER_CLASSES) == 9 + assert len(DataManager.PROVIDER_CLASSES) == len(expected_providers) def test_provider_imports_work(): @@ -48,6 +50,7 @@ def test_provider_imports_work(): OandaProvider, SyntheticProvider, YahooFinanceProvider, + AlpacaProvider, ] for provider_class in providers: @@ -98,11 +101,10 @@ def test_provider_count(): - Prediction markets: kalshi, polymarket - Historical: wiki_prices """ - assert len(DataManager.PROVIDER_CLASSES) == 9, ( + assert len(DataManager.PROVIDER_CLASSES) == 10, ( f"Expected 9 providers, got {len(DataManager.PROVIDER_CLASSES)}" ) - # List all provider names for clarity provider_names = list(DataManager.PROVIDER_CLASSES.keys()) assert "binance" in provider_names assert "binance_public" in provider_names @@ -113,6 +115,7 @@ def test_provider_count(): assert "okx" in provider_names assert "synthetic" in provider_names assert "yahoo" in provider_names + assert "alpaca" in provider_names if __name__ == "__main__":