From 260ff04a2ceca16a933527f9e8d03602a0dc47d7 Mon Sep 17 00:00:00 2001
From: Alessandro
Date: Thu, 30 Apr 2026 18:19:06 +0200
Subject: [PATCH 1/3] add alpaca historicals provider
---
pyproject.toml | 6 +-
src/ml4t/data/providers/AGENTS.md | 1 +
src/ml4t/data/providers/__init__.py | 9 +
src/ml4t/data/providers/alpaca.py | 493 ++++++++++++++++++++++++++++
tests/test_alpaca_provider_unit.py | 315 ++++++++++++++++++
uv.lock | 91 ++++-
6 files changed, 913 insertions(+), 2 deletions(-)
create mode 100644 src/ml4t/data/providers/alpaca.py
create mode 100644 tests/test_alpaca_provider_unit.py
diff --git a/pyproject.toml b/pyproject.toml
index 19376bd..bbfc00e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -110,9 +110,12 @@ oanda = [
cot = [
"cot-reports>=0.1.0", # CFTC Commitment of Traders data
]
+alpaca = [
+ "alpaca-py>=0.30.0",
+]
all-providers = [
- "ml4t-data[yahoo,databento,oanda,cot]",
+ "ml4t-data[yahoo,databento,oanda,cot,alpaca]",
]
# Development dependencies
@@ -171,6 +174,7 @@ dev = [
"yfinance>=0.2.0",
"oandapyV20>=0.7.0",
"xlsxwriter>=3.1.0",
+ "alpaca-py>=0.30.0",
]
[tool.pytest.ini_options]
diff --git a/src/ml4t/data/providers/AGENTS.md b/src/ml4t/data/providers/AGENTS.md
index f0e3bc2..c3b437d 100644
--- a/src/ml4t/data/providers/AGENTS.md
+++ b/src/ml4t/data/providers/AGENTS.md
@@ -13,6 +13,7 @@
| File | Lines | Purpose |
|------|-------|---------|
| yahoo.py | 603 | Yahoo Finance (free) |
+| alpaca.py | — | Alpaca Markets OHLCV |
| binance_api.py | 410 | Binance REST API |
| binance_bulk.py | 1430 | Binance bulk historical archive |
| eodhd.py | 464 | EOD Historical Data |
diff --git a/src/ml4t/data/providers/__init__.py b/src/ml4t/data/providers/__init__.py
index 9a66521..005ceb6 100644
--- a/src/ml4t/data/providers/__init__.py
+++ b/src/ml4t/data/providers/__init__.py
@@ -5,6 +5,7 @@
Available Providers (20 live + 3 synthetic/testing):
- BaseProvider: Abstract base class for all providers
- YahooFinanceProvider: Yahoo Finance (free, no API key)
+ - AlpacaProvider: Alpaca Markets bars (optional: pip install 'ml4t-data[alpaca]')
- TiingoProvider: Tiingo stocks (free tier: 1000 req/day, 500 symbols/month)
- FinnhubProvider: Finnhub multi-asset data (free tier: 60 req/min)
- EODHDProvider: EODHD global equities (free tier: 500 req/day, 1 year depth)
@@ -48,6 +49,12 @@
except ImportError:
YahooFinanceProvider = None # type: ignore
+try:
+ from ml4t.data.providers.alpaca import AlpacaHistoricalProvider, AlpacaProvider
+except ImportError:
+ AlpacaHistoricalProvider = None # type: ignore
+ AlpacaProvider = None # type: ignore
+
try:
from ml4t.data.providers.tiingo import TiingoProvider
except ImportError:
@@ -148,6 +155,8 @@
"Provider",
# Equity providers
"YahooFinanceProvider",
+ "AlpacaHistoricalProvider",
+ "AlpacaProvider",
"TiingoProvider",
"FinnhubProvider",
"EODHDProvider",
diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py
new file mode 100644
index 0000000..2a26c2b
--- /dev/null
+++ b/src/ml4t/data/providers/alpaca.py
@@ -0,0 +1,493 @@
+"""Alpaca Markets historical OHLCV bars via alpaca-py.
+
+Equities and options require API keys. Crypto historical data does not (higher
+rate limits if keys are provided). See
+https://alpaca.markets/sdks/python/market_data.html
+
+Authentication:
+ For stocks/options, set ``ALPACA_API_KEY`` and ``ALPACA_SECRET_KEY`` or pass
+ ``api_key`` / ``secret_key`` to :class:`AlpacaProvider`.
+
+ Optional: set ``ALPACA_STOCK_FEED`` (e.g. ``iex``, ``sip``) when ``feed`` is not
+ passed to the constructor; useful for free-tier IEX access.
+
+Date strings ``start`` / ``end`` may include stray whitespace (e.g. ``2024-12- 31``);
+it is normalized before parsing.
+
+This module requires ``alpaca-py`` (``pip install 'ml4t-data[alpaca]'``). If the
+package is not installed, ``import ml4t.data.providers.alpaca`` raises
+``ImportError``; :mod:`ml4t.data.providers` still loads and exposes
+``AlpacaProvider`` as ``None`` when the extra is missing.
+
+Example:
+ >>> from ml4t.data.providers.alpaca import AlpacaProvider
+ >>> p = AlpacaProvider() # crypto only, no keys
+ >>> # p.fetch_ohlcv("BTC/USD", "2024-01-01", "2024-01-31", "daily")
+ >>> p2 = AlpacaProvider(api_key="...", secret_key="...")
+ >>> # p2.fetch_ohlcv("AAPL", "2024-01-01", "2024-01-31", "daily")
+"""
+
+from __future__ import annotations
+
+import asyncio
+import os
+import re
+import time
+from collections import defaultdict
+from datetime import UTC, datetime, timedelta
+from datetime import time as dt_time
+from typing import ClassVar, Literal
+
+import pandas as pd
+import polars as pl
+from alpaca.data.enums import Adjustment, DataFeed
+from alpaca.data.historical import (
+ CryptoHistoricalDataClient,
+ OptionHistoricalDataClient,
+ StockHistoricalDataClient,
+)
+from alpaca.data.requests import CryptoBarsRequest, OptionBarsRequest, StockBarsRequest
+from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
+
+from ml4t.data.core.exceptions import (
+ AuthenticationError,
+ DataValidationError,
+ SymbolNotFoundError,
+)
+from ml4t.data.providers.base import BaseProvider
+
+__all__ = ["AlpacaHistoricalProvider", "AlpacaProvider"]
+
+
+def _chunks(lst: list[str], n: int):
+ """Yield successive n-sized chunks from lst (same pattern as Yahoo)."""
+ for i in range(0, len(lst), n):
+ yield lst[i : i + n]
+
+
+_OPTION_RE = re.compile(r"^[A-Z]{1,5}\d{6,7}[CP]\d{8}$")
+Kind = Literal["crypto", "stock", "option"]
+
+
+def _normalize_iso_date(s: str) -> str:
+ """Collapse whitespace in YYYY-MM-DD strings so values like ``2024-12- 31`` parse."""
+ return "".join(s.split())
+
+
+def _optional_stock_feed_from_env() -> DataFeed | None:
+ """Resolve ``ALPACA_STOCK_FEED`` to :class:`~alpaca.data.enums.DataFeed` if set."""
+ raw = os.getenv("ALPACA_STOCK_FEED", "").strip()
+ if not raw:
+ return None
+ try:
+ return DataFeed(raw.lower())
+ except ValueError:
+ return None
+
+
+def _infer_kind(symbol: str) -> Kind:
+ if "/" in symbol:
+ return "crypto"
+ if _OPTION_RE.match(symbol):
+ return "option"
+ return "stock"
+
+
+def _timeframe_for(frequency: str) -> TimeFrame:
+ key = frequency.lower()
+ match key:
+ case "minute" | "1minute":
+ return TimeFrame(1, TimeFrameUnit.Minute)
+ case "5minute":
+ return TimeFrame(5, TimeFrameUnit.Minute)
+ case "15minute":
+ return TimeFrame(15, TimeFrameUnit.Minute)
+ case "30minute":
+ return TimeFrame(30, TimeFrameUnit.Minute)
+ case "hourly" | "1hour":
+ return TimeFrame(1, TimeFrameUnit.Hour)
+ case "daily" | "1day":
+ return TimeFrame(1, TimeFrameUnit.Day)
+ case "weekly" | "1week":
+ return TimeFrame(1, TimeFrameUnit.Week)
+ case "monthly" | "1month":
+ return TimeFrame(1, TimeFrameUnit.Month)
+ case _:
+ raise DataValidationError(
+ "alpaca",
+ f"Unsupported frequency for Alpaca: {frequency!r}",
+ field="frequency",
+ value=frequency,
+ )
+
+
+def _utc_range(start: str, end: str) -> tuple[datetime, datetime, datetime]:
+ """Parse YYYY-MM-DD range to Alpaca UTC window (end exclusive for API, inclusive filter)."""
+ start_utc = datetime.combine(
+ datetime.strptime(start, "%Y-%m-%d").date(), dt_time.min, tzinfo=UTC
+ )
+ end_day = datetime.strptime(end, "%Y-%m-%d").date()
+ end_excl = datetime.combine(end_day + timedelta(days=1), dt_time.min, tzinfo=UTC)
+ end_inclusive = datetime.combine(end_day, dt_time(23, 59, 59, 999999), tzinfo=UTC)
+ return start_utc, end_excl, end_inclusive
+
+
+def _bars_pandas_to_polars(
+ df_pd: pd.DataFrame,
+ symbol: str,
+ end_inclusive: datetime,
+) -> pl.DataFrame:
+ """Map Alpaca ``BarSet.df`` (pandas) to ml4t OHLCV Polars schema (single symbol)."""
+ if isinstance(df_pd.index, pd.MultiIndex) and "symbol" in (df_pd.index.names or []):
+ df_pd = df_pd.droplevel("symbol")
+ df_pd = df_pd.reset_index()
+ time_col = "timestamp" if "timestamp" in df_pd.columns else df_pd.columns[0]
+ out = pl.from_pandas(df_pd[[time_col, "open", "high", "low", "close", "volume"]])
+ out = out.rename({time_col: "timestamp"})
+ return (
+ out.with_columns(
+ [
+ pl.col("timestamp").cast(pl.Datetime(time_zone="UTC")),
+ pl.col("open").cast(pl.Float64),
+ pl.col("high").cast(pl.Float64),
+ pl.col("low").cast(pl.Float64),
+ pl.col("close").cast(pl.Float64),
+ pl.col("volume").cast(pl.Float64),
+ ]
+ )
+ .filter(pl.col("timestamp") <= pl.lit(end_inclusive))
+ .with_columns(pl.lit(symbol).alias("symbol"))
+ .select(["timestamp", "symbol", "open", "high", "low", "close", "volume"])
+ .sort("timestamp")
+ )
+
+
+def _bars_pandas_to_polars_batch(df_pd: pd.DataFrame, end_inclusive: datetime) -> pl.DataFrame:
+ """Map Alpaca multi-symbol ``BarSet.df`` to long-format canonical Polars."""
+ if df_pd.empty:
+ return pl.DataFrame(
+ schema={
+ "timestamp": pl.Datetime(time_zone="UTC"),
+ "symbol": pl.Utf8,
+ "open": pl.Float64,
+ "high": pl.Float64,
+ "low": pl.Float64,
+ "close": pl.Float64,
+ "volume": pl.Float64,
+ }
+ )
+ df_pd = df_pd.reset_index()
+ if "symbol" not in df_pd.columns:
+ raise DataValidationError(
+ "alpaca",
+ "Expected multi-symbol Alpaca bars to include a symbol column after reset_index",
+ field="bars.df",
+ )
+ out = pl.from_pandas(
+ df_pd[["timestamp", "symbol", "open", "high", "low", "close", "volume"]]
+ )
+ return (
+ out.with_columns(
+ [
+ pl.col("timestamp").cast(pl.Datetime(time_zone="UTC")),
+ pl.col("symbol").cast(pl.Utf8),
+ pl.col("open").cast(pl.Float64),
+ pl.col("high").cast(pl.Float64),
+ pl.col("low").cast(pl.Float64),
+ pl.col("close").cast(pl.Float64),
+ pl.col("volume").cast(pl.Float64),
+ ]
+ )
+ .filter(pl.col("timestamp") <= pl.lit(end_inclusive))
+ .select(["timestamp", "symbol", "open", "high", "low", "close", "volume"])
+ .sort(["symbol", "timestamp"])
+ )
+
+
+class AlpacaProvider(BaseProvider):
+ """Alpaca historical bars; output matches :class:`BaseProvider` OHLCV contract.
+
+ Install: ``pip install 'ml4t-data[alpaca]'``
+ """
+
+ FREQUENCY_MAP: ClassVar[dict[str, str]] = {
+ "minute": "1minute",
+ "1minute": "1minute",
+ "5minute": "5minute",
+ "15minute": "15minute",
+ "30minute": "30minute",
+ "hourly": "1hour",
+ "1hour": "1hour",
+ "daily": "1day",
+ "1day": "1day",
+ "weekly": "1week",
+ "1week": "1week",
+ "monthly": "1month",
+ "1month": "1month",
+ }
+
+ def __init__(
+ self,
+ api_key: str | None = None,
+ secret_key: str | None = None,
+ *,
+ feed: DataFeed | None = None,
+ adjustment: Adjustment | None = None,
+ ) -> None:
+ """Initialize Alpaca provider.
+
+ Args:
+ api_key: Alpaca API key; defaults to ``ALPACA_API_KEY``.
+ secret_key: Alpaca secret key; defaults to ``ALPACA_SECRET_KEY``.
+ feed: Stock bar feed (e.g. SIP); optional.
+ adjustment: Corporate-action adjustment for stock bars; optional.
+ """
+ super().__init__(rate_limit=None)
+ self._api_key = api_key or os.getenv("ALPACA_API_KEY")
+ self._secret_key = secret_key or os.getenv("ALPACA_SECRET_KEY")
+ self._feed = feed
+ self._adjustment = adjustment
+ self._crypto: CryptoHistoricalDataClient | None = None
+ self._stock: StockHistoricalDataClient | None = None
+ self._option: OptionHistoricalDataClient | None = None
+
+ @property
+ def name(self) -> str:
+ """Return provider name."""
+ return "alpaca"
+
+ def _client(self, kind: Kind):
+ if kind == "crypto":
+ if self._crypto is None:
+ self._crypto = CryptoHistoricalDataClient()
+ return self._crypto
+ if self._api_key is None or self._secret_key is None:
+ raise AuthenticationError(
+ "alpaca",
+ "api_key and secret_key are required for stock and option bars. "
+ "Set ALPACA_API_KEY and ALPACA_SECRET_KEY environment variables "
+ "or pass api_key and secret_key.",
+ )
+ if kind == "stock":
+ if self._stock is None:
+ self._stock = StockHistoricalDataClient(self._api_key, self._secret_key)
+ return self._stock
+ if self._option is None:
+ self._option = OptionHistoricalDataClient(self._api_key, self._secret_key)
+ return self._option
+
+ def _fetch_and_transform_data(
+ self, symbol: str, start: str, end: str, frequency: str
+ ) -> pl.DataFrame:
+ kind = _infer_kind(symbol)
+ tf = _timeframe_for(self.FREQUENCY_MAP.get(frequency.lower(), frequency))
+ client = self._client(kind)
+ start_utc, end_excl, end_inclusive = _utc_range(start, end)
+
+ self.logger.info(
+ "Fetching Alpaca bars",
+ symbol=symbol,
+ start=start,
+ end=end,
+ frequency=frequency,
+ kind=kind,
+ )
+
+ try:
+ if kind == "crypto":
+ req = CryptoBarsRequest(
+ symbol_or_symbols=symbol,
+ timeframe=tf,
+ start=start_utc,
+ end=end_excl,
+ )
+ bars = client.get_crypto_bars(req)
+ elif kind == "stock":
+ req = StockBarsRequest(
+ symbol_or_symbols=symbol,
+ timeframe=tf,
+ start=start_utc,
+ end=end_excl,
+ feed=self._feed,
+ adjustment=self._adjustment,
+ )
+ bars = client.get_stock_bars(req)
+ else:
+ req = OptionBarsRequest(
+ symbol_or_symbols=symbol,
+ timeframe=tf,
+ start=start_utc,
+ end=end_excl,
+ )
+ bars = client.get_option_bars(req)
+ except Exception as e:
+ raise DataValidationError(
+ "alpaca",
+ f"Failed to fetch {symbol}: {e}",
+ details={"symbol": symbol, "error": str(e)},
+ ) from e
+
+ df_pd = bars.df
+ if df_pd.empty:
+ raise SymbolNotFoundError(
+ "alpaca",
+ symbol,
+ details={"start": start, "end": end, "frequency": frequency},
+ )
+
+ out = _bars_pandas_to_polars(df_pd, symbol, end_inclusive)
+ self.logger.info("Fetched Alpaca bars", symbol=symbol, rows=len(out))
+ return out
+
+ def fetch_batch_ohlcv(
+ self,
+ symbols: list[str],
+ start: str,
+ end: str,
+ frequency: str = "daily",
+ chunk_size: int = 50,
+ delay_seconds: float = 0.0,
+ ) -> pl.DataFrame:
+ """Fetch OHLCV for multiple symbols using Alpaca multi-symbol bar requests.
+
+ Symbols are grouped by asset class (crypto / stock / option). Each chunk
+ calls the API with ``symbol_or_symbols=[...]`` (same idea as Yahoo batch).
+
+ Args:
+ symbols: Ticker list (e.g. ``[\"AAPL\", \"MSFT\"]`` or crypto ``BTC/USD``).
+ start: Start date ``YYYY-MM-DD``.
+ end: End date ``YYYY-MM-DD`` (inclusive).
+ frequency: Same names as :meth:`fetch_ohlcv`.
+ chunk_size: Max symbols per API request per asset class.
+ delay_seconds: Pause between chunks (rate limiting).
+
+ Returns:
+ Long-format Polars frame, sorted by ``symbol``, ``timestamp``.
+ """
+ if not symbols:
+ return self._create_empty_dataframe()
+
+ tf = _timeframe_for(self.FREQUENCY_MAP.get(frequency.lower(), frequency))
+ start_utc, end_excl, end_inclusive = _utc_range(start, end)
+
+ by_kind: dict[Kind, list[str]] = defaultdict(list)
+ for s in symbols:
+ by_kind[_infer_kind(s)].append(s)
+
+ if ("stock" in by_kind or "option" in by_kind) and (
+ self._api_key is None or self._secret_key is None
+ ):
+ raise AuthenticationError(
+ "alpaca",
+ "api_key and secret_key are required when batch includes "
+ "stock or option symbols. Set ALPACA_API_KEY / ALPACA_SECRET_KEY "
+ "or pass keys to the constructor.",
+ )
+
+ self.logger.info(
+ "Starting Alpaca batch download",
+ total_symbols=len(symbols),
+ chunk_size=chunk_size,
+ start=start,
+ end=end,
+ )
+
+ all_parts: list[pl.DataFrame] = []
+ failed: list[str] = []
+
+ for kind in ("crypto", "stock", "option"):
+ syms = by_kind.get(kind)
+ if not syms:
+ continue
+ client = self._client(kind)
+ n_chunks = (len(syms) + chunk_size - 1) // chunk_size
+ for i, chunk in enumerate(_chunks(syms, chunk_size), 1):
+ try:
+ if kind == "crypto":
+ req = CryptoBarsRequest(
+ symbol_or_symbols=chunk,
+ timeframe=tf,
+ start=start_utc,
+ end=end_excl,
+ )
+ bars = client.get_crypto_bars(req)
+ elif kind == "stock":
+ req = StockBarsRequest(
+ symbol_or_symbols=chunk,
+ timeframe=tf,
+ start=start_utc,
+ end=end_excl,
+ feed=self._feed,
+ adjustment=self._adjustment,
+ )
+ bars = client.get_stock_bars(req)
+ else:
+ req = OptionBarsRequest(
+ symbol_or_symbols=chunk,
+ timeframe=tf,
+ start=start_utc,
+ end=end_excl,
+ )
+ bars = client.get_option_bars(req)
+
+ df_pd = bars.df
+ if df_pd.empty:
+ self.logger.warning("Empty Alpaca batch chunk", chunk=i, symbols=chunk)
+ failed.extend(chunk)
+ continue
+ all_parts.append(_bars_pandas_to_polars_batch(df_pd, end_inclusive))
+ except Exception as e:
+ self.logger.error(
+ "Alpaca batch chunk failed", chunk=i, symbols=chunk, error=str(e)
+ )
+ failed.extend(chunk)
+
+ if i < n_chunks and delay_seconds > 0:
+ time.sleep(delay_seconds)
+
+ if not all_parts:
+ self.logger.error("No Alpaca batch data", failed_symbols=failed)
+ return self._create_empty_dataframe()
+
+ result = pl.concat(all_parts).sort(["symbol", "timestamp"])
+ self.logger.info(
+ "Alpaca batch download complete",
+ rows=len(result),
+ failed_symbols=len(failed),
+ )
+ if failed:
+ self.logger.warning("Some Alpaca batch symbols failed", count=len(failed), sample=failed[:10])
+ return result
+
+ async def fetch_batch_ohlcv_async(
+ self,
+ symbols: list[str],
+ start: str,
+ end: str,
+ frequency: str = "daily",
+ chunk_size: int = 50,
+ delay_seconds: float = 0.0,
+ ) -> pl.DataFrame:
+ """Async batch fetch (sync Alpaca client via :func:`asyncio.to_thread`)."""
+ return await asyncio.to_thread(
+ self.fetch_batch_ohlcv,
+ symbols,
+ start,
+ end,
+ frequency,
+ chunk_size,
+ delay_seconds,
+ )
+
+ async def close_async(self) -> None:
+ """No network session to close (SDK owns HTTP)."""
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close_async()
+
+
+AlpacaHistoricalProvider = AlpacaProvider
diff --git a/tests/test_alpaca_provider_unit.py b/tests/test_alpaca_provider_unit.py
new file mode 100644
index 0000000..b6b38ee
--- /dev/null
+++ b/tests/test_alpaca_provider_unit.py
@@ -0,0 +1,315 @@
+"""Tests for Alpaca OHLCV provider (requires ``alpaca-py`` / ``ml4t-data[alpaca]``)."""
+
+import asyncio
+from datetime import UTC, datetime
+from unittest.mock import MagicMock, patch
+
+import pandas as pd
+import polars as pl
+import pytest
+
+pytest.importorskip("alpaca")
+
+from ml4t.data.core.exceptions import AuthenticationError, DataValidationError
+from ml4t.data.providers.alpaca import (
+ AlpacaHistoricalProvider,
+ AlpacaProvider,
+ _bars_pandas_to_polars,
+ _bars_pandas_to_polars_batch,
+ _infer_kind,
+ _timeframe_for,
+)
+
+
+class TestInferKind:
+ """Tests for asset-class inference from symbol."""
+
+ def test_crypto_slash(self):
+ assert _infer_kind("BTC/USD") == "crypto"
+ assert _infer_kind("ETH/USD") == "crypto"
+
+ def test_option_oc(self):
+ assert _infer_kind("AAPL241220C00300000") == "option"
+
+ def test_stock(self):
+ assert _infer_kind("AAPL") == "stock"
+ assert _infer_kind("MSFT") == "stock"
+
+
+class TestFrequencyMap:
+ """FREQUENCY_MAP mirrors Yahoo-style aliases."""
+
+ def test_keys_cover_aliases(self):
+ p = AlpacaProvider()
+ for key in (
+ "daily",
+ "1day",
+ "minute",
+ "5minute",
+ "hourly",
+ "monthly",
+ ):
+ assert key in p.FREQUENCY_MAP
+
+ def test_map_normalizes_to_internal_names(self):
+ p = AlpacaProvider()
+ assert p.FREQUENCY_MAP["daily"] == "1day"
+ assert p.FREQUENCY_MAP["minute"] == "1minute"
+ assert p.FREQUENCY_MAP["5minute"] == "5minute"
+
+
+class TestTimeframeFor:
+ """Tests for Alpaca TimeFrame mapping."""
+
+ def test_daily(self):
+ from alpaca.data.timeframe import TimeFrameUnit
+
+ tf = _timeframe_for("daily")
+ assert tf.amount == 1
+ assert tf.unit == TimeFrameUnit.Day
+
+ def test_five_minute(self):
+ from alpaca.data.timeframe import TimeFrameUnit
+
+ tf = _timeframe_for("5minute")
+ assert tf.amount == 5
+ assert tf.unit == TimeFrameUnit.Minute
+
+ def test_invalid_frequency_raises_data_validation(self):
+ with pytest.raises(DataValidationError, match="Unsupported frequency"):
+ _timeframe_for("not_a_real_freq")
+
+
+class TestBarsPandasToPolars:
+ """Pandas BarSet → Polars (single- and multi-symbol)."""
+
+ def test_multiindex_symbol_timestamp(self):
+ idx = pd.MultiIndex.from_tuples(
+ [("BTC/USD", pd.Timestamp("2024-01-02", tz="UTC"))],
+ names=["symbol", "timestamp"],
+ )
+ df_pd = pd.DataFrame(
+ {"open": [1.0], "high": [2.0], "low": [0.5], "close": [1.5], "volume": [100.0]},
+ index=idx,
+ )
+ end_inc = datetime(2024, 1, 3, 23, 59, 59, 999999, tzinfo=UTC)
+ out = _bars_pandas_to_polars(df_pd, "BTC/USD", end_inc)
+ assert len(out) == 1
+ assert out.columns == [
+ "timestamp",
+ "symbol",
+ "open",
+ "high",
+ "low",
+ "close",
+ "volume",
+ ]
+ assert out["symbol"][0] == "BTC/USD"
+ assert out["close"][0] == 1.5
+
+ def test_end_filter_excludes_bar_after_range(self):
+ idx = pd.MultiIndex.from_tuples(
+ [
+ ("BTC/USD", pd.Timestamp("2024-01-02", tz="UTC")),
+ ("BTC/USD", pd.Timestamp("2024-01-05", tz="UTC")),
+ ],
+ names=["symbol", "timestamp"],
+ )
+ df_pd = pd.DataFrame(
+ {
+ "open": [1.0, 2.0],
+ "high": [2.0, 3.0],
+ "low": [0.5, 1.5],
+ "close": [1.5, 2.5],
+ "volume": [100.0, 200.0],
+ },
+ index=idx,
+ )
+ end_inc = datetime(2024, 1, 3, 23, 59, 59, 999999, tzinfo=UTC)
+ out = _bars_pandas_to_polars(df_pd, "BTC/USD", end_inc)
+ assert len(out) == 1
+
+ def test_batch_two_symbols(self):
+ idx = pd.MultiIndex.from_tuples(
+ [
+ ("BTC/USD", pd.Timestamp("2024-01-02", tz="UTC")),
+ ("ETH/USD", pd.Timestamp("2024-01-02", tz="UTC")),
+ ],
+ names=["symbol", "timestamp"],
+ )
+ df_pd = pd.DataFrame(
+ {
+ "open": [1.0, 2.0],
+ "high": [2.0, 3.0],
+ "low": [0.5, 1.5],
+ "close": [1.5, 2.5],
+ "volume": [100.0, 200.0],
+ },
+ index=idx,
+ )
+ end_inc = datetime(2024, 1, 3, 23, 59, 59, 999999, tzinfo=UTC)
+ out = _bars_pandas_to_polars_batch(df_pd, end_inc)
+ assert len(out) == 2
+ assert set(out["symbol"].to_list()) == {"BTC/USD", "ETH/USD"}
+
+
+class TestCreateEmptyDataframe:
+ """Empty OHLCV schema (BaseProvider)."""
+
+ def test_empty_dataframe_columns(self):
+ provider = AlpacaProvider()
+ df = provider._create_empty_dataframe()
+ assert list(df.columns) == [
+ "timestamp",
+ "symbol",
+ "open",
+ "high",
+ "low",
+ "close",
+ "volume",
+ ]
+
+ def test_empty_dataframe_length(self):
+ provider = AlpacaProvider()
+ assert len(provider._create_empty_dataframe()) == 0
+
+
+class TestAuthentication:
+ """Stock/option paths require credentials."""
+
+ def test_stock_without_keys_raises(self):
+ with patch.dict("os.environ", {}, clear=True):
+ p = AlpacaProvider()
+ with pytest.raises(AuthenticationError, match="api_key and secret_key"):
+ p.fetch_ohlcv("AAPL", "2024-01-01", "2024-01-05", "daily")
+
+ def test_option_without_keys_raises(self):
+ with patch.dict("os.environ", {}, clear=True):
+ p = AlpacaProvider()
+ with pytest.raises(AuthenticationError, match="api_key and secret_key"):
+ p.fetch_ohlcv("AAPL241220C00300000", "2024-12-01", "2024-12-05", "daily")
+
+ def test_batch_stock_without_keys_raises(self):
+ with patch.dict("os.environ", {}, clear=True):
+ p = AlpacaProvider()
+ with pytest.raises(AuthenticationError, match="stock or option"):
+ p.fetch_batch_ohlcv(["AAPL", "MSFT"], "2024-01-01", "2024-01-05", "daily")
+
+ def test_init_reads_env_keys(self):
+ with patch.dict(
+ "os.environ",
+ {"ALPACA_API_KEY": "pk-x", "ALPACA_SECRET_KEY": "sec-y"},
+ ):
+ p = AlpacaProvider()
+ assert p._api_key == "pk-x"
+ assert p._secret_key == "sec-y"
+
+
+class TestFetchCryptoMocked:
+ """Offline tests with mocked Alpaca client."""
+
+ def test_fetch_crypto_returns_canonical_schema(self):
+ idx = pd.MultiIndex.from_tuples(
+ [("BTC/USD", pd.Timestamp("2024-01-02", tz="UTC"))],
+ names=["symbol", "timestamp"],
+ )
+ df_pd = pd.DataFrame(
+ {"open": [1.0], "high": [2.0], "low": [0.5], "close": [1.5], "volume": [100.0]},
+ index=idx,
+ )
+ barset = MagicMock()
+ barset.df = df_pd
+
+ with patch("ml4t.data.providers.alpaca.CryptoHistoricalDataClient") as client_cls:
+ inst = MagicMock()
+ client_cls.return_value = inst
+ inst.get_crypto_bars.return_value = barset
+ p = AlpacaProvider()
+ out = p._fetch_and_transform_data("BTC/USD", "2024-01-01", "2024-01-03", "daily")
+ assert len(out) == 1
+ assert out["close"][0] == 1.5
+
+ def test_fetch_batch_crypto_two_symbols_mocked(self):
+ idx = pd.MultiIndex.from_tuples(
+ [
+ ("BTC/USD", pd.Timestamp("2024-01-02", tz="UTC")),
+ ("ETH/USD", pd.Timestamp("2024-01-02", tz="UTC")),
+ ],
+ names=["symbol", "timestamp"],
+ )
+ df_pd = pd.DataFrame(
+ {
+ "open": [1.0, 2.0],
+ "high": [2.0, 3.0],
+ "low": [0.5, 1.5],
+ "close": [1.5, 2.5],
+ "volume": [100.0, 200.0],
+ },
+ index=idx,
+ )
+ barset = MagicMock()
+ barset.df = df_pd
+
+ with patch("ml4t.data.providers.alpaca.CryptoHistoricalDataClient") as client_cls:
+ inst = MagicMock()
+ client_cls.return_value = inst
+ inst.get_crypto_bars.return_value = barset
+ p = AlpacaProvider()
+ out = p.fetch_batch_ohlcv(
+ ["BTC/USD", "ETH/USD"], "2024-01-01", "2024-01-03", "daily", chunk_size=10
+ )
+ assert len(out) == 2
+ assert out.columns == [
+ "timestamp",
+ "symbol",
+ "open",
+ "high",
+ "low",
+ "close",
+ "volume",
+ ]
+
+ def test_fetch_batch_empty_symbols(self):
+ p = AlpacaProvider()
+ out = p.fetch_batch_ohlcv([], "2024-01-01", "2024-01-03")
+ assert len(out) == 0
+
+
+class TestAlpacaProviderInit:
+ """Naming and backward-compat alias."""
+
+ def test_name(self):
+ assert AlpacaProvider().name == "alpaca"
+
+ def test_historical_alias(self):
+ assert AlpacaHistoricalProvider is AlpacaProvider
+
+
+class TestInvalidFrequencyInFetch:
+ """Unsupported frequency → DataValidationError before auth."""
+
+ def test_fetch_bad_frequency(self):
+ p = AlpacaProvider()
+ with pytest.raises(DataValidationError, match="Unsupported frequency"):
+ p.fetch_ohlcv("BTC/USD", "2024-01-01", "2024-01-03", "bad_freq_xyz")
+
+ def test_stock_bad_frequency_without_keys_is_validation_not_auth(self):
+ with patch.dict("os.environ", {}, clear=True):
+ p = AlpacaProvider()
+ with pytest.raises(DataValidationError, match="Unsupported frequency"):
+ p.fetch_ohlcv("AAPL", "2024-01-01", "2024-01-03", "bad_freq_xyz")
+
+
+class TestFetchBatchAsync:
+ """Mirror Yahoo's asyncio.to_thread batch wrapper."""
+
+ def test_fetch_batch_ohlcv_async_delegates(self):
+ p = AlpacaProvider()
+ with patch.object(p, "fetch_batch_ohlcv", return_value=pl.DataFrame()) as mock_fb:
+ out = asyncio.run(
+ p.fetch_batch_ohlcv_async(
+ ["BTC/USD"], "2024-01-01", "2024-01-02", frequency="daily"
+ )
+ )
+ mock_fb.assert_called_once()
+ assert len(out) == 0
diff --git a/uv.lock b/uv.lock
index 77b62d1..cb13e31 100644
--- a/uv.lock
+++ b/uv.lock
@@ -139,6 +139,24 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
]
+[[package]]
+name = "alpaca-py"
+version = "0.43.4"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "msgpack" },
+ { name = "pandas" },
+ { name = "pydantic" },
+ { name = "pytz" },
+ { name = "requests" },
+ { name = "sseclient-py" },
+ { name = "websockets" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/2f/9d/3003f661c15b8003655c447c187aec10f0843647e5c98b391701b04ac3d8/alpaca_py-0.43.4.tar.gz", hash = "sha256:7d529b3654d4e817d9fd7ab461131c4f06a315c736b6a9e4a87d5406bb71114a", size = 97990, upload-time = "2026-04-29T08:41:48.775Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/3a/d5/1f57cc03e7b5925a927cb7f8e7ee5f873e22632633778d28d5d23681c871/alpaca_py-0.43.4-py3-none-any.whl", hash = "sha256:dd49ac30e0f2a8f38550ef1f27a58e7fd8f3f3875deaa4e757443cdbd033a1b4", size = 122534, upload-time = "2026-04-29T08:41:50.149Z" },
+]
+
[[package]]
name = "annotated-types"
version = "0.7.0"
@@ -1576,6 +1594,7 @@ dependencies = [
[package.optional-dependencies]
all = [
+ { name = "alpaca-py" },
{ name = "cot-reports" },
{ name = "databento" },
{ name = "hypothesis" },
@@ -1598,11 +1617,15 @@ all = [
{ name = "yfinance" },
]
all-providers = [
+ { name = "alpaca-py" },
{ name = "cot-reports" },
{ name = "databento" },
{ name = "oandapyv20" },
{ name = "yfinance" },
]
+alpaca = [
+ { name = "alpaca-py" },
+]
cot = [
{ name = "cot-reports" },
]
@@ -1639,6 +1662,7 @@ yahoo = [
[package.dev-dependencies]
dev = [
+ { name = "alpaca-py" },
{ name = "databento" },
{ name = "hypothesis" },
{ name = "oandapyv20" },
@@ -1658,6 +1682,9 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "aiofiles", specifier = ">=23.0.0" },
+ { name = "alpaca-py", marker = "extra == 'all'", specifier = ">=0.30.0" },
+ { name = "alpaca-py", marker = "extra == 'all-providers'", specifier = ">=0.30.0" },
+ { name = "alpaca-py", marker = "extra == 'alpaca'", specifier = ">=0.30.0" },
{ name = "click", specifier = ">=8.0.0" },
{ name = "cot-reports", marker = "extra == 'all'", specifier = ">=0.1.0" },
{ name = "cot-reports", marker = "extra == 'all-providers'", specifier = ">=0.1.0" },
@@ -1723,10 +1750,11 @@ requires-dist = [
{ name = "yfinance", marker = "extra == 'all-providers'", specifier = ">=0.2.0" },
{ name = "yfinance", marker = "extra == 'yahoo'", specifier = ">=0.2.0" },
]
-provides-extras = ["all", "all-providers", "cot", "databento", "dev", "docs", "oanda", "yahoo"]
+provides-extras = ["all", "all-providers", "alpaca", "cot", "databento", "dev", "docs", "oanda", "yahoo"]
[package.metadata.requires-dev]
dev = [
+ { name = "alpaca-py", specifier = ">=0.30.0" },
{ name = "databento", specifier = ">=0.38.0" },
{ name = "hypothesis", specifier = ">=6.80.0" },
{ name = "oandapyv20", specifier = ">=0.7.0" },
@@ -1752,6 +1780,59 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a4/8e/469e5a4a2f5855992e425f3cb33804cc07bf18d48f2db061aec61ce50270/more_itertools-10.8.0-py3-none-any.whl", hash = "sha256:52d4362373dcf7c52546bc4af9a86ee7c4579df9a8dc268be0a2f949d376cc9b", size = 69667, upload-time = "2025-09-02T15:23:09.635Z" },
]
+[[package]]
+name = "msgpack"
+version = "1.1.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/4d/f2/bfb55a6236ed8725a96b0aa3acbd0ec17588e6a2c3b62a93eb513ed8783f/msgpack-1.1.2.tar.gz", hash = "sha256:3b60763c1373dd60f398488069bcdc703cd08a711477b5d480eecc9f9626f47e", size = 173581, upload-time = "2025-10-08T09:15:56.596Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/2c/97/560d11202bcd537abca693fd85d81cebe2107ba17301de42b01ac1677b69/msgpack-1.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2e86a607e558d22985d856948c12a3fa7b42efad264dca8a3ebbcfa2735d786c", size = 82271, upload-time = "2025-10-08T09:14:49.967Z" },
+ { url = "https://files.pythonhosted.org/packages/83/04/28a41024ccbd67467380b6fb440ae916c1e4f25e2cd4c63abe6835ac566e/msgpack-1.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:283ae72fc89da59aa004ba147e8fc2f766647b1251500182fac0350d8af299c0", size = 84914, upload-time = "2025-10-08T09:14:50.958Z" },
+ { url = "https://files.pythonhosted.org/packages/71/46/b817349db6886d79e57a966346cf0902a426375aadc1e8e7a86a75e22f19/msgpack-1.1.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:61c8aa3bd513d87c72ed0b37b53dd5c5a0f58f2ff9f26e1555d3bd7948fb7296", size = 416962, upload-time = "2025-10-08T09:14:51.997Z" },
+ { url = "https://files.pythonhosted.org/packages/da/e0/6cc2e852837cd6086fe7d8406af4294e66827a60a4cf60b86575a4a65ca8/msgpack-1.1.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:454e29e186285d2ebe65be34629fa0e8605202c60fbc7c4c650ccd41870896ef", size = 426183, upload-time = "2025-10-08T09:14:53.477Z" },
+ { url = "https://files.pythonhosted.org/packages/25/98/6a19f030b3d2ea906696cedd1eb251708e50a5891d0978b012cb6107234c/msgpack-1.1.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7bc8813f88417599564fafa59fd6f95be417179f76b40325b500b3c98409757c", size = 411454, upload-time = "2025-10-08T09:14:54.648Z" },
+ { url = "https://files.pythonhosted.org/packages/b7/cd/9098fcb6adb32187a70b7ecaabf6339da50553351558f37600e53a4a2a23/msgpack-1.1.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bafca952dc13907bdfdedfc6a5f579bf4f292bdd506fadb38389afa3ac5b208e", size = 422341, upload-time = "2025-10-08T09:14:56.328Z" },
+ { url = "https://files.pythonhosted.org/packages/e6/ae/270cecbcf36c1dc85ec086b33a51a4d7d08fc4f404bdbc15b582255d05ff/msgpack-1.1.2-cp311-cp311-win32.whl", hash = "sha256:602b6740e95ffc55bfb078172d279de3773d7b7db1f703b2f1323566b878b90e", size = 64747, upload-time = "2025-10-08T09:14:57.882Z" },
+ { url = "https://files.pythonhosted.org/packages/2a/79/309d0e637f6f37e83c711f547308b91af02b72d2326ddd860b966080ef29/msgpack-1.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:d198d275222dc54244bf3327eb8cbe00307d220241d9cec4d306d49a44e85f68", size = 71633, upload-time = "2025-10-08T09:14:59.177Z" },
+ { url = "https://files.pythonhosted.org/packages/73/4d/7c4e2b3d9b1106cd0aa6cb56cc57c6267f59fa8bfab7d91df5adc802c847/msgpack-1.1.2-cp311-cp311-win_arm64.whl", hash = "sha256:86f8136dfa5c116365a8a651a7d7484b65b13339731dd6faebb9a0242151c406", size = 64755, upload-time = "2025-10-08T09:15:00.48Z" },
+ { url = "https://files.pythonhosted.org/packages/ad/bd/8b0d01c756203fbab65d265859749860682ccd2a59594609aeec3a144efa/msgpack-1.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:70a0dff9d1f8da25179ffcf880e10cf1aad55fdb63cd59c9a49a1b82290062aa", size = 81939, upload-time = "2025-10-08T09:15:01.472Z" },
+ { url = "https://files.pythonhosted.org/packages/34/68/ba4f155f793a74c1483d4bdef136e1023f7bcba557f0db4ef3db3c665cf1/msgpack-1.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:446abdd8b94b55c800ac34b102dffd2f6aa0ce643c55dfc017ad89347db3dbdb", size = 85064, upload-time = "2025-10-08T09:15:03.764Z" },
+ { url = "https://files.pythonhosted.org/packages/f2/60/a064b0345fc36c4c3d2c743c82d9100c40388d77f0b48b2f04d6041dbec1/msgpack-1.1.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c63eea553c69ab05b6747901b97d620bb2a690633c77f23feb0c6a947a8a7b8f", size = 417131, upload-time = "2025-10-08T09:15:05.136Z" },
+ { url = "https://files.pythonhosted.org/packages/65/92/a5100f7185a800a5d29f8d14041f61475b9de465ffcc0f3b9fba606e4505/msgpack-1.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:372839311ccf6bdaf39b00b61288e0557916c3729529b301c52c2d88842add42", size = 427556, upload-time = "2025-10-08T09:15:06.837Z" },
+ { url = "https://files.pythonhosted.org/packages/f5/87/ffe21d1bf7d9991354ad93949286f643b2bb6ddbeab66373922b44c3b8cc/msgpack-1.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2929af52106ca73fcb28576218476ffbb531a036c2adbcf54a3664de124303e9", size = 404920, upload-time = "2025-10-08T09:15:08.179Z" },
+ { url = "https://files.pythonhosted.org/packages/ff/41/8543ed2b8604f7c0d89ce066f42007faac1eaa7d79a81555f206a5cdb889/msgpack-1.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:be52a8fc79e45b0364210eef5234a7cf8d330836d0a64dfbb878efa903d84620", size = 415013, upload-time = "2025-10-08T09:15:09.83Z" },
+ { url = "https://files.pythonhosted.org/packages/41/0d/2ddfaa8b7e1cee6c490d46cb0a39742b19e2481600a7a0e96537e9c22f43/msgpack-1.1.2-cp312-cp312-win32.whl", hash = "sha256:1fff3d825d7859ac888b0fbda39a42d59193543920eda9d9bea44d958a878029", size = 65096, upload-time = "2025-10-08T09:15:11.11Z" },
+ { url = "https://files.pythonhosted.org/packages/8c/ec/d431eb7941fb55a31dd6ca3404d41fbb52d99172df2e7707754488390910/msgpack-1.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:1de460f0403172cff81169a30b9a92b260cb809c4cb7e2fc79ae8d0510c78b6b", size = 72708, upload-time = "2025-10-08T09:15:12.554Z" },
+ { url = "https://files.pythonhosted.org/packages/c5/31/5b1a1f70eb0e87d1678e9624908f86317787b536060641d6798e3cf70ace/msgpack-1.1.2-cp312-cp312-win_arm64.whl", hash = "sha256:be5980f3ee0e6bd44f3a9e9dea01054f175b50c3e6cdb692bc9424c0bbb8bf69", size = 64119, upload-time = "2025-10-08T09:15:13.589Z" },
+ { url = "https://files.pythonhosted.org/packages/6b/31/b46518ecc604d7edf3a4f94cb3bf021fc62aa301f0cb849936968164ef23/msgpack-1.1.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4efd7b5979ccb539c221a4c4e16aac1a533efc97f3b759bb5a5ac9f6d10383bf", size = 81212, upload-time = "2025-10-08T09:15:14.552Z" },
+ { url = "https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:42eefe2c3e2af97ed470eec850facbe1b5ad1d6eacdbadc42ec98e7dcf68b4b7", size = 84315, upload-time = "2025-10-08T09:15:15.543Z" },
+ { url = "https://files.pythonhosted.org/packages/d3/68/93180dce57f684a61a88a45ed13047558ded2be46f03acb8dec6d7c513af/msgpack-1.1.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1fdf7d83102bf09e7ce3357de96c59b627395352a4024f6e2458501f158bf999", size = 412721, upload-time = "2025-10-08T09:15:16.567Z" },
+ { url = "https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fac4be746328f90caa3cd4bc67e6fe36ca2bf61d5c6eb6d895b6527e3f05071e", size = 424657, upload-time = "2025-10-08T09:15:17.825Z" },
+ { url = "https://files.pythonhosted.org/packages/38/f8/4398c46863b093252fe67368b44edc6c13b17f4e6b0e4929dbf0bdb13f23/msgpack-1.1.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:fffee09044073e69f2bad787071aeec727183e7580443dfeb8556cbf1978d162", size = 402668, upload-time = "2025-10-08T09:15:19.003Z" },
+ { url = "https://files.pythonhosted.org/packages/28/ce/698c1eff75626e4124b4d78e21cca0b4cc90043afb80a507626ea354ab52/msgpack-1.1.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5928604de9b032bc17f5099496417f113c45bc6bc21b5c6920caf34b3c428794", size = 419040, upload-time = "2025-10-08T09:15:20.183Z" },
+ { url = "https://files.pythonhosted.org/packages/67/32/f3cd1667028424fa7001d82e10ee35386eea1408b93d399b09fb0aa7875f/msgpack-1.1.2-cp313-cp313-win32.whl", hash = "sha256:a7787d353595c7c7e145e2331abf8b7ff1e6673a6b974ded96e6d4ec09f00c8c", size = 65037, upload-time = "2025-10-08T09:15:21.416Z" },
+ { url = "https://files.pythonhosted.org/packages/74/07/1ed8277f8653c40ebc65985180b007879f6a836c525b3885dcc6448ae6cb/msgpack-1.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:a465f0dceb8e13a487e54c07d04ae3ba131c7c5b95e2612596eafde1dccf64a9", size = 72631, upload-time = "2025-10-08T09:15:22.431Z" },
+ { url = "https://files.pythonhosted.org/packages/e5/db/0314e4e2db56ebcf450f277904ffd84a7988b9e5da8d0d61ab2d057df2b6/msgpack-1.1.2-cp313-cp313-win_arm64.whl", hash = "sha256:e69b39f8c0aa5ec24b57737ebee40be647035158f14ed4b40e6f150077e21a84", size = 64118, upload-time = "2025-10-08T09:15:23.402Z" },
+ { url = "https://files.pythonhosted.org/packages/22/71/201105712d0a2ff07b7873ed3c220292fb2ea5120603c00c4b634bcdafb3/msgpack-1.1.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e23ce8d5f7aa6ea6d2a2b326b4ba46c985dbb204523759984430db7114f8aa00", size = 81127, upload-time = "2025-10-08T09:15:24.408Z" },
+ { url = "https://files.pythonhosted.org/packages/1b/9f/38ff9e57a2eade7bf9dfee5eae17f39fc0e998658050279cbb14d97d36d9/msgpack-1.1.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:6c15b7d74c939ebe620dd8e559384be806204d73b4f9356320632d783d1f7939", size = 84981, upload-time = "2025-10-08T09:15:25.812Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/a9/3536e385167b88c2cc8f4424c49e28d49a6fc35206d4a8060f136e71f94c/msgpack-1.1.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99e2cb7b9031568a2a5c73aa077180f93dd2e95b4f8d3b8e14a73ae94a9e667e", size = 411885, upload-time = "2025-10-08T09:15:27.22Z" },
+ { url = "https://files.pythonhosted.org/packages/2f/40/dc34d1a8d5f1e51fc64640b62b191684da52ca469da9cd74e84936ffa4a6/msgpack-1.1.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:180759d89a057eab503cf62eeec0aa61c4ea1200dee709f3a8e9397dbb3b6931", size = 419658, upload-time = "2025-10-08T09:15:28.4Z" },
+ { url = "https://files.pythonhosted.org/packages/3b/ef/2b92e286366500a09a67e03496ee8b8ba00562797a52f3c117aa2b29514b/msgpack-1.1.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:04fb995247a6e83830b62f0b07bf36540c213f6eac8e851166d8d86d83cbd014", size = 403290, upload-time = "2025-10-08T09:15:29.764Z" },
+ { url = "https://files.pythonhosted.org/packages/78/90/e0ea7990abea5764e4655b8177aa7c63cdfa89945b6e7641055800f6c16b/msgpack-1.1.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8e22ab046fa7ede9e36eeb4cfad44d46450f37bb05d5ec482b02868f451c95e2", size = 415234, upload-time = "2025-10-08T09:15:31.022Z" },
+ { url = "https://files.pythonhosted.org/packages/72/4e/9390aed5db983a2310818cd7d3ec0aecad45e1f7007e0cda79c79507bb0d/msgpack-1.1.2-cp314-cp314-win32.whl", hash = "sha256:80a0ff7d4abf5fecb995fcf235d4064b9a9a8a40a3ab80999e6ac1e30b702717", size = 66391, upload-time = "2025-10-08T09:15:32.265Z" },
+ { url = "https://files.pythonhosted.org/packages/6e/f1/abd09c2ae91228c5f3998dbd7f41353def9eac64253de3c8105efa2082f7/msgpack-1.1.2-cp314-cp314-win_amd64.whl", hash = "sha256:9ade919fac6a3e7260b7f64cea89df6bec59104987cbea34d34a2fa15d74310b", size = 73787, upload-time = "2025-10-08T09:15:33.219Z" },
+ { url = "https://files.pythonhosted.org/packages/6a/b0/9d9f667ab48b16ad4115c1935d94023b82b3198064cb84a123e97f7466c1/msgpack-1.1.2-cp314-cp314-win_arm64.whl", hash = "sha256:59415c6076b1e30e563eb732e23b994a61c159cec44deaf584e5cc1dd662f2af", size = 66453, upload-time = "2025-10-08T09:15:34.225Z" },
+ { url = "https://files.pythonhosted.org/packages/16/67/93f80545eb1792b61a217fa7f06d5e5cb9e0055bed867f43e2b8e012e137/msgpack-1.1.2-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:897c478140877e5307760b0ea66e0932738879e7aa68144d9b78ea4c8302a84a", size = 85264, upload-time = "2025-10-08T09:15:35.61Z" },
+ { url = "https://files.pythonhosted.org/packages/87/1c/33c8a24959cf193966ef11a6f6a2995a65eb066bd681fd085afd519a57ce/msgpack-1.1.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a668204fa43e6d02f89dbe79a30b0d67238d9ec4c5bd8a940fc3a004a47b721b", size = 89076, upload-time = "2025-10-08T09:15:36.619Z" },
+ { url = "https://files.pythonhosted.org/packages/fc/6b/62e85ff7193663fbea5c0254ef32f0c77134b4059f8da89b958beb7696f3/msgpack-1.1.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5559d03930d3aa0f3aacb4c42c776af1a2ace2611871c84a75afe436695e6245", size = 435242, upload-time = "2025-10-08T09:15:37.647Z" },
+ { url = "https://files.pythonhosted.org/packages/c1/47/5c74ecb4cc277cf09f64e913947871682ffa82b3b93c8dad68083112f412/msgpack-1.1.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:70c5a7a9fea7f036b716191c29047374c10721c389c21e9ffafad04df8c52c90", size = 432509, upload-time = "2025-10-08T09:15:38.794Z" },
+ { url = "https://files.pythonhosted.org/packages/24/a4/e98ccdb56dc4e98c929a3f150de1799831c0a800583cde9fa022fa90602d/msgpack-1.1.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f2cb069d8b981abc72b41aea1c580ce92d57c673ec61af4c500153a626cb9e20", size = 415957, upload-time = "2025-10-08T09:15:40.238Z" },
+ { url = "https://files.pythonhosted.org/packages/da/28/6951f7fb67bc0a4e184a6b38ab71a92d9ba58080b27a77d3e2fb0be5998f/msgpack-1.1.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d62ce1f483f355f61adb5433ebfd8868c5f078d1a52d042b0a998682b4fa8c27", size = 422910, upload-time = "2025-10-08T09:15:41.505Z" },
+ { url = "https://files.pythonhosted.org/packages/f0/03/42106dcded51f0a0b5284d3ce30a671e7bd3f7318d122b2ead66ad289fed/msgpack-1.1.2-cp314-cp314t-win32.whl", hash = "sha256:1d1418482b1ee984625d88aa9585db570180c286d942da463533b238b98b812b", size = 75197, upload-time = "2025-10-08T09:15:42.954Z" },
+ { url = "https://files.pythonhosted.org/packages/15/86/d0071e94987f8db59d4eeb386ddc64d0bb9b10820a8d82bcd3e53eeb2da6/msgpack-1.1.2-cp314-cp314t-win_amd64.whl", hash = "sha256:5a46bf7e831d09470ad92dff02b8b1ac92175ca36b087f904a0519857c6be3ff", size = 85772, upload-time = "2025-10-08T09:15:43.954Z" },
+ { url = "https://files.pythonhosted.org/packages/81/f2/08ace4142eb281c12701fc3b93a10795e4d4dc7f753911d836675050f886/msgpack-1.1.2-cp314-cp314t-win_arm64.whl", hash = "sha256:d99ef64f349d5ec3293688e91486c5fdb925ed03807f64d98d205d2713c60b46", size = 70868, upload-time = "2025-10-08T09:15:44.959Z" },
+]
+
[[package]]
name = "multidict"
version = "6.7.0"
@@ -2946,6 +3027,14 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/48/f3/b67d6ea49ca9154453b6d70b34ea22f3996b9fa55da105a79d8732227adc/soupsieve-2.8.1-py3-none-any.whl", hash = "sha256:a11fe2a6f3d76ab3cf2de04eb339c1be5b506a8a47f2ceb6d139803177f85434", size = 36710, upload-time = "2025-12-18T13:50:33.267Z" },
]
+[[package]]
+name = "sseclient-py"
+version = "1.9.0"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/4d/2e/59920f7d66b7f9932a3d83dd0ec53fab001be1e058bf582606fe414a5198/sseclient_py-1.9.0-py3-none-any.whl", hash = "sha256:340062b1587fc2880892811e2ab5b176d98ef3eee98b3672ff3a3ba1e8ed0f6f", size = 8351, upload-time = "2026-01-02T23:39:30.995Z" },
+]
+
[[package]]
name = "stack-data"
version = "0.6.3"
From c5dc657ed1754477829c5f845cb503fc5856a301 Mon Sep 17 00:00:00 2001
From: Alessandro
Date: Thu, 30 Apr 2026 18:25:32 +0200
Subject: [PATCH 2/3] improve
---
src/ml4t/data/providers/alpaca.py | 31 ++++++++-----
tests/integration/test_alpaca.py | 70 ++++++++++++++++++++++++++++++
tests/test_alpaca_provider_unit.py | 47 ++++++++++++++++++++
3 files changed, 138 insertions(+), 10 deletions(-)
create mode 100644 tests/integration/test_alpaca.py
diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py
index 2a26c2b..785b39e 100644
--- a/src/ml4t/data/providers/alpaca.py
+++ b/src/ml4t/data/providers/alpaca.py
@@ -123,10 +123,20 @@ def _timeframe_for(frequency: str) -> TimeFrame:
def _utc_range(start: str, end: str) -> tuple[datetime, datetime, datetime]:
"""Parse YYYY-MM-DD range to Alpaca UTC window (end exclusive for API, inclusive filter)."""
- start_utc = datetime.combine(
- datetime.strptime(start, "%Y-%m-%d").date(), dt_time.min, tzinfo=UTC
- )
- end_day = datetime.strptime(end, "%Y-%m-%d").date()
+ start_n = _normalize_iso_date(start)
+ end_n = _normalize_iso_date(end)
+ try:
+ start_utc = datetime.combine(
+ datetime.strptime(start_n, "%Y-%m-%d").date(), dt_time.min, tzinfo=UTC
+ )
+ end_day = datetime.strptime(end_n, "%Y-%m-%d").date()
+ except ValueError as e:
+ raise DataValidationError(
+ "alpaca",
+ f"Invalid start/end date (expected YYYY-MM-DD): {e}",
+ field="start/end",
+ value=f"{start!r} / {end!r}",
+ ) from e
end_excl = datetime.combine(end_day + timedelta(days=1), dt_time.min, tzinfo=UTC)
end_inclusive = datetime.combine(end_day, dt_time(23, 59, 59, 999999), tzinfo=UTC)
return start_utc, end_excl, end_inclusive
@@ -183,9 +193,7 @@ def _bars_pandas_to_polars_batch(df_pd: pd.DataFrame, end_inclusive: datetime) -
"Expected multi-symbol Alpaca bars to include a symbol column after reset_index",
field="bars.df",
)
- out = pl.from_pandas(
- df_pd[["timestamp", "symbol", "open", "high", "low", "close", "volume"]]
- )
+ out = pl.from_pandas(df_pd[["timestamp", "symbol", "open", "high", "low", "close", "volume"]])
return (
out.with_columns(
[
@@ -239,13 +247,14 @@ def __init__(
Args:
api_key: Alpaca API key; defaults to ``ALPACA_API_KEY``.
secret_key: Alpaca secret key; defaults to ``ALPACA_SECRET_KEY``.
- feed: Stock bar feed (e.g. SIP); optional.
+ feed: Stock bar feed (e.g. SIP); optional. If omitted, uses ``ALPACA_STOCK_FEED``
+ when set, otherwise lets the SDK default apply.
adjustment: Corporate-action adjustment for stock bars; optional.
"""
super().__init__(rate_limit=None)
self._api_key = api_key or os.getenv("ALPACA_API_KEY")
self._secret_key = secret_key or os.getenv("ALPACA_SECRET_KEY")
- self._feed = feed
+ self._feed = feed if feed is not None else _optional_stock_feed_from_env()
self._adjustment = adjustment
self._crypto: CryptoHistoricalDataClient | None = None
self._stock: StockHistoricalDataClient | None = None
@@ -456,7 +465,9 @@ def fetch_batch_ohlcv(
failed_symbols=len(failed),
)
if failed:
- self.logger.warning("Some Alpaca batch symbols failed", count=len(failed), sample=failed[:10])
+ self.logger.warning(
+ "Some Alpaca batch symbols failed", count=len(failed), sample=failed[:10]
+ )
return result
async def fetch_batch_ohlcv_async(
diff --git a/tests/integration/test_alpaca.py b/tests/integration/test_alpaca.py
new file mode 100644
index 0000000..e28384e
--- /dev/null
+++ b/tests/integration/test_alpaca.py
@@ -0,0 +1,70 @@
+"""Live Alpaca API checks when ``ALPACA_API_KEY`` and ``ALPACA_SECRET_KEY`` are set."""
+
+from __future__ import annotations
+
+import os
+
+import polars as pl
+import pytest
+
+pytest.importorskip("alpaca")
+
+from ml4t.data.providers.alpaca import AlpacaProvider
+
+pytestmark = pytest.mark.integration
+
+
+@pytest.fixture
+def alpaca_keys():
+ """Skip unless both Alpaca credentials are present."""
+ key = os.getenv("ALPACA_API_KEY")
+ secret = os.getenv("ALPACA_SECRET_KEY")
+ if not key or not secret:
+ pytest.skip("ALPACA_API_KEY and ALPACA_SECRET_KEY not both set")
+ return key, secret
+
+
+class TestAlpacaLiveOhlcv:
+ """Minimal real calls to validate batch and single-symbol paths."""
+
+ def test_fetch_batch_ohlcv_stocks_5minute(self, alpaca_keys):
+ """Short window + two symbols to limit payload and rate impact."""
+ p = AlpacaProvider()
+ df = p.fetch_batch_ohlcv(
+ ["AAPL", "MSFT"],
+ start="2024-01-02",
+ end="2024-01-04",
+ frequency="5minute",
+ )
+ assert isinstance(df, pl.DataFrame)
+ assert not df.is_empty()
+ assert set(df.columns) == {
+ "timestamp",
+ "symbol",
+ "open",
+ "high",
+ "low",
+ "close",
+ "volume",
+ }
+ syms = set(df["symbol"].unique().to_list())
+ assert syms <= {"AAPL", "MSFT"}
+ assert len(syms) >= 1
+
+ def test_fetch_ohlcv_single_daily(self, alpaca_keys):
+ p = AlpacaProvider()
+ df = p.fetch_ohlcv("AAPL", "2024-01-02", "2024-01-10", frequency="daily")
+ assert isinstance(df, pl.DataFrame)
+ assert not df.is_empty()
+ assert df["symbol"].unique().to_list() == ["AAPL"]
+
+ def test_batch_tolerates_whitespace_in_end_date(self, alpaca_keys):
+ """Regression: ``2024-12- 31`` style typos should not raise ValueError."""
+ p = AlpacaProvider()
+ df = p.fetch_batch_ohlcv(
+ ["AAPL"],
+ start="2024-12-30",
+ end="2024-12- 31",
+ frequency="daily",
+ )
+ assert len(df) >= 1
diff --git a/tests/test_alpaca_provider_unit.py b/tests/test_alpaca_provider_unit.py
index b6b38ee..d83c2ca 100644
--- a/tests/test_alpaca_provider_unit.py
+++ b/tests/test_alpaca_provider_unit.py
@@ -17,7 +17,10 @@
_bars_pandas_to_polars,
_bars_pandas_to_polars_batch,
_infer_kind,
+ _normalize_iso_date,
+ _optional_stock_feed_from_env,
_timeframe_for,
+ _utc_range,
)
@@ -58,6 +61,50 @@ def test_map_normalizes_to_internal_names(self):
assert p.FREQUENCY_MAP["5minute"] == "5minute"
+class TestNormalizeIsoDate:
+ """Whitespace-tolerant YYYY-MM-DD inputs (common copy-paste typo)."""
+
+ def test_collapses_internal_space(self):
+ assert _normalize_iso_date("2024-12- 31") == "2024-12-31"
+
+ def test_strips_edges(self):
+ assert _normalize_iso_date(" 2024-01-01 ") == "2024-01-01"
+
+
+class TestUtcRange:
+ """UTC window parsing for Alpaca requests."""
+
+ def test_accepts_date_with_stray_whitespace(self):
+ start_utc, end_excl, end_inc = _utc_range("2024-01-01", "2024-12- 31")
+ assert start_utc.year == 2024 and start_utc.month == 1
+ assert end_inc.month == 12 and end_inc.day == 31
+
+ def test_invalid_date_raises_data_validation(self):
+ with pytest.raises(DataValidationError, match="Invalid start/end date"):
+ _utc_range("2024-13-01", "2024-12-31")
+
+
+class TestStockFeedFromEnv:
+ """ALPACA_STOCK_FEED mirrors metafreq / Alpaca tier defaults."""
+
+ def test_unset_returns_none(self):
+ with patch.dict("os.environ", {}, clear=True):
+ assert _optional_stock_feed_from_env() is None
+
+ def test_iex_parsed(self):
+ with patch.dict("os.environ", {"ALPACA_STOCK_FEED": "iex"}):
+ from alpaca.data.enums import DataFeed
+
+ assert _optional_stock_feed_from_env() == DataFeed.IEX
+
+ def test_init_picks_feed_from_env_when_param_omitted(self):
+ with patch.dict("os.environ", {"ALPACA_STOCK_FEED": "iex"}):
+ from alpaca.data.enums import DataFeed
+
+ p = AlpacaProvider()
+ assert p._feed == DataFeed.IEX
+
+
class TestTimeframeFor:
"""Tests for Alpaca TimeFrame mapping."""
From 7a48023f69a21e364a2f8518bfd7752d58f9f5c5 Mon Sep 17 00:00:00 2001
From: Alessandro
Date: Sat, 2 May 2026 21:06:57 +0200
Subject: [PATCH 3/3] add test and register to provider
---
src/ml4t/data/managers/provider_manager.py | 10 ++++++++++
tests/test_provider_registration.py | 11 +++++++----
2 files changed, 17 insertions(+), 4 deletions(-)
diff --git a/src/ml4t/data/managers/provider_manager.py b/src/ml4t/data/managers/provider_manager.py
index 1178d5a..e24183d 100644
--- a/src/ml4t/data/managers/provider_manager.py
+++ b/src/ml4t/data/managers/provider_manager.py
@@ -215,6 +215,13 @@ def _get_provider_classes(cls) -> dict[str, type]:
except ImportError:
pass
+ try:
+ from ml4t.data.providers.alpaca import AlpacaProvider
+
+ provider_classes["alpaca"] = AlpacaProvider
+ except ImportError:
+ pass
+
cls._PROVIDER_CLASSES = provider_classes
return provider_classes
@@ -248,6 +255,9 @@ def _detect_available_providers(self) -> None:
if free_provider not in self._available_providers:
self._available_providers.append(free_provider)
+ if "alpaca" in self._provider_classes and "alpaca" not in self._available_providers:
+ self._available_providers.append("alpaca")
+
@property
def available_providers(self) -> list[str]:
"""Get list of available provider names."""
diff --git a/tests/test_provider_registration.py b/tests/test_provider_registration.py
index d76b0c8..34ae456 100644
--- a/tests/test_provider_registration.py
+++ b/tests/test_provider_registration.py
@@ -1,6 +1,7 @@
"""Test provider registration in DataManager."""
from ml4t.data.data_manager import DataManager
+from ml4t.data.providers.alpaca import AlpacaProvider
from ml4t.data.providers.binance import BinanceProvider
from ml4t.data.providers.binance_public import BinancePublicProvider
from ml4t.data.providers.cryptocompare import CryptoCompareProvider
@@ -13,7 +14,7 @@
def test_all_providers_registered():
- """Test that all 9 OHLCV providers are registered in DataManager.
+ """Test that all 10 OHLCV providers are registered in DataManager.
Note: Specialized providers (factor, prediction market) are not in DataManager.
- Factor providers: aqr, fama_french (standalone, different API)
@@ -30,10 +31,11 @@ def test_all_providers_registered():
"okx": OKXProvider,
"synthetic": SyntheticProvider,
"yahoo": YahooFinanceProvider,
+ "alpaca": AlpacaProvider,
}
assert expected_providers == DataManager.PROVIDER_CLASSES
- assert len(DataManager.PROVIDER_CLASSES) == 9
+ assert len(DataManager.PROVIDER_CLASSES) == len(expected_providers)
def test_provider_imports_work():
@@ -48,6 +50,7 @@ def test_provider_imports_work():
OandaProvider,
SyntheticProvider,
YahooFinanceProvider,
+ AlpacaProvider,
]
for provider_class in providers:
@@ -98,11 +101,10 @@ def test_provider_count():
- Prediction markets: kalshi, polymarket
- Historical: wiki_prices
"""
- assert len(DataManager.PROVIDER_CLASSES) == 9, (
+ assert len(DataManager.PROVIDER_CLASSES) == 10, (
f"Expected 9 providers, got {len(DataManager.PROVIDER_CLASSES)}"
)
- # List all provider names for clarity
provider_names = list(DataManager.PROVIDER_CLASSES.keys())
assert "binance" in provider_names
assert "binance_public" in provider_names
@@ -113,6 +115,7 @@ def test_provider_count():
assert "okx" in provider_names
assert "synthetic" in provider_names
assert "yahoo" in provider_names
+ assert "alpaca" in provider_names
if __name__ == "__main__":