From 9e9e8b589cafb615dbe22fb64b7a85baf5de4028 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 10 Jun 2026 14:05:04 +0000 Subject: [PATCH 01/13] feat: scaffold Alpaca data provider (SG-1) - Add AlpacaDataProvider with two-header auth on sync and async sessions - Map daily/hourly/minute frequencies and aliases to Alpaca timeframes - Declare crypto/intraday capabilities and conform to OHLCVProvider --- src/ml4t/data/providers/alpaca.py | 197 ++++++++++++++++++++++++++++++ tests/test_alpaca_provider.py | 176 ++++++++++++++++++++++++++ 2 files changed, 373 insertions(+) create mode 100644 src/ml4t/data/providers/alpaca.py create mode 100644 tests/test_alpaca_provider.py diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py new file mode 100644 index 0000000..e870383 --- /dev/null +++ b/src/ml4t/data/providers/alpaca.py @@ -0,0 +1,197 @@ +"""Alpaca data provider. + +Alpaca provides long-history, high-frequency market data across multiple asset +classes including equities and crypto, served over a REST API. + +API Documentation: https://docs.alpaca.markets/ + +Authentication: +- Two credentials are required: an API key id and an API secret, sent as the + ``APCA-API-KEY-ID`` / ``APCA-API-SECRET-KEY`` request headers. +- Set ``ALPACA_API_KEY`` / ``ALPACA_API_SECRET`` (primary, project convention) + or pass ``api_key`` / ``api_secret`` to the constructor. Alpaca's own SDK/CLI + names ``APCA_API_KEY_ID`` / ``APCA_API_SECRET_KEY`` are accepted as a fallback. + +Feed selection: +- ``feed`` defaults to ``"iex"`` (free tier, ~15-min delayed, thinner coverage). + Paid subscribers pass ``feed="sip"`` once at construction. + +Rate limiting: +- ``DEFAULT_RATE_LIMIT`` is a conservative client-side throttle of 200 calls per + minute reflecting the commonly-cited Basic (free) plan figure. The API itself + enforces limits via HTTP 429 and rate-limit response headers rather than a + fixed documented number, so this default is tentative and overridable via the + ``rate_limit`` constructor argument. + +Example: + >>> from ml4t.data.providers.alpaca import AlpacaDataProvider + >>> provider = AlpacaDataProvider(api_key="k", api_secret="s") + >>> provider.close() +""" + +from __future__ import annotations + +import os +from typing import Any, ClassVar + +import structlog + +from ml4t.data.core.exceptions import AuthenticationError, DataValidationError +from ml4t.data.providers.base import BaseProvider +from ml4t.data.providers.mixins import AsyncSessionMixin +from ml4t.data.providers.protocols import ProviderCapabilities + +logger = structlog.get_logger() + + +class AlpacaDataProvider(AsyncSessionMixin, BaseProvider): + """Alpaca market data provider. + + Supports equities and crypto with daily, hourly, and minute OHLCV bars over + Alpaca's historical REST API. Authentication uses two header credentials that + are wired onto both the sync and async HTTP sessions. + + Supports both sync and async operations: + # Sync + provider = AlpacaDataProvider(api_key="k", api_secret="s") + + # Async + async with AlpacaDataProvider(api_key="k", api_secret="s") as provider: + ... + """ + + # 200 requests/min — Alpaca Basic (free) plan, per docs.alpaca.markets + # "About Market Data API" and Alpaca support (usage-limit-api-calls). + # Tentative: the API enforces via HTTP 429 + rate-limit headers, not a fixed + # documented number — verify from response headers and adjust if needed. + DEFAULT_RATE_LIMIT: ClassVar[tuple[int, float]] = (200, 60.0) + + # Map canonical frequency keys and aliases to Alpaca timeframe strings + FREQUENCY_MAP: ClassVar[dict[str, str]] = { + "daily": "1Day", + "day": "1Day", + "1d": "1Day", + "1day": "1Day", + "hourly": "1Hour", + "hour": "1Hour", + "1h": "1Hour", + "1hour": "1Hour", + "minute": "1Min", + "1m": "1Min", + "1minute": "1Min", + } + + def __init__( + self, + api_key: str | None = None, + api_secret: str | None = None, + feed: str = "iex", + rate_limit: tuple[int, float] | None = None, + **kwargs: Any, + ) -> None: + """Initialize the Alpaca data provider. + + Args: + api_key: Alpaca API key id (or set ALPACA_API_KEY / APCA_API_KEY_ID). + api_secret: Alpaca API secret (or set ALPACA_API_SECRET / + APCA_API_SECRET_KEY). + feed: Data feed, "iex" (free) or "sip" (paid). + rate_limit: Optional custom (calls, period_seconds) override. + **kwargs: Additional arguments forwarded to BaseProvider. + + Raises: + AuthenticationError: If either credential is missing. + """ + self.api_key = api_key or os.getenv("ALPACA_API_KEY") or os.getenv("APCA_API_KEY_ID") + self.api_secret = ( + api_secret or os.getenv("ALPACA_API_SECRET") or os.getenv("APCA_API_SECRET_KEY") + ) + if not self.api_key or not self.api_secret: + raise AuthenticationError( + provider="alpaca", + message="Alpaca API key and secret required. Set ALPACA_API_KEY " + "and ALPACA_API_SECRET environment variables or pass api_key and " + "api_secret parameters. Get a free key at: https://alpaca.markets/", + ) + + self.feed = feed + self.base_url = "https://data.alpaca.markets" + self.crypto_base_url = "https://data.alpaca.markets" + + # Built once, reused for both the sync and async sessions so the two + # credentials accompany every request on either transport. + self._auth_headers = { + "APCA-API-KEY-ID": self.api_key, + "APCA-API-SECRET-KEY": self.api_secret, + } + + super().__init__( + rate_limit=rate_limit or self.DEFAULT_RATE_LIMIT, + session_config={"headers": self._auth_headers}, + **kwargs, + ) + + self.logger.info( + "Initialized Alpaca provider", + feed=feed, + rate_limit=rate_limit or self.DEFAULT_RATE_LIMIT, + ) + + @property + def name(self) -> str: + """Return provider name.""" + return "alpaca" + + async def init_async_session( + self, + headers: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: + """Initialize the async session, defaulting to the auth headers. + + The async session must carry the same two credentials as the sync + session, so the auth headers are applied unless a caller supplies its own. + + Args: + headers: Default headers for all requests; falls back to the auth + headers when not provided. + **kwargs: Additional arguments forwarded to AsyncSessionMixin. + """ + await super().init_async_session(headers=headers or self._auth_headers, **kwargs) + + def capabilities(self) -> ProviderCapabilities: + """Return provider capabilities. + + Returns: + Capabilities advertising crypto, intraday, required authentication, + and the client-side rate limit. + """ + return ProviderCapabilities( + supports_intraday=True, + supports_crypto=True, + requires_api_key=True, + rate_limit=self.DEFAULT_RATE_LIMIT, + ) + + def _map_frequency(self, frequency: str) -> str: + """Map a canonical frequency to an Alpaca timeframe string. + + Args: + frequency: Canonical frequency key or alias (e.g. "daily", "1h"). + + Returns: + The Alpaca timeframe string (e.g. "1Day", "1Hour", "1Min"). + + Raises: + DataValidationError: If the frequency is not supported. + """ + try: + return self.FREQUENCY_MAP[frequency.lower()] + except KeyError as err: + raise DataValidationError( + provider="alpaca", + message=f"Unsupported frequency '{frequency}'. " + f"Supported: {list(self.FREQUENCY_MAP.keys())}", + field="frequency", + value=frequency, + ) from err diff --git a/tests/test_alpaca_provider.py b/tests/test_alpaca_provider.py new file mode 100644 index 0000000..512ebd8 --- /dev/null +++ b/tests/test_alpaca_provider.py @@ -0,0 +1,176 @@ +"""Tests for Alpaca data provider module.""" + +from unittest.mock import patch + +import pytest + +from ml4t.data.core.exceptions import AuthenticationError, DataValidationError +from ml4t.data.providers.alpaca import AlpacaDataProvider +from ml4t.data.providers.protocols import OHLCVProvider + + +class TestAlpacaProviderInit: + """Tests for provider construction and authentication.""" + + def test_init_with_explicit_keys(self): + """Explicit api_key/api_secret are stored on the provider.""" + provider = AlpacaDataProvider(api_key="k", api_secret="s") + + assert provider.api_key == "k" + assert provider.api_secret == "s" + assert provider.feed == "iex" + + def test_init_with_env_keys(self): + """Credentials are read from ALPACA_* environment variables.""" + with patch.dict( + "os.environ", + {"ALPACA_API_KEY": "k", "ALPACA_API_SECRET": "s"}, + ): + provider = AlpacaDataProvider() + + assert provider.api_key == "k" + assert provider.api_secret == "s" + + def test_init_with_apca_env_fallback(self): + """Alpaca SDK/CLI env names are honored as a fallback.""" + with patch.dict( + "os.environ", + {"APCA_API_KEY_ID": "k", "APCA_API_SECRET_KEY": "s"}, + clear=True, + ): + provider = AlpacaDataProvider() + + assert provider.api_key == "k" + assert provider.api_secret == "s" + + def test_init_missing_key_raises(self): + """A missing API key raises AuthenticationError.""" + with patch.dict("os.environ", {}, clear=True): + with pytest.raises(AuthenticationError): + AlpacaDataProvider(api_secret="s") + + def test_init_missing_secret_raises(self): + """A missing API secret raises AuthenticationError.""" + with patch.dict("os.environ", {}, clear=True): + with pytest.raises(AuthenticationError): + AlpacaDataProvider(api_key="k") + + def test_default_feed_is_iex(self): + """Feed defaults to IEX (free tier).""" + provider = AlpacaDataProvider(api_key="k", api_secret="s") + + assert provider.feed == "iex" + + def test_feed_sip_honored(self): + """An explicit SIP feed is honored.""" + provider = AlpacaDataProvider(api_key="k", api_secret="s", feed="sip") + + assert provider.feed == "sip" + + +class TestAuthHeaders: + """Tests for auth header wiring on both sessions.""" + + def test_auth_headers_on_sync_session(self): + """Both auth headers are present on the sync httpx.Client.""" + provider = AlpacaDataProvider(api_key="k", api_secret="s") + + assert provider.session.headers["APCA-API-KEY-ID"] == "k" + assert provider.session.headers["APCA-API-SECRET-KEY"] == "s" + + @pytest.mark.asyncio + async def test_auth_headers_on_async_session(self): + """Both auth headers are present on the async httpx.AsyncClient.""" + provider = AlpacaDataProvider(api_key="k", api_secret="s") + + await provider.init_async_session() + try: + assert provider.async_session is not None + assert provider.async_session.headers["APCA-API-KEY-ID"] == "k" + assert provider.async_session.headers["APCA-API-SECRET-KEY"] == "s" + finally: + await provider.close_async_session() + + +class TestNameProperty: + """Tests for the name property.""" + + def test_name_property(self): + """Name property returns 'alpaca'.""" + provider = AlpacaDataProvider(api_key="k", api_secret="s") + + assert provider.name == "alpaca" + + +class TestFrequencyMapping: + """Tests for frequency mapping to Alpaca timeframes.""" + + @pytest.fixture + def provider(self): + """Create a provider instance.""" + return AlpacaDataProvider(api_key="k", api_secret="s") + + @pytest.mark.parametrize( + ("frequency", "expected"), + [ + ("daily", "1Day"), + ("day", "1Day"), + ("1d", "1Day"), + ("1day", "1Day"), + ("hourly", "1Hour"), + ("hour", "1Hour"), + ("1h", "1Hour"), + ("1hour", "1Hour"), + ("minute", "1Min"), + ("1m", "1Min"), + ("1minute", "1Min"), + ], + ) + def test_frequency_maps_to_timeframe(self, provider, frequency, expected): + """Canonical keys and aliases map to the right Alpaca timeframe.""" + assert provider._map_frequency(frequency) == expected + + def test_frequency_mapping_is_case_insensitive(self, provider): + """Frequency lookup is case-insensitive.""" + assert provider._map_frequency("DAILY") == "1Day" + + def test_unsupported_frequency_raises(self, provider): + """Unsupported frequency raises DataValidationError listing supported keys.""" + with pytest.raises(DataValidationError) as exc_info: + provider._map_frequency("yearly") + + message = str(exc_info.value) + assert "daily" in message + assert "minute" in message + + +class TestCapabilities: + """Tests for the capabilities declaration.""" + + def test_capabilities(self): + """Capabilities report crypto, intraday, api-key, and a rate limit.""" + provider = AlpacaDataProvider(api_key="k", api_secret="s") + caps = provider.capabilities() + + assert caps.supports_crypto is True + assert caps.supports_intraday is True + assert caps.requires_api_key is True + assert caps.rate_limit == AlpacaDataProvider.DEFAULT_RATE_LIMIT + + +class TestProtocolConformance: + """Tests for OHLCVProvider protocol conformance.""" + + def test_is_ohlcv_provider(self): + """An instance satisfies the OHLCVProvider protocol.""" + provider = AlpacaDataProvider(api_key="k", api_secret="s") + + assert isinstance(provider, OHLCVProvider) + + +class TestDefaultRateLimit: + """Tests for the default rate-limit class variable.""" + + def test_default_rate_limit(self): + """DEFAULT_RATE_LIMIT reflects the documented Basic-plan figure.""" + assert AlpacaDataProvider.DEFAULT_RATE_LIMIT == (200, 60.0) From bb115789bbbee67b0052232d2555da0914e022d9 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 10 Jun 2026 14:13:37 +0000 Subject: [PATCH 02/13] feat: fetch and transform Alpaca stock bars (SG-2) - Add sync and async single-page stock bar fetch with per-status error mapping - Inspect status codes directly so 401/403/429/404/500 surface as typed errors - Transform bars into the canonical OHLCV schema, tolerant of list and dict shapes - Forward the configured data feed in request params --- src/ml4t/data/providers/alpaca.py | 243 +++++++++++++++++++++++++++++- tests/test_alpaca_provider.py | 188 ++++++++++++++++++++++- 2 files changed, 428 insertions(+), 3 deletions(-) diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py index e870383..a8b6b9a 100644 --- a/src/ml4t/data/providers/alpaca.py +++ b/src/ml4t/data/providers/alpaca.py @@ -34,9 +34,17 @@ import os from typing import Any, ClassVar +import polars as pl import structlog -from ml4t.data.core.exceptions import AuthenticationError, DataValidationError +from ml4t.data.core.exceptions import ( + AuthenticationError, + DataNotAvailableError, + DataValidationError, + NetworkError, + ProviderError, + RateLimitError, +) from ml4t.data.providers.base import BaseProvider from ml4t.data.providers.mixins import AsyncSessionMixin from ml4t.data.providers.protocols import ProviderCapabilities @@ -195,3 +203,236 @@ def _map_frequency(self, frequency: str) -> str: field="frequency", value=frequency, ) from err + + def _stock_bars_params(self, frequency: str, start: str, end: str) -> dict[str, Any]: + """Build the query parameters for a stock bars request. + + Args: + frequency: Canonical frequency key or alias. + start: Inclusive start date/datetime in ISO-8601 (RFC-3339) form. + end: Inclusive end date/datetime in ISO-8601 (RFC-3339) form. + + Returns: + The query parameter mapping, including the configured data feed. + """ + return { + "timeframe": self._map_frequency(frequency), + "start": start, + "end": end, + "limit": 10000, + "feed": self.feed, + } + + def _check_response_status(self, status_code: int, symbol: str, response_text: str) -> None: + """Map an HTTP status code to a typed provider exception. + + The status code is inspected directly rather than delegating to + ``raise_for_status``, so that each documented failure mode surfaces as a + specific provider error instead of a generic transport error. + + Args: + status_code: The HTTP status code from the response. + symbol: The requested symbol, for error context. + response_text: The response body, included in network-error messages. + + Raises: + RateLimitError: On HTTP 429. + AuthenticationError: On HTTP 401/403. + DataNotAvailableError: On HTTP 404. + NetworkError: On any other non-200 status. + """ + if status_code == 429: + raise RateLimitError(provider="alpaca", retry_after=60.0) + if status_code in (401, 403): + raise AuthenticationError( + provider="alpaca", message="Invalid API key or unauthorized access" + ) + if status_code == 404: + raise DataNotAvailableError(provider="alpaca", symbol=symbol) + if status_code != 200: + raise NetworkError(provider="alpaca", message=f"HTTP {status_code}: {response_text}") + + def _parse_bars_response(self, response: Any) -> dict[str, Any]: + """Parse a validated bars response into a JSON dict. + + Args: + response: An HTTP response whose status has already been checked. + + Returns: + The parsed JSON payload. + + Raises: + NetworkError: If the body cannot be decoded as JSON. + """ + try: + return response.json() + except Exception as err: + raise NetworkError(provider="alpaca", message="Failed to parse JSON response") from err + + def _fetch_raw_data( + self, + symbol: str, + start: str, + end: str, + frequency: str = "daily", + asset_class: str | None = None, # noqa: ARG002 + ) -> dict[str, Any]: + """Fetch a single page of stock bars from Alpaca. + + Args: + symbol: The equity symbol to fetch (case-insensitive). + start: Inclusive start date/datetime in ISO-8601 (RFC-3339) form. + end: Inclusive end date/datetime in ISO-8601 (RFC-3339) form. + frequency: Canonical frequency key or alias. + asset_class: Reserved for routing to other asset classes; the stock + branch is the default. + + Returns: + The parsed JSON payload containing the ``bars`` data. + + Raises: + RateLimitError, AuthenticationError, DataNotAvailableError, + NetworkError: Per the HTTP status of the response. + """ + endpoint = f"{self.base_url}/v2/stocks/{symbol.upper()}/bars" + params = self._stock_bars_params(frequency, start, end) + + try: + self.rate_limiter.acquire(blocking=True) + response = self.session.get(endpoint, params=params) + self._check_response_status(response.status_code, symbol, response.text) + return self._parse_bars_response(response) + except ( + AuthenticationError, + RateLimitError, + NetworkError, + DataNotAvailableError, + ProviderError, + ): + raise + except Exception as err: + raise NetworkError(provider="alpaca", message=f"Request failed: {endpoint}") from err + + async def _fetch_raw_data_async( + self, + symbol: str, + start: str, + end: str, + frequency: str = "daily", + asset_class: str | None = None, # noqa: ARG002 + ) -> dict[str, Any]: + """Asynchronously fetch a single page of stock bars from Alpaca. + + Mirrors :meth:`_fetch_raw_data` over the async transport; the async + session carries the same two-header credentials as the sync session. + + Args: + symbol: The equity symbol to fetch (case-insensitive). + start: Inclusive start date/datetime in ISO-8601 (RFC-3339) form. + end: Inclusive end date/datetime in ISO-8601 (RFC-3339) form. + frequency: Canonical frequency key or alias. + asset_class: Reserved for routing to other asset classes; the stock + branch is the default. + + Returns: + The parsed JSON payload containing the ``bars`` data. + + Raises: + RateLimitError, AuthenticationError, DataNotAvailableError, + NetworkError: Per the HTTP status of the response. + """ + endpoint = f"{self.base_url}/v2/stocks/{symbol.upper()}/bars" + params = self._stock_bars_params(frequency, start, end) + + try: + self.rate_limiter.acquire(blocking=True) + response = await self._aget(endpoint, params=params) + self._check_response_status(response.status_code, symbol, response.text) + return self._parse_bars_response(response) + except ( + AuthenticationError, + RateLimitError, + NetworkError, + DataNotAvailableError, + ProviderError, + ): + raise + except Exception as err: + raise NetworkError(provider="alpaca", message=f"Request failed: {endpoint}") from err + + def _extract_bars(self, raw_data: dict[str, Any], symbol: str) -> list[dict[str, Any]]: + """Extract the bar list from either response shape. + + Alpaca's single-symbol endpoint returns ``{"bars": [...]}`` while the + multi-symbol endpoint returns ``{"bars": {"": [...]}}``. Both are + accepted so the endpoint choice does not ripple into the transform. + + Args: + raw_data: The parsed JSON payload. + symbol: The requested symbol, used to key into a dict payload. + + Returns: + The list of bar dicts, or an empty list when none are present. + """ + bars = raw_data.get("bars") + if isinstance(bars, dict): + return bars.get(symbol) or bars.get(symbol.upper()) or [] + return bars or [] + + def _transform_data( + self, + raw_data: dict[str, Any], + symbol: str, + asset_class: str | None = None, + ) -> pl.DataFrame: + """Transform a raw bars payload into the canonical OHLCV schema. + + Args: + raw_data: The parsed JSON payload from a bars request. + symbol: The requested symbol; the literal is added as a column, + uppercased for stocks and used verbatim for crypto. + asset_class: When ``"crypto"``, the symbol literal is preserved + verbatim; otherwise it is uppercased. + + Returns: + A DataFrame with columns + ``[timestamp, symbol, open, high, low, close, volume]`` sorted by + timestamp, or the canonical empty DataFrame when there are no bars. + + Raises: + DataValidationError: If the bars cannot be transformed. + """ + bars = self._extract_bars(raw_data, symbol) + if not bars: + return self._create_empty_dataframe() + + symbol_literal = symbol if asset_class == "crypto" else symbol.upper() + + try: + df = pl.DataFrame(bars) + df = df.rename( + { + "t": "timestamp", + "o": "open", + "h": "high", + "l": "low", + "c": "close", + "v": "volume", + } + ) + # Alpaca timestamps are RFC-3339 with a UTC ("Z") offset; parse them + # tz-aware then drop the zone to match the canonical naive schema. + df = df.with_columns( + pl.col("timestamp") + .str.to_datetime(format="%Y-%m-%dT%H:%M:%S%.f%#z") + .dt.replace_time_zone(None) + ) + for col in ["open", "high", "low", "close", "volume"]: + df = df.with_columns(pl.col(col).cast(pl.Float64)) + df = df.with_columns(pl.lit(symbol_literal).alias("symbol")) + df = df.sort("timestamp") + return df.select(["timestamp", "symbol", "open", "high", "low", "close", "volume"]) + except Exception as err: + raise DataValidationError( + provider="alpaca", message=f"Failed to transform data for {symbol}" + ) from err diff --git a/tests/test_alpaca_provider.py b/tests/test_alpaca_provider.py index 512ebd8..38f94b9 100644 --- a/tests/test_alpaca_provider.py +++ b/tests/test_alpaca_provider.py @@ -1,13 +1,25 @@ """Tests for Alpaca data provider module.""" -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch +import polars as pl import pytest -from ml4t.data.core.exceptions import AuthenticationError, DataValidationError +from ml4t.data.core.exceptions import ( + AuthenticationError, + DataNotAvailableError, + DataValidationError, + NetworkError, + RateLimitError, +) from ml4t.data.providers.alpaca import AlpacaDataProvider from ml4t.data.providers.protocols import OHLCVProvider +STOCK_BARS = [ + {"t": "2024-01-03T05:00:00Z", "o": 2.0, "h": 3.0, "l": 1.5, "c": 2.5, "v": 2000}, + {"t": "2024-01-02T05:00:00Z", "o": 1.0, "h": 2.0, "l": 0.5, "c": 1.5, "v": 1000}, +] + class TestAlpacaProviderInit: """Tests for provider construction and authentication.""" @@ -174,3 +186,175 @@ class TestDefaultRateLimit: def test_default_rate_limit(self): """DEFAULT_RATE_LIMIT reflects the documented Basic-plan figure.""" assert AlpacaDataProvider.DEFAULT_RATE_LIMIT == (200, 60.0) + + +class TestFetchRawDataStock: + """Tests for single-page stock bar fetching and error mapping.""" + + @pytest.fixture + def provider(self): + """Create a provider instance.""" + return AlpacaDataProvider(api_key="k", api_secret="s") + + def test_fetch_raw_data_success(self, provider): + """A 200 response returns the parsed bars structure.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"bars": STOCK_BARS, "next_page_token": None} + + with ( + patch.object(provider.session, "get", return_value=mock_response), + patch.object(provider.rate_limiter, "acquire"), + ): + data = provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + + assert data["bars"] == STOCK_BARS + assert data["next_page_token"] is None + + def test_feed_param_sent_for_stock(self): + """The configured feed is forwarded in the request params.""" + provider = AlpacaDataProvider(api_key="k", api_secret="s", feed="sip") + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"bars": [], "next_page_token": None} + + with ( + patch.object(provider.session, "get", return_value=mock_response) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + + assert mock_get.call_args.kwargs["params"]["feed"] == "sip" + + def test_fetch_429_raises_rate_limit(self, provider): + """A 429 maps to RateLimitError.""" + mock_response = MagicMock() + mock_response.status_code = 429 + + with ( + patch.object(provider.session, "get", return_value=mock_response), + patch.object(provider.rate_limiter, "acquire"), + ): + with pytest.raises(RateLimitError): + provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + + def test_fetch_401_raises_auth(self, provider): + """A 401 maps to AuthenticationError.""" + mock_response = MagicMock() + mock_response.status_code = 401 + + with ( + patch.object(provider.session, "get", return_value=mock_response), + patch.object(provider.rate_limiter, "acquire"), + ): + with pytest.raises(AuthenticationError): + provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + + def test_fetch_404_raises_data_not_available(self, provider): + """A 404 maps to DataNotAvailableError.""" + mock_response = MagicMock() + mock_response.status_code = 404 + + with ( + patch.object(provider.session, "get", return_value=mock_response), + patch.object(provider.rate_limiter, "acquire"), + ): + with pytest.raises(DataNotAvailableError): + provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + + def test_fetch_500_raises_network(self, provider): + """A 500 maps to NetworkError.""" + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + + with ( + patch.object(provider.session, "get", return_value=mock_response), + patch.object(provider.rate_limiter, "acquire"), + ): + with pytest.raises(NetworkError): + provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + + def test_json_parse_error_raises_network(self, provider): + """A JSON parse failure maps to NetworkError.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.side_effect = ValueError("Invalid JSON") + + with ( + patch.object(provider.session, "get", return_value=mock_response), + patch.object(provider.rate_limiter, "acquire"), + ): + with pytest.raises(NetworkError, match="Failed to parse JSON"): + provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + + @pytest.mark.asyncio + async def test_fetch_raw_data_async_stock(self, provider): + """The async path returns the same parsed structure as the sync path.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"bars": STOCK_BARS, "next_page_token": None} + + with ( + patch.object(provider, "_aget", new=AsyncMock(return_value=mock_response)), + patch.object(provider.rate_limiter, "acquire"), + ): + data = await provider._fetch_raw_data_async("AAPL", "2024-01-01", "2024-01-05", "daily") + + assert data["bars"] == STOCK_BARS + + +class TestTransformDataStock: + """Tests for transforming stock bars into the standard schema.""" + + @pytest.fixture + def provider(self): + """Create a provider instance.""" + return AlpacaDataProvider(api_key="k", api_secret="s") + + def test_transform_data_stock(self, provider): + """Bars transform to the canonical OHLCV schema, sorted, with invariants.""" + raw = {"bars": STOCK_BARS, "next_page_token": None} + + df = provider._transform_data(raw, "AAPL") + + assert df.columns == [ + "timestamp", + "symbol", + "open", + "high", + "low", + "close", + "volume", + ] + assert df.schema["timestamp"] == pl.Datetime + for col in ["open", "high", "low", "close", "volume"]: + assert df.schema[col] == pl.Float64 + assert df["symbol"].to_list() == ["AAPL", "AAPL"] + # Rows sorted ascending by timestamp. + timestamps = df["timestamp"].to_list() + assert timestamps == sorted(timestamps) + # OHLC invariants hold for every row. + assert (df["high"] >= df["low"]).all() + assert (df["high"] >= df["open"]).all() + assert (df["high"] >= df["close"]).all() + + def test_transform_data_stock_dict_shape(self, provider): + """A multi-symbol dict-keyed bars payload is also accepted.""" + raw = {"bars": {"AAPL": STOCK_BARS}, "next_page_token": None} + + df = provider._transform_data(raw, "AAPL") + + assert df["symbol"].to_list() == ["AAPL", "AAPL"] + assert df.height == 2 + + def test_empty_bars_returns_empty_dataframe(self, provider): + """An empty bars list yields the canonical empty DataFrame.""" + raw = {"bars": [], "next_page_token": None} + + df = provider._transform_data(raw, "AAPL") + expected = provider._create_empty_dataframe() + + assert df.columns == expected.columns + assert df.schema == expected.schema + assert df.height == 0 From c08a21da9c3555e843306b6334d30fd5cff510c2 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 10 Jun 2026 14:35:11 +0000 Subject: [PATCH 03/13] feat: route and fetch Alpaca crypto bars (SG-3) - Resolve asset class from a BASE/QUOTE slash or an explicit override - Route crypto symbols to the v1beta3 crypto bars endpoint, omitting feed - Preserve the BASE/QUOTE symbol verbatim in params and the symbol column - Fully override fetch_ohlcv/fetch_ohlcv_async to thread asset_class while keeping input validation, circuit breaker, and OHLCV validation - Share one bars-to-DataFrame helper across the stock and crypto branches --- src/ml4t/data/providers/alpaca.py | 258 ++++++++++++++++++++++++++++-- tests/test_alpaca_provider.py | 167 +++++++++++++++++++ 2 files changed, 410 insertions(+), 15 deletions(-) diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py index a8b6b9a..bb09218 100644 --- a/src/ml4t/data/providers/alpaca.py +++ b/src/ml4t/data/providers/alpaca.py @@ -36,6 +36,7 @@ import polars as pl import structlog +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from ml4t.data.core.exceptions import ( AuthenticationError, @@ -204,6 +205,24 @@ def _map_frequency(self, frequency: str) -> str: value=frequency, ) from err + def _resolve_asset_class(self, symbol: str, asset_class: str | None) -> str: + """Resolve the asset class for a request. + + An explicit ``asset_class`` always wins. Otherwise a ``BASE/QUOTE`` + symbol (e.g. ``"BTC/USD"``) is treated as crypto and everything else as + a stock. + + Args: + symbol: The requested symbol. + asset_class: An explicit asset class, or ``None`` to infer it. + + Returns: + Either ``"crypto"`` or ``"stock"``. + """ + if asset_class is not None: + return asset_class + return "crypto" if "/" in symbol else "stock" + def _stock_bars_params(self, frequency: str, start: str, end: str) -> dict[str, Any]: """Build the query parameters for a stock bars request. @@ -223,6 +242,60 @@ def _stock_bars_params(self, frequency: str, start: str, end: str) -> dict[str, "feed": self.feed, } + def _crypto_bars_params( + self, symbol: str, frequency: str, start: str, end: str + ) -> dict[str, Any]: + """Build the query parameters for a crypto bars request. + + The crypto bars endpoint is multi-symbol: the symbol travels in the + ``symbols`` parameter (preserving the ``BASE/QUOTE`` form verbatim) and + no ``feed`` is sent, since crypto has a single consolidated feed. + + Args: + symbol: The crypto symbol in ``BASE/QUOTE`` form (e.g. "BTC/USD"). + frequency: Canonical frequency key or alias. + start: Inclusive start date/datetime in ISO-8601 (RFC-3339) form. + end: Inclusive end date/datetime in ISO-8601 (RFC-3339) form. + + Returns: + The query parameter mapping for the crypto bars endpoint. + """ + return { + "symbols": symbol, + "timeframe": self._map_frequency(frequency), + "start": start, + "end": end, + "limit": 10000, + } + + def _bars_request( + self, symbol: str, start: str, end: str, frequency: str, asset_class: str + ) -> tuple[str, dict[str, Any]]: + """Resolve the endpoint and query params for a bars request. + + Branches on the resolved asset class so the sync and async fetchers share + one routing decision. The stock branch uppercases the symbol into the + path; the crypto branch hits the multi-symbol endpoint and preserves the + ``BASE/QUOTE`` symbol verbatim. + + Args: + symbol: The requested symbol. + start: Inclusive start date/datetime in ISO-8601 (RFC-3339) form. + end: Inclusive end date/datetime in ISO-8601 (RFC-3339) form. + frequency: Canonical frequency key or alias. + asset_class: The resolved asset class, ``"crypto"`` or ``"stock"``. + + Returns: + A ``(endpoint, params)`` tuple. + """ + if asset_class == "crypto": + # The {loc} path segment selects the venue group; "us" is the + # default consolidated US crypto feed. + endpoint = f"{self.crypto_base_url}/v1beta3/crypto/us/bars" + return endpoint, self._crypto_bars_params(symbol, frequency, start, end) + endpoint = f"{self.base_url}/v2/stocks/{symbol.upper()}/bars" + return endpoint, self._stock_bars_params(frequency, start, end) + def _check_response_status(self, status_code: int, symbol: str, response_text: str) -> None: """Map an HTTP status code to a typed provider exception. @@ -275,17 +348,22 @@ def _fetch_raw_data( start: str, end: str, frequency: str = "daily", - asset_class: str | None = None, # noqa: ARG002 + asset_class: str | None = None, ) -> dict[str, Any]: - """Fetch a single page of stock bars from Alpaca. + """Fetch a single page of bars from Alpaca. + + Routes to the stock or crypto bars endpoint based on the resolved asset + class. A ``BASE/QUOTE`` symbol (e.g. "BTC/USD") routes to crypto unless + an explicit ``asset_class`` overrides the inference. Args: - symbol: The equity symbol to fetch (case-insensitive). + symbol: The symbol to fetch (stocks are case-insensitive; crypto + symbols are preserved verbatim). start: Inclusive start date/datetime in ISO-8601 (RFC-3339) form. end: Inclusive end date/datetime in ISO-8601 (RFC-3339) form. frequency: Canonical frequency key or alias. - asset_class: Reserved for routing to other asset classes; the stock - branch is the default. + asset_class: Explicit asset class; inferred from the symbol when + ``None``. Returns: The parsed JSON payload containing the ``bars`` data. @@ -294,8 +372,8 @@ def _fetch_raw_data( RateLimitError, AuthenticationError, DataNotAvailableError, NetworkError: Per the HTTP status of the response. """ - endpoint = f"{self.base_url}/v2/stocks/{symbol.upper()}/bars" - params = self._stock_bars_params(frequency, start, end) + resolved = self._resolve_asset_class(symbol, asset_class) + endpoint, params = self._bars_request(symbol, start, end, frequency, resolved) try: self.rate_limiter.acquire(blocking=True) @@ -319,20 +397,22 @@ async def _fetch_raw_data_async( start: str, end: str, frequency: str = "daily", - asset_class: str | None = None, # noqa: ARG002 + asset_class: str | None = None, ) -> dict[str, Any]: - """Asynchronously fetch a single page of stock bars from Alpaca. + """Asynchronously fetch a single page of bars from Alpaca. Mirrors :meth:`_fetch_raw_data` over the async transport; the async - session carries the same two-header credentials as the sync session. + session carries the same two-header credentials as the sync session and + applies the same stock/crypto routing. Args: - symbol: The equity symbol to fetch (case-insensitive). + symbol: The symbol to fetch (stocks are case-insensitive; crypto + symbols are preserved verbatim). start: Inclusive start date/datetime in ISO-8601 (RFC-3339) form. end: Inclusive end date/datetime in ISO-8601 (RFC-3339) form. frequency: Canonical frequency key or alias. - asset_class: Reserved for routing to other asset classes; the stock - branch is the default. + asset_class: Explicit asset class; inferred from the symbol when + ``None``. Returns: The parsed JSON payload containing the ``bars`` data. @@ -341,8 +421,8 @@ async def _fetch_raw_data_async( RateLimitError, AuthenticationError, DataNotAvailableError, NetworkError: Per the HTTP status of the response. """ - endpoint = f"{self.base_url}/v2/stocks/{symbol.upper()}/bars" - params = self._stock_bars_params(frequency, start, end) + resolved = self._resolve_asset_class(symbol, asset_class) + endpoint, params = self._bars_request(symbol, start, end, frequency, resolved) try: self.rate_limiter.acquire(blocking=True) @@ -406,8 +486,33 @@ def _transform_data( if not bars: return self._create_empty_dataframe() + # Crypto symbols keep their BASE/QUOTE form verbatim; stocks uppercase. symbol_literal = symbol if asset_class == "crypto" else symbol.upper() + return self._bars_to_dataframe(bars, symbol_literal, symbol) + + def _bars_to_dataframe( + self, bars: list[dict[str, Any]], symbol_literal: str, symbol: str + ) -> pl.DataFrame: + """Convert a list of bar records into the canonical OHLCV DataFrame. + + Stock and crypto bars share the same ``o/h/l/c/v/t`` field names, so a + single conversion serves both branches; only the bars-extraction and the + symbol literal differ upstream. Crypto bars are 24/7, so no calendar or + session filtering is applied here. + + Args: + bars: Non-empty list of bar records with ``o/h/l/c/v/t`` keys. + symbol_literal: The value to write into the ``symbol`` column. + symbol: The requested symbol, used only for error context. + Returns: + A DataFrame with columns + ``[timestamp, symbol, open, high, low, close, volume]`` sorted by + timestamp. + + Raises: + DataValidationError: If the bars cannot be transformed. + """ try: df = pl.DataFrame(bars) df = df.rename( @@ -436,3 +541,126 @@ def _transform_data( raise DataValidationError( provider="alpaca", message=f"Failed to transform data for {symbol}" ) from err + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((NetworkError, RateLimitError)), + reraise=True, + ) + def fetch_ohlcv( + self, + symbol: str, + start: str, + end: str, + frequency: str = "daily", + asset_class: str | None = None, + ) -> pl.DataFrame: + """Fetch OHLCV bars for a stock or crypto symbol. + + This fully overrides the base template rather than delegating to it: the + base signature cannot thread ``asset_class`` through, and it acquires a + rate-limit token up front whereas this provider rate-limits per page + inside the fetch. The base contract is reproduced here directly -- + input validation, circuit-breaker-wrapped fetch/transform/validate, and + the same info-logging. + + Args: + symbol: The symbol to fetch. A ``BASE/QUOTE`` symbol (e.g. "BTC/USD") + is treated as crypto unless ``asset_class`` overrides it. + start: Start date in YYYY-MM-DD format (inclusive). + end: End date in YYYY-MM-DD format (inclusive). + frequency: Canonical frequency key or alias. + asset_class: Explicit asset class ("stock" or "crypto"); inferred + from the symbol when ``None``. + + Returns: + A DataFrame in the canonical OHLCV schema + ``[timestamp, symbol, open, high, low, close, volume]``. + """ + self.logger.info( + "Fetching OHLCV data", + symbol=symbol, + start=start, + end=end, + frequency=frequency, + provider=self.name, + ) + + self._validate_inputs(symbol, start, end, frequency) + resolved = self._resolve_asset_class(symbol, asset_class) + + def _fetch_and_process() -> pl.DataFrame: + raw_data = self._fetch_raw_data(symbol, start, end, frequency, asset_class=resolved) + df = self._transform_data(raw_data, symbol, asset_class=resolved) + return self._validate_ohlcv(df, self.name) + + validated_data = self._with_circuit_breaker(_fetch_and_process) + + self.logger.info( + "Successfully fetched OHLCV data", + symbol=symbol, + rows=len(validated_data), + provider=self.name, + ) + + return validated_data + + async def fetch_ohlcv_async( + self, + symbol: str, + start: str, + end: str, + frequency: str = "daily", + asset_class: str | None = None, + ) -> pl.DataFrame: + """Asynchronously fetch OHLCV bars for a stock or crypto symbol. + + Mirrors :meth:`fetch_ohlcv` over the async transport. The circuit + breaker wraps a synchronous callable, so the coroutine is awaited first + and only the transform/validate step runs inside the breaker, preserving + failure accounting without blocking the event loop on the fetch. + + Args: + symbol: The symbol to fetch. A ``BASE/QUOTE`` symbol (e.g. "BTC/USD") + is treated as crypto unless ``asset_class`` overrides it. + start: Start date in YYYY-MM-DD format (inclusive). + end: End date in YYYY-MM-DD format (inclusive). + frequency: Canonical frequency key or alias. + asset_class: Explicit asset class ("stock" or "crypto"); inferred + from the symbol when ``None``. + + Returns: + A DataFrame in the canonical OHLCV schema + ``[timestamp, symbol, open, high, low, close, volume]``. + """ + self.logger.info( + "Fetching OHLCV data (async)", + symbol=symbol, + start=start, + end=end, + frequency=frequency, + provider=self.name, + ) + + self._validate_inputs(symbol, start, end, frequency) + resolved = self._resolve_asset_class(symbol, asset_class) + + raw_data = await self._fetch_raw_data_async( + symbol, start, end, frequency, asset_class=resolved + ) + + def _transform_and_validate() -> pl.DataFrame: + df = self._transform_data(raw_data, symbol, asset_class=resolved) + return self._validate_ohlcv(df, self.name) + + validated_data = self._with_circuit_breaker(_transform_and_validate) + + self.logger.info( + "Successfully fetched OHLCV data (async)", + symbol=symbol, + rows=len(validated_data), + provider=self.name, + ) + + return validated_data diff --git a/tests/test_alpaca_provider.py b/tests/test_alpaca_provider.py index 38f94b9..d16b232 100644 --- a/tests/test_alpaca_provider.py +++ b/tests/test_alpaca_provider.py @@ -20,6 +20,13 @@ {"t": "2024-01-02T05:00:00Z", "o": 1.0, "h": 2.0, "l": 0.5, "c": 1.5, "v": 1000}, ] +# Crypto trades 24/7; 2024-01-06 is a Saturday, included to confirm no +# weekday-only assumption filters weekend bars out. +CRYPTO_BARS = [ + {"t": "2024-01-06T00:00:00Z", "o": 44.0, "h": 46.0, "l": 43.0, "c": 45.0, "v": 12.0}, + {"t": "2024-01-05T00:00:00Z", "o": 42.0, "h": 44.0, "l": 41.0, "c": 43.0, "v": 10.0}, +] + class TestAlpacaProviderInit: """Tests for provider construction and authentication.""" @@ -358,3 +365,163 @@ def test_empty_bars_returns_empty_dataframe(self, provider): assert df.columns == expected.columns assert df.schema == expected.schema assert df.height == 0 + + +class TestAssetRouting: + """Tests for routing symbols to the stock vs crypto endpoint.""" + + @pytest.fixture + def provider(self): + """Create a provider instance.""" + return AlpacaDataProvider(api_key="k", api_secret="s") + + def test_resolve_asset_class_by_slash(self, provider): + """A slash in the symbol resolves to crypto; otherwise stock.""" + assert provider._resolve_asset_class("BTC/USD", None) == "crypto" + assert provider._resolve_asset_class("AAPL", None) == "stock" + + def test_resolve_asset_class_override(self, provider): + """An explicit asset_class kwarg wins over the slash heuristic.""" + assert provider._resolve_asset_class("AAPL", "crypto") == "crypto" + assert provider._resolve_asset_class("BTC/USD", "stock") == "stock" + + def test_routes_crypto_by_slash(self, provider): + """A BTC/USD symbol hits the crypto bars endpoint; AAPL hits the stock one.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"bars": {}, "next_page_token": None} + + with ( + patch.object(provider.session, "get", return_value=mock_response) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + provider.fetch_ohlcv("BTC/USD", "2024-01-01", "2024-01-07", "daily") + crypto_url = mock_get.call_args.args[0] + + provider.fetch_ohlcv("AAPL", "2024-01-01", "2024-01-07", "daily") + stock_url = mock_get.call_args.args[0] + + assert "/v1beta3/crypto/" in crypto_url + assert crypto_url.endswith("/bars") + assert "/v2/stocks/" in stock_url + + def test_asset_class_override(self, provider): + """An explicit asset_class forces crypto routing regardless of the slash.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"bars": {}, "next_page_token": None} + + with ( + patch.object(provider.session, "get", return_value=mock_response) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + provider.fetch_ohlcv("X", "2024-01-01", "2024-01-07", "daily", asset_class="crypto") + url = mock_get.call_args.args[0] + + assert "/v1beta3/crypto/" in url + + def test_feed_not_sent_for_crypto(self, provider): + """Crypto request params must not include the stock-only feed parameter.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"bars": {}, "next_page_token": None} + + with ( + patch.object(provider.session, "get", return_value=mock_response) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + provider.fetch_ohlcv("BTC/USD", "2024-01-01", "2024-01-07", "daily") + params = mock_get.call_args.kwargs["params"] + + assert "feed" not in params + assert params["symbols"] == "BTC/USD" + + +class TestFetchRawDataCrypto: + """Tests for single-page crypto bar fetching.""" + + @pytest.fixture + def provider(self): + """Create a provider instance.""" + return AlpacaDataProvider(api_key="k", api_secret="s") + + def test_fetch_raw_data_crypto_success(self, provider): + """A 200 crypto response returns the parsed symbol-keyed bars structure.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "bars": {"BTC/USD": CRYPTO_BARS}, + "next_page_token": None, + } + + with ( + patch.object(provider.session, "get", return_value=mock_response) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + data = provider._fetch_raw_data( + "BTC/USD", "2024-01-01", "2024-01-07", "daily", asset_class="crypto" + ) + + assert data["bars"]["BTC/USD"] == CRYPTO_BARS + # Symbol is forwarded verbatim, slash preserved, via the symbols param. + assert mock_get.call_args.kwargs["params"]["symbols"] == "BTC/USD" + assert "/v1beta3/crypto/" in mock_get.call_args.args[0] + + @pytest.mark.asyncio + async def test_crypto_async(self, provider): + """The async crypto path returns the same parsed structure as the sync path.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "bars": {"BTC/USD": CRYPTO_BARS}, + "next_page_token": None, + } + + with ( + patch.object(provider, "_aget", new=AsyncMock(return_value=mock_response)) as mock_aget, + patch.object(provider.rate_limiter, "acquire"), + ): + data = await provider._fetch_raw_data_async( + "BTC/USD", "2024-01-01", "2024-01-07", "daily", asset_class="crypto" + ) + + assert data["bars"]["BTC/USD"] == CRYPTO_BARS + assert "/v1beta3/crypto/" in mock_aget.call_args.args[0] + assert "feed" not in mock_aget.call_args.kwargs["params"] + + +class TestTransformDataCrypto: + """Tests for transforming crypto bars into the standard schema.""" + + @pytest.fixture + def provider(self): + """Create a provider instance.""" + return AlpacaDataProvider(api_key="k", api_secret="s") + + def test_transform_data_crypto(self, provider): + """Crypto bars transform with the slash-preserved symbol and weekend bars kept.""" + raw = {"bars": {"BTC/USD": CRYPTO_BARS}, "next_page_token": None} + + df = provider._transform_data(raw, "BTC/USD", asset_class="crypto") + + assert df.columns == [ + "timestamp", + "symbol", + "open", + "high", + "low", + "close", + "volume", + ] + assert df.schema["timestamp"] == pl.Datetime + for col in ["open", "high", "low", "close", "volume"]: + assert df.schema[col] == pl.Float64 + # Slash preserved, not uppercased-stripped. + assert df["symbol"].to_list() == ["BTC/USD", "BTC/USD"] + # Every bar is retained, including the Saturday timestamp. + assert df.height == 2 + weekdays = df["timestamp"].dt.weekday().to_list() + # Polars weekday: Saturday == 6. The weekend bar must survive transform. + assert 6 in weekdays + timestamps = df["timestamp"].to_list() + assert timestamps == sorted(timestamps) From 6e75ba0337952add0238e8b5ba18e9a1117ef0ec Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 10 Jun 2026 14:48:51 +0000 Subject: [PATCH 04/13] feat: paginate Alpaca bars via next_page_token (SG-4) - Loop each fetch until next_page_token is null, merging per-page bars - Extend the stock bar list and concatenate crypto bars per symbol - Send a fresh params dict per page so the token never rewrites earlier requests - Acquire one rate-limit token per page, preserving the once-per-fetch guarantee --- src/ml4t/data/providers/alpaca.py | 80 +++++++++++++++++++++---- tests/test_alpaca_provider.py | 98 ++++++++++++++++++++++++++++++- 2 files changed, 165 insertions(+), 13 deletions(-) diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py index bb09218..844ab71 100644 --- a/src/ml4t/data/providers/alpaca.py +++ b/src/ml4t/data/providers/alpaca.py @@ -296,6 +296,31 @@ def _bars_request( endpoint = f"{self.base_url}/v2/stocks/{symbol.upper()}/bars" return endpoint, self._stock_bars_params(frequency, start, end) + def _merge_bars(self, accumulated: Any, page_bars: Any) -> Any: + """Merge one page's bars into the accumulator across both response shapes. + + The single-symbol endpoint returns ``bars`` as a list, which is simply + extended. The multi-symbol crypto endpoint returns ``bars`` as a dict of + per-symbol lists, whose entries are concatenated by symbol. The seed + accumulator may be ``None`` on the first page, in which case the page's + own container type is adopted. + + Args: + accumulated: The bars merged from prior pages, or ``None`` on page 1. + page_bars: The ``bars`` value from the current page. + + Returns: + The accumulator with the current page's bars appended. + """ + if isinstance(page_bars, dict): + merged: dict[str, list[Any]] = accumulated if isinstance(accumulated, dict) else {} + for sym, sym_bars in page_bars.items(): + merged.setdefault(sym, []).extend(sym_bars or []) + return merged + merged_list: list[Any] = accumulated if isinstance(accumulated, list) else [] + merged_list.extend(page_bars or []) + return merged_list + def _check_response_status(self, status_code: int, symbol: str, response_text: str) -> None: """Map an HTTP status code to a typed provider exception. @@ -350,11 +375,14 @@ def _fetch_raw_data( frequency: str = "daily", asset_class: str | None = None, ) -> dict[str, Any]: - """Fetch a single page of bars from Alpaca. + """Fetch a full bars range from Alpaca, following pagination to the end. Routes to the stock or crypto bars endpoint based on the resolved asset class. A ``BASE/QUOTE`` symbol (e.g. "BTC/USD") routes to crypto unless - an explicit ``asset_class`` overrides the inference. + an explicit ``asset_class`` overrides the inference. Requests are repeated + with each response's ``next_page_token`` until it is null/absent, and the + per-page bars are merged into the single shape ``_transform_data`` + expects (a list for stocks, a per-symbol dict for crypto). Args: symbol: The symbol to fetch (stocks are case-insensitive; crypto @@ -375,11 +403,25 @@ def _fetch_raw_data( resolved = self._resolve_asset_class(symbol, asset_class) endpoint, params = self._bars_request(symbol, start, end, frequency, resolved) + accumulated: Any = None + token: str | None = None try: - self.rate_limiter.acquire(blocking=True) - response = self.session.get(endpoint, params=params) - self._check_response_status(response.status_code, symbol, response.text) - return self._parse_bars_response(response) + while True: + # A fresh dict per page keeps each request's params independent; + # mutating one shared dict would otherwise rewrite the token on + # earlier requests that already went out. + page_params = {**params, "page_token": token} if token else params + # One acquisition per page is the only rate-limit gate, since + # fetch_ohlcv is fully overridden and the base never acquires. + self.rate_limiter.acquire(blocking=True) + response = self.session.get(endpoint, params=page_params) + self._check_response_status(response.status_code, symbol, response.text) + payload = self._parse_bars_response(response) + accumulated = self._merge_bars(accumulated, payload.get("bars")) + token = payload.get("next_page_token") + if not token: + break + return {"bars": accumulated} except ( AuthenticationError, RateLimitError, @@ -399,11 +441,11 @@ async def _fetch_raw_data_async( frequency: str = "daily", asset_class: str | None = None, ) -> dict[str, Any]: - """Asynchronously fetch a single page of bars from Alpaca. + """Asynchronously fetch a full bars range, following pagination to the end. Mirrors :meth:`_fetch_raw_data` over the async transport; the async session carries the same two-header credentials as the sync session and - applies the same stock/crypto routing. + applies the same stock/crypto routing and ``next_page_token`` loop. Args: symbol: The symbol to fetch (stocks are case-insensitive; crypto @@ -424,11 +466,25 @@ async def _fetch_raw_data_async( resolved = self._resolve_asset_class(symbol, asset_class) endpoint, params = self._bars_request(symbol, start, end, frequency, resolved) + accumulated: Any = None + token: str | None = None try: - self.rate_limiter.acquire(blocking=True) - response = await self._aget(endpoint, params=params) - self._check_response_status(response.status_code, symbol, response.text) - return self._parse_bars_response(response) + while True: + # A fresh dict per page keeps each request's params independent; + # mutating one shared dict would otherwise rewrite the token on + # earlier requests that already went out. + page_params = {**params, "page_token": token} if token else params + # One acquisition per page is the only rate-limit gate, since + # fetch_ohlcv is fully overridden and the base never acquires. + self.rate_limiter.acquire(blocking=True) + response = await self._aget(endpoint, params=page_params) + self._check_response_status(response.status_code, symbol, response.text) + payload = self._parse_bars_response(response) + accumulated = self._merge_bars(accumulated, payload.get("bars")) + token = payload.get("next_page_token") + if not token: + break + return {"bars": accumulated} except ( AuthenticationError, RateLimitError, diff --git a/tests/test_alpaca_provider.py b/tests/test_alpaca_provider.py index d16b232..0a64ac8 100644 --- a/tests/test_alpaca_provider.py +++ b/tests/test_alpaca_provider.py @@ -215,8 +215,9 @@ def test_fetch_raw_data_success(self, provider): ): data = provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + # The fetcher returns the merged-across-pages shape carrying only bars; + # the per-page next_page_token is consumed internally, not surfaced. assert data["bars"] == STOCK_BARS - assert data["next_page_token"] is None def test_feed_param_sent_for_stock(self): """The configured feed is forwarded in the request params.""" @@ -525,3 +526,98 @@ def test_transform_data_crypto(self, provider): assert 6 in weekdays timestamps = df["timestamp"].to_list() assert timestamps == sorted(timestamps) + + +def _page_response(bars, token): + """Build a mock bars response carrying a given next_page_token.""" + response = MagicMock() + response.status_code = 200 + response.text = "" + response.json.return_value = {"bars": bars, "next_page_token": token} + return response + + +class TestPagination: + """Tests for following next_page_token across multiple pages.""" + + @pytest.fixture + def provider(self): + """Create a provider instance.""" + return AlpacaDataProvider(api_key="k", api_secret="s") + + def test_follows_next_page_token(self, provider): + """Two stock pages are merged and page 2 is requested with the token.""" + page1 = _page_response([STOCK_BARS[0]], "abc") + page2 = _page_response([STOCK_BARS[1]], None) + + with ( + patch.object(provider.session, "get", side_effect=[page1, page2]) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + data = provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-07", "daily") + + assert data["bars"] == [STOCK_BARS[0], STOCK_BARS[1]] + assert mock_get.call_count == 2 + # The second request carries the page token returned by page 1. + assert mock_get.call_args_list[1].kwargs["params"]["page_token"] == "abc" + # The first request does not send a page token. + assert "page_token" not in mock_get.call_args_list[0].kwargs["params"] + + def test_single_page_no_extra_request(self, provider): + """A null token on page 1 yields exactly one request.""" + page1 = _page_response(STOCK_BARS, None) + + with ( + patch.object(provider.session, "get", side_effect=[page1]) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + data = provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-07", "daily") + + assert data["bars"] == STOCK_BARS + assert mock_get.call_count == 1 + + def test_pagination_crypto(self, provider): + """Two crypto pages merge per-symbol bar lists across pages.""" + page1 = _page_response({"BTC/USD": [CRYPTO_BARS[0]]}, "abc") + page2 = _page_response({"BTC/USD": [CRYPTO_BARS[1]]}, None) + + with ( + patch.object(provider.session, "get", side_effect=[page1, page2]) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + data = provider._fetch_raw_data( + "BTC/USD", "2024-01-01", "2024-01-07", "daily", asset_class="crypto" + ) + + assert data["bars"]["BTC/USD"] == [CRYPTO_BARS[0], CRYPTO_BARS[1]] + assert mock_get.call_count == 2 + assert mock_get.call_args_list[1].kwargs["params"]["page_token"] == "abc" + + def test_rate_limit_acquired_once_per_page(self, provider): + """acquire is called exactly once per page, not doubled per request.""" + page1 = _page_response([STOCK_BARS[0]], "abc") + page2 = _page_response([STOCK_BARS[1]], None) + + with ( + patch.object(provider.session, "get", side_effect=[page1, page2]), + patch.object(provider.rate_limiter, "acquire") as mock_acquire, + ): + provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-07", "daily") + + assert mock_acquire.call_count == 2 + + @pytest.mark.asyncio + async def test_follows_next_page_token_async(self, provider): + """The async path follows the token and merges both pages.""" + page1 = _page_response([STOCK_BARS[0]], "abc") + page2 = _page_response([STOCK_BARS[1]], None) + + with ( + patch.object(provider, "_aget", new=AsyncMock(side_effect=[page1, page2])) as mock_aget, + patch.object(provider.rate_limiter, "acquire"), + ): + data = await provider._fetch_raw_data_async("AAPL", "2024-01-01", "2024-01-07", "daily") + + assert data["bars"] == [STOCK_BARS[0], STOCK_BARS[1]] + assert mock_aget.call_count == 2 + assert mock_aget.call_args_list[1].kwargs["params"]["page_token"] == "abc" From 544cdfbc2a8343e7721b69be95697c9a0f9fc243 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 10 Jun 2026 15:08:03 +0000 Subject: [PATCH 05/13] feat: register Alpaca provider in registry and catalog (SG-5) Wire AlpacaDataProvider into the provider registry, env auto-detection, config env mapping, public exports, the config-validation warning list, and the provider catalog tables so provider="alpaca" resolves end-to-end. --- src/ml4t/data/config/models.py | 1 + src/ml4t/data/config/validator.py | 5 +- src/ml4t/data/managers/config_manager.py | 2 + src/ml4t/data/managers/provider_manager.py | 4 ++ src/ml4t/data/providers/AGENT.md | 1 + src/ml4t/data/providers/AGENTS.md | 1 + src/ml4t/data/providers/__init__.py | 7 ++ tests/test_alpaca_provider.py | 82 ++++++++++++++++++++++ 8 files changed, 102 insertions(+), 1 deletion(-) diff --git a/src/ml4t/data/config/models.py b/src/ml4t/data/config/models.py index be792a8..6f96cdb 100644 --- a/src/ml4t/data/config/models.py +++ b/src/ml4t/data/config/models.py @@ -53,6 +53,7 @@ class ProviderType(str, Enum): """Provider type enumeration.""" YAHOO = "yahoo" + ALPACA = "alpaca" BINANCE = "binance" CRYPTOCOMPARE = "cryptocompare" DATABENTO = "databento" diff --git a/src/ml4t/data/config/validator.py b/src/ml4t/data/config/validator.py index 171c1dc..b95f968 100644 --- a/src/ml4t/data/config/validator.py +++ b/src/ml4t/data/config/validator.py @@ -65,7 +65,10 @@ def _validate_providers(self) -> None: provider_names.add(provider.name) # Check API keys for providers that need them - if provider.type in ["massive", "polygon", "cryptocompare"] and not provider.api_key: + if ( + provider.type in ["massive", "polygon", "cryptocompare", "alpaca"] + and not provider.api_key + ): self.warnings.append( f"Provider {provider.name} ({provider.type}) may require an API key" ) diff --git a/src/ml4t/data/managers/config_manager.py b/src/ml4t/data/managers/config_manager.py index 12daa7d..0b5024e 100644 --- a/src/ml4t/data/managers/config_manager.py +++ b/src/ml4t/data/managers/config_manager.py @@ -42,6 +42,8 @@ class ConfigManager: # Environment variable to provider config mapping ENV_MAPPING = { + "ALPACA_API_KEY": ("alpaca", "api_key"), + "ALPACA_API_SECRET": ("alpaca", "api_secret"), "CRYPTOCOMPARE_API_KEY": ("cryptocompare", "api_key"), "DATABENTO_API_KEY": ("databento", "api_key"), "POLYGON_API_KEY": ("massive", "api_key"), diff --git a/src/ml4t/data/managers/provider_manager.py b/src/ml4t/data/managers/provider_manager.py index 79e6e89..611d73a 100644 --- a/src/ml4t/data/managers/provider_manager.py +++ b/src/ml4t/data/managers/provider_manager.py @@ -149,6 +149,7 @@ class ProviderManager: # Providers that require API keys KEYED_PROVIDERS = frozenset( { + "alpaca", "databento", "massive", "oanda", @@ -184,6 +185,7 @@ def _get_provider_classes(cls) -> dict[str, type]: return cls._PROVIDER_CLASSES # Import core providers + from ml4t.data.providers.alpaca import AlpacaDataProvider from ml4t.data.providers.binance import BinanceProvider from ml4t.data.providers.binance_public import BinancePublicProvider from ml4t.data.providers.cryptocompare import CryptoCompareProvider @@ -193,6 +195,7 @@ def _get_provider_classes(cls) -> dict[str, type]: from ml4t.data.providers.yahoo import YahooFinanceProvider provider_classes: dict[str, type] = { + "alpaca": AlpacaDataProvider, "binance": BinanceProvider, "binance_public": BinancePublicProvider, "cryptocompare": CryptoCompareProvider, @@ -244,6 +247,7 @@ def _detect_available_providers(self) -> None: # Check environment for API keys not in config env_to_provider = { + "ALPACA_API_KEY": "alpaca", "CRYPTOCOMPARE_API_KEY": "cryptocompare", "DATABENTO_API_KEY": "databento", "MASSIVE_API_KEY": "massive", diff --git a/src/ml4t/data/providers/AGENT.md b/src/ml4t/data/providers/AGENT.md index 60c1c93..244cc30 100644 --- a/src/ml4t/data/providers/AGENT.md +++ b/src/ml4t/data/providers/AGENT.md @@ -13,6 +13,7 @@ | File | Lines | Purpose | |------|-------|---------| | yahoo.py | 603 | Yahoo Finance (free) | +| alpaca.py | 722 | Alpaca US stocks + crypto (free IEX feed, two-cred auth) | | binance.py | 410 | Binance authenticated | | binance_public.py | 1430 | Binance public API | | eodhd.py | 464 | EOD Historical Data | diff --git a/src/ml4t/data/providers/AGENTS.md b/src/ml4t/data/providers/AGENTS.md index f0e3bc2..87395e1 100644 --- a/src/ml4t/data/providers/AGENTS.md +++ b/src/ml4t/data/providers/AGENTS.md @@ -13,6 +13,7 @@ | File | Lines | Purpose | |------|-------|---------| | yahoo.py | 603 | Yahoo Finance (free) | +| alpaca.py | 722 | Alpaca US stocks + crypto (free IEX feed, two-cred auth) | | binance_api.py | 410 | Binance REST API | | binance_bulk.py | 1430 | Binance bulk historical archive | | eodhd.py | 464 | EOD Historical Data | diff --git a/src/ml4t/data/providers/__init__.py b/src/ml4t/data/providers/__init__.py index fdf7e95..3de34a5 100644 --- a/src/ml4t/data/providers/__init__.py +++ b/src/ml4t/data/providers/__init__.py @@ -5,6 +5,7 @@ Available Providers (20 live + 3 synthetic/testing): - BaseProvider: Abstract base class for all providers - YahooFinanceProvider: Yahoo Finance (free, no API key) + - AlpacaDataProvider: Alpaca US stocks and crypto (free IEX feed, two-credential auth) - TiingoProvider: Tiingo stocks (free tier: 1000 req/day, 500 symbols/month) - FinnhubProvider: Finnhub multi-asset data (free tier: 60 req/min) - EODHDProvider: EODHD global equities (free tier: 500 req/day, 1 year depth) @@ -49,6 +50,11 @@ except ImportError: YahooFinanceProvider = None # type: ignore +try: + from ml4t.data.providers.alpaca import AlpacaDataProvider +except ImportError: + AlpacaDataProvider = None # type: ignore + try: from ml4t.data.providers.tiingo import TiingoProvider except ImportError: @@ -149,6 +155,7 @@ "Provider", # Equity providers "YahooFinanceProvider", + "AlpacaDataProvider", "TiingoProvider", "FinnhubProvider", "EODHDProvider", diff --git a/tests/test_alpaca_provider.py b/tests/test_alpaca_provider.py index 0a64ac8..3bd7126 100644 --- a/tests/test_alpaca_provider.py +++ b/tests/test_alpaca_provider.py @@ -621,3 +621,85 @@ async def test_follows_next_page_token_async(self, provider): assert data["bars"] == [STOCK_BARS[0], STOCK_BARS[1]] assert mock_aget.call_count == 2 assert mock_aget.call_args_list[1].kwargs["params"]["page_token"] == "abc" + + +class TestAlpacaRegistration: + """Tests that wire Alpaca into the registry, config, and exports.""" + + def test_registry_resolves_alpaca(self): + """The ProviderManager registry maps 'alpaca' to AlpacaDataProvider.""" + from ml4t.data.managers.provider_manager import ProviderManager + + classes = ProviderManager._get_provider_classes() + + assert "alpaca" in classes + assert classes["alpaca"] is AlpacaDataProvider + + def test_alpaca_in_keyed_providers(self): + """Alpaca is registered as an API-key-requiring provider.""" + from ml4t.data.managers.provider_manager import ProviderManager + + assert "alpaca" in ProviderManager.KEYED_PROVIDERS + + def test_alpaca_exported(self): + """AlpacaDataProvider is importable from the package and in __all__.""" + import ml4t.data.providers as providers + from ml4t.data.providers import AlpacaDataProvider as Exported + + assert Exported is AlpacaDataProvider + assert "AlpacaDataProvider" in providers.__all__ + + def test_env_autodetect_alpaca(self): + """ALPACA_API_KEY auto-detection resolves the alpaca provider.""" + import os + + from ml4t.data.managers.provider_manager import ProviderManager + + env = {k: v for k, v in os.environ.items() if not k.startswith("ALPACA_")} + env["ALPACA_API_KEY"] = "k" + with patch.dict(os.environ, env, clear=True): + manager = ProviderManager(config={"providers": {}}) + assert "alpaca" in manager.available_providers + + def test_provider_type_enum_accepts_alpaca(self): + """ProviderType accepts 'alpaca' and ProviderConfig validates with it.""" + from ml4t.data.config.models import ProviderConfig, ProviderType + + assert ProviderType("alpaca") == ProviderType.ALPACA + + config = ProviderConfig(name="x", type="alpaca") + assert config.type == ProviderType.ALPACA + + def test_config_manager_injects_both_credentials(self): + """Both Alpaca env credentials land in the resolved provider config.""" + import os + + from ml4t.data.managers.config_manager import ConfigManager + + env = {k: v for k, v in os.environ.items() if not k.startswith("ALPACA_")} + env["ALPACA_API_KEY"] = "k" + env["ALPACA_API_SECRET"] = "s" + with patch.dict(os.environ, env, clear=True): + manager = ConfigManager() + provider_config = manager.get_provider_config("alpaca") + + assert provider_config.get("api_key") == "k" + assert provider_config.get("api_secret") == "s" + + def test_alpaca_missing_secret_fails_clearly(self): + """A config with api_key but no api_secret fails with a clear error. + + Availability is keyed on api_key alone, so an alpaca entry with only a + key is marked available but fails loudly at construction rather than + silently producing a broken provider. + """ + import os + + from ml4t.data.managers.provider_manager import ProviderManager + + env = {k: v for k, v in os.environ.items() if not k.startswith("ALPACA_")} + with patch.dict(os.environ, env, clear=True): + manager = ProviderManager(config={"providers": {"alpaca": {"api_key": "k"}}}) + + with pytest.raises(ValueError, match="secret"): + manager.get_provider("alpaca") From c54ff7701c97b143c19eca959e6efce9c3fd27b6 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 10 Jun 2026 15:30:30 +0000 Subject: [PATCH 06/13] test: document Alpaca provider and add integration coverage (SG-6) - Expand module docstring with API URL, feed/tier, env vars, rate limit, sync and async examples - Add skipped-by-default integration test fetching AAPL daily and BTC/USD minute - Update provider registration test to include the registered Alpaca provider Task: SG-6 --- src/ml4t/data/providers/alpaca.py | 45 ++++++++++---- tests/integration/test_alpaca.py | 95 +++++++++++++++++++++++++++++ tests/test_provider_registration.py | 24 ++++---- 3 files changed, 140 insertions(+), 24 deletions(-) create mode 100644 tests/integration/test_alpaca.py diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py index 844ab71..00f39cb 100644 --- a/src/ml4t/data/providers/alpaca.py +++ b/src/ml4t/data/providers/alpaca.py @@ -1,32 +1,53 @@ """Alpaca data provider. Alpaca provides long-history, high-frequency market data across multiple asset -classes including equities and crypto, served over a REST API. +classes including US equities and crypto, served over a historical REST API. +Daily, hourly, and minute OHLCV bars are supported for both asset classes from a +single symbol-routed provider. -API Documentation: https://docs.alpaca.markets/ +API Documentation: https://docs.alpaca.markets/us/docs/about-market-data-api + +Symbol Format: +- Stocks use a plain ticker (e.g. "AAPL"); the symbol is uppercased and routed to + the stock bars endpoint. +- Crypto uses ``BASE/QUOTE`` (e.g. "BTC/USD"); the slash routes the request to the + crypto bars endpoint, and the symbol is preserved verbatim. Authentication: - Two credentials are required: an API key id and an API secret, sent as the - ``APCA-API-KEY-ID`` / ``APCA-API-SECRET-KEY`` request headers. + ``APCA-API-KEY-ID`` / ``APCA-API-SECRET-KEY`` request headers on both the sync + and async sessions. - Set ``ALPACA_API_KEY`` / ``ALPACA_API_SECRET`` (primary, project convention) or pass ``api_key`` / ``api_secret`` to the constructor. Alpaca's own SDK/CLI names ``APCA_API_KEY_ID`` / ``APCA_API_SECRET_KEY`` are accepted as a fallback. +- Get a free key at: https://alpaca.markets/ -Feed selection: -- ``feed`` defaults to ``"iex"`` (free tier, ~15-min delayed, thinner coverage). - Paid subscribers pass ``feed="sip"`` once at construction. +Feed selection (stocks): +- ``feed`` defaults to ``"iex"``: the free tier, served from a single exchange + with ~15-min-delayed, thinner coverage. +- Paid subscribers pass ``feed="sip"`` once at construction to unlock the + consolidated tape (100% of US market volume across all exchanges). +- Crypto has a single consolidated feed, so ``feed`` does not apply to it. Rate limiting: - ``DEFAULT_RATE_LIMIT`` is a conservative client-side throttle of 200 calls per - minute reflecting the commonly-cited Basic (free) plan figure. The API itself - enforces limits via HTTP 429 and rate-limit response headers rather than a - fixed documented number, so this default is tentative and overridable via the - ``rate_limit`` constructor argument. + minute, reflecting the Basic (free) plan figure documented at + docs.alpaca.markets "About Market Data API" and Alpaca support + (usage-limit-api-calls). Alpaca enforces the limit server-side via HTTP 429 and + rate-limit response headers; this client-side default front-runs that and is + overridable via the ``rate_limit`` constructor argument. -Example: +Sync Example: >>> from ml4t.data.providers.alpaca import AlpacaDataProvider - >>> provider = AlpacaDataProvider(api_key="k", api_secret="s") + >>> provider = AlpacaDataProvider(api_key="key", api_secret="secret") + >>> df = provider.fetch_ohlcv("AAPL", "2024-01-01", "2024-01-31") + >>> crypto = provider.fetch_ohlcv("BTC/USD", "2024-01-01", "2024-01-02", + ... frequency="minute") >>> provider.close() + +Async Example: + >>> async with AlpacaDataProvider(api_key="key", api_secret="secret") as provider: + ... df = await provider.fetch_ohlcv_async("AAPL", "2024-01-01", "2024-01-31") """ from __future__ import annotations diff --git a/tests/integration/test_alpaca.py b/tests/integration/test_alpaca.py new file mode 100644 index 0000000..2982dbe --- /dev/null +++ b/tests/integration/test_alpaca.py @@ -0,0 +1,95 @@ +"""Integration tests for Alpaca provider (real API calls). + +These tests verify the Alpaca provider works correctly with actual API calls. + +Requirements: + - ALPACA_API_KEY and ALPACA_API_SECRET environment variables must be set + - Free tier uses the IEX feed (~15-min delayed); ~200 requests/min + - API key from: https://alpaca.markets/ + +Test Coverage: + - Stock daily OHLCV data (AAPL) + - Crypto minute OHLCV data (BTC/USD) + +IMPORTANT: + These tests are excluded from the default suite (the `integration` marker is + deselected in pyproject) and are skipped unless credentials are set. They are + smoke/documentation coverage, not part of the green gate. +""" + +import os + +import polars as pl +import pytest + +from ml4t.data.providers.alpaca import AlpacaDataProvider + +ALPACA_API_KEY = os.getenv("ALPACA_API_KEY") +ALPACA_API_SECRET = os.getenv("ALPACA_API_SECRET") + +pytestmark = [ + pytest.mark.integration, + pytest.mark.skipif( + not (ALPACA_API_KEY and ALPACA_API_SECRET), + reason="ALPACA_API_KEY/ALPACA_API_SECRET not set - get a free key at " + "https://alpaca.markets/", + ), +] + +REQUIRED_COLS = ["timestamp", "symbol", "open", "high", "low", "close", "volume"] + + +@pytest.fixture +def provider(): + """Create an Alpaca provider with credentials from the environment.""" + provider = AlpacaDataProvider(api_key=ALPACA_API_KEY, api_secret=ALPACA_API_SECRET) + yield provider + provider.close() + + +class TestAlpacaProvider: + """Test the Alpaca provider against the real Market Data API.""" + + def test_provider_initialization(self): + """Provider initializes from credentials and reports its name. + + This test makes no API calls. + """ + provider = AlpacaDataProvider(api_key=ALPACA_API_KEY, api_secret=ALPACA_API_SECRET) + assert provider.name == "alpaca" + assert provider.feed == "iex" + provider.close() + + def test_fetch_stock_daily(self, provider): + """Fetch daily stock bars for AAPL with a real API call.""" + df = provider.fetch_ohlcv( + symbol="AAPL", + start="2024-01-01", + end="2024-01-31", + frequency="daily", + ) + + assert isinstance(df, pl.DataFrame) + assert not df.is_empty(), "Should fetch some daily data for AAPL" + assert all(col in df.columns for col in REQUIRED_COLS) + assert df["timestamp"].dtype == pl.Datetime + assert df["symbol"].dtype == pl.String + assert df["close"].dtype == pl.Float64 + assert (df["high"] >= df["low"]).all(), "High should be >= Low" + assert (df["symbol"] == "AAPL").all() + + def test_fetch_crypto_minute(self, provider): + """Fetch minute crypto bars for BTC/USD with a real API call.""" + df = provider.fetch_ohlcv( + symbol="BTC/USD", + start="2024-01-01T00:00:00Z", + end="2024-01-01T01:00:00Z", + frequency="minute", + ) + + assert isinstance(df, pl.DataFrame) + assert not df.is_empty(), "Should fetch some minute data for BTC/USD" + assert all(col in df.columns for col in REQUIRED_COLS) + assert (df["high"] >= df["low"]).all(), "High should be >= Low" + # Crypto symbols keep their BASE/QUOTE form verbatim. + assert (df["symbol"] == "BTC/USD").all() diff --git a/tests/test_provider_registration.py b/tests/test_provider_registration.py index 2244c31..d70e31b 100644 --- a/tests/test_provider_registration.py +++ b/tests/test_provider_registration.py @@ -1,6 +1,7 @@ """Test provider registration in DataManager.""" from ml4t.data.data_manager import DataManager +from ml4t.data.providers.alpaca import AlpacaDataProvider from ml4t.data.providers.binance import BinanceProvider from ml4t.data.providers.binance_public import BinancePublicProvider from ml4t.data.providers.cryptocompare import CryptoCompareProvider @@ -22,6 +23,7 @@ def test_all_providers_registered(): - Historical: wiki_prices (standalone, local file only) """ expected_providers = { + "alpaca": AlpacaDataProvider, "binance": BinanceProvider, "binance_public": BinancePublicProvider, "cryptocompare": CryptoCompareProvider, @@ -36,13 +38,14 @@ def test_all_providers_registered(): } assert expected_providers == DataManager.PROVIDER_CLASSES - assert len(DataManager.PROVIDER_CLASSES) == 11 + assert len(DataManager.PROVIDER_CLASSES) == 12 def test_provider_imports_work(): """Test that all provider imports are functional.""" # This test ensures the imports at the top of data_manager.py work providers = [ + AlpacaDataProvider, BinanceProvider, BinancePublicProvider, CryptoCompareProvider, @@ -80,19 +83,15 @@ def test_provider_instantiation(): for provider_name, _provider_class in DataManager.PROVIDER_CLASSES.items(): # Try to get provider instance try: - if provider_name in ["databento", "massive", "oanda", "polygon"]: + keyed = ["alpaca", "databento", "massive", "oanda", "polygon"] + if provider_name in keyed: # These require API keys, so we expect None or skip continue provider = dm._get_provider(provider_name) - assert provider is not None or provider_name in [ - "databento", - "massive", - "oanda", - "polygon", - ] + assert provider is not None or provider_name in keyed except Exception as e: # Only keyed providers should fail without API keys - assert provider_name in ["databento", "massive", "oanda", "polygon"], ( + assert provider_name in ["alpaca", "databento", "massive", "oanda", "polygon"], ( f"Unexpected error for {provider_name}: {e}" ) @@ -101,18 +100,19 @@ def test_provider_count(): """Test that we have the expected OHLCV providers registered in DataManager. Provider categories: - - In DataManager (11): General OHLCV providers with unified fetch_ohlcv() interface + - In DataManager (12): General OHLCV providers with unified fetch_ohlcv() interface - Standalone (5): Specialized providers with unique APIs - Factor data: aqr, fama_french - Prediction markets: kalshi, polymarket - Historical: wiki_prices """ - assert len(DataManager.PROVIDER_CLASSES) == 11, ( - f"Expected 11 providers, got {len(DataManager.PROVIDER_CLASSES)}" + assert len(DataManager.PROVIDER_CLASSES) == 12, ( + f"Expected 12 providers, got {len(DataManager.PROVIDER_CLASSES)}" ) # List all provider names for clarity provider_names = list(DataManager.PROVIDER_CLASSES.keys()) + assert "alpaca" in provider_names assert "binance" in provider_names assert "binance_public" in provider_names assert "cryptocompare" in provider_names From dedcdfbecf17c6277572e631758950a3d0771e0a Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 10 Jun 2026 16:56:59 +0000 Subject: [PATCH 07/13] fix: harden Alpaca fetch validation, retry, and async resilience - Accept RFC-3339 datetime bounds and validate explicit asset_class - Retry per pagination page, honoring 429 rate-limit headers - Run async fetches inside the circuit breaker via new call_async - Uppercase crypto symbols and expose a stock adjustment option --- src/ml4t/data/providers/alpaca.py | 469 +++++++++++++----- .../data/providers/mixins/circuit_breaker.py | 66 ++- 2 files changed, 412 insertions(+), 123 deletions(-) diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py index 00f39cb..70dc68e 100644 --- a/src/ml4t/data/providers/alpaca.py +++ b/src/ml4t/data/providers/alpaca.py @@ -11,7 +11,13 @@ - Stocks use a plain ticker (e.g. "AAPL"); the symbol is uppercased and routed to the stock bars endpoint. - Crypto uses ``BASE/QUOTE`` (e.g. "BTC/USD"); the slash routes the request to the - crypto bars endpoint, and the symbol is preserved verbatim. + crypto bars endpoint, and the symbol is uppercased with the slash preserved so + the canonical uppercase-symbol output contract holds for both asset classes. + +Date bounds: +- ``start``/``end`` accept either a date (``YYYY-MM-DD``) or an RFC-3339 datetime + (e.g. ``2024-01-01T00:00:00Z``), both inclusive. Datetime bounds matter for + minute/hour frequencies, where a sub-day window avoids paginating a full day. Authentication: - Two credentials are required: an API key id and an API secret, sent as the @@ -23,12 +29,20 @@ - Get a free key at: https://alpaca.markets/ Feed selection (stocks): -- ``feed`` defaults to ``"iex"``: the free tier, served from a single exchange - with ~15-min-delayed, thinner coverage. +- ``feed`` defaults to ``"iex"``: the free tier, served in real time from a + single exchange (IEX, roughly 2-3% of US market volume), so coverage is + thinner than the consolidated tape. The Basic plan additionally cannot query + the most recent 15 minutes of SIP data. - Paid subscribers pass ``feed="sip"`` once at construction to unlock the consolidated tape (100% of US market volume across all exchanges). - Crypto has a single consolidated feed, so ``feed`` does not apply to it. +Price adjustment (stocks): +- ``adjustment`` defaults to ``"raw"`` (Alpaca's own default): stock bars are + NOT adjusted for splits or dividends. Pass ``adjustment="split"``, + ``"dividend"``, or ``"all"`` at construction for adjusted bars. Crypto bars + have no adjustment concept. + Rate limiting: - ``DEFAULT_RATE_LIMIT`` is a conservative client-side throttle of 200 calls per minute, reflecting the Basic (free) plan figure documented at @@ -52,19 +66,27 @@ from __future__ import annotations +import asyncio import os +import time +from datetime import UTC, datetime from typing import Any, ClassVar import polars as pl import structlog -from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential +from tenacity import ( + RetryCallState, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from ml4t.data.core.exceptions import ( AuthenticationError, DataNotAvailableError, DataValidationError, NetworkError, - ProviderError, RateLimitError, ) from ml4t.data.providers.base import BaseProvider @@ -73,6 +95,37 @@ logger = structlog.get_logger() +_EXPONENTIAL_WAIT = wait_exponential(multiplier=1, min=4, max=10) + + +def _retry_wait(retry_state: RetryCallState) -> float: + """Honor a server-provided retry delay; otherwise back off exponentially. + + A 429 carries the server's own ``retry_after`` (derived from its rate-limit + headers), which beats a generic exponential guess. The delay is capped so a + pathological header cannot stall a fetch for hours. + + Args: + retry_state: The tenacity retry state for the failed attempt. + + Returns: + Seconds to wait before the next attempt. + """ + exception = retry_state.outcome.exception() if retry_state.outcome else None + if isinstance(exception, RateLimitError) and exception.retry_after is not None: + return min(float(exception.retry_after), 60.0) + return _EXPONENTIAL_WAIT(retry_state) + + +# Retries transient failures at the page level so a failure on page N does not +# refetch pages 1..N-1, and a multi-page fetch cannot amplify API load. +_PAGE_RETRY = retry( + stop=stop_after_attempt(3), + wait=_retry_wait, + retry=retry_if_exception_type((NetworkError, RateLimitError)), + reraise=True, +) + class AlpacaDataProvider(AsyncSessionMixin, BaseProvider): """Alpaca market data provider. @@ -91,11 +144,17 @@ class AlpacaDataProvider(AsyncSessionMixin, BaseProvider): """ # 200 requests/min — Alpaca Basic (free) plan, per docs.alpaca.markets - # "About Market Data API" and Alpaca support (usage-limit-api-calls). - # Tentative: the API enforces via HTTP 429 + rate-limit headers, not a fixed - # documented number — verify from response headers and adjust if needed. + # "About Market Data API" and Alpaca support (usage-limit-api-calls). The + # API also enforces this server-side via HTTP 429 + rate-limit headers. DEFAULT_RATE_LIMIT: ClassVar[tuple[int, float]] = (200, 60.0) + # Maximum bars per page accepted by both bars endpoints; pagination follows + # next_page_token beyond this. + PAGE_LIMIT: ClassVar[int] = 10000 + + # Asset classes a request can be routed to. + ASSET_CLASSES: ClassVar[frozenset[str]] = frozenset({"stock", "crypto"}) + # Map canonical frequency keys and aliases to Alpaca timeframe strings FREQUENCY_MAP: ClassVar[dict[str, str]] = { "daily": "1Day", @@ -116,6 +175,7 @@ def __init__( api_key: str | None = None, api_secret: str | None = None, feed: str = "iex", + adjustment: str = "raw", rate_limit: tuple[int, float] | None = None, **kwargs: Any, ) -> None: @@ -126,6 +186,9 @@ def __init__( api_secret: Alpaca API secret (or set ALPACA_API_SECRET / APCA_API_SECRET_KEY). feed: Data feed, "iex" (free) or "sip" (paid). + adjustment: Stock price adjustment, one of "raw" (default, matching + Alpaca's own default: no split/dividend adjustment), "split", + "dividend", or "all". Ignored for crypto. rate_limit: Optional custom (calls, period_seconds) override. **kwargs: Additional arguments forwarded to BaseProvider. @@ -145,8 +208,8 @@ def __init__( ) self.feed = feed + self.adjustment = adjustment self.base_url = "https://data.alpaca.markets" - self.crypto_base_url = "https://data.alpaca.markets" # Built once, reused for both the sync and async sessions so the two # credentials accompany every request on either transport. @@ -203,6 +266,61 @@ def capabilities(self) -> ProviderCapabilities: rate_limit=self.DEFAULT_RATE_LIMIT, ) + def _validate_inputs( + self, + symbol: str, + start: str, + end: str, + frequency: str, # noqa: ARG002 + ) -> None: + """Validate input parameters, accepting dates or RFC-3339 datetimes. + + Overrides the base date-only validation because Alpaca's bars endpoints + accept full RFC-3339 datetime bounds, which matter for minute/hour + frequencies where a sub-day window is the natural request shape. + + Args: + symbol: Symbol to fetch. + start: Inclusive start, ``YYYY-MM-DD`` or RFC-3339 datetime. + end: Inclusive end, ``YYYY-MM-DD`` or RFC-3339 datetime. + frequency: Data frequency (validated later against FREQUENCY_MAP). + + Raises: + ValueError: If the symbol is empty, a bound cannot be parsed, or + start is after end. + """ + if not symbol or not symbol.strip(): + raise ValueError("Symbol cannot be empty") + + try: + start_dt = self._parse_time_bound(start) + end_dt = self._parse_time_bound(end) + except ValueError as e: + raise ValueError( + f"Invalid date format (expected YYYY-MM-DD or RFC-3339 datetime): {e}" + ) from e + + if start_dt > end_dt: + raise ValueError("Start date must be before or equal to end date") + + @staticmethod + def _parse_time_bound(value: str) -> datetime: + """Parse a date or RFC-3339 datetime bound into a UTC-aware datetime. + + Args: + value: ``YYYY-MM-DD`` or RFC-3339 datetime string (a trailing ``Z`` + is accepted). + + Returns: + A timezone-aware datetime; naive inputs are assumed UTC so mixed + date/datetime bounds stay comparable. + + Raises: + ValueError: If the value is not a valid ISO-8601 date or datetime. + """ + parsed = datetime.fromisoformat(value) + return parsed if parsed.tzinfo else parsed.replace(tzinfo=UTC) + def _map_frequency(self, frequency: str) -> str: """Map a canonical frequency to an Alpaca timeframe string. @@ -227,11 +345,12 @@ def _map_frequency(self, frequency: str) -> str: ) from err def _resolve_asset_class(self, symbol: str, asset_class: str | None) -> str: - """Resolve the asset class for a request. + """Resolve and validate the asset class for a request. - An explicit ``asset_class`` always wins. Otherwise a ``BASE/QUOTE`` - symbol (e.g. ``"BTC/USD"``) is treated as crypto and everything else as - a stock. + An explicit ``asset_class`` always wins but must name a real asset + class — a typo would otherwise silently route to the wrong endpoint and + surface as a misleading 404. Otherwise a ``BASE/QUOTE`` symbol (e.g. + ``"BTC/USD"``) is treated as crypto and everything else as a stock. Args: symbol: The requested symbol. @@ -239,9 +358,22 @@ def _resolve_asset_class(self, symbol: str, asset_class: str | None) -> str: Returns: Either ``"crypto"`` or ``"stock"``. + + Raises: + DataValidationError: If an explicit ``asset_class`` is not one of + the supported asset classes (case-insensitively). """ if asset_class is not None: - return asset_class + normalized = asset_class.lower() + if normalized not in self.ASSET_CLASSES: + raise DataValidationError( + provider="alpaca", + message=f"Invalid asset_class '{asset_class}'. " + f"Supported: {sorted(self.ASSET_CLASSES)}", + field="asset_class", + value=asset_class, + ) + return normalized return "crypto" if "/" in symbol else "stock" def _stock_bars_params(self, frequency: str, start: str, end: str) -> dict[str, Any]: @@ -253,14 +385,16 @@ def _stock_bars_params(self, frequency: str, start: str, end: str) -> dict[str, end: Inclusive end date/datetime in ISO-8601 (RFC-3339) form. Returns: - The query parameter mapping, including the configured data feed. + The query parameter mapping, including the configured data feed and + price adjustment. """ return { "timeframe": self._map_frequency(frequency), "start": start, "end": end, - "limit": 10000, + "limit": self.PAGE_LIMIT, "feed": self.feed, + "adjustment": self.adjustment, } def _crypto_bars_params( @@ -269,8 +403,9 @@ def _crypto_bars_params( """Build the query parameters for a crypto bars request. The crypto bars endpoint is multi-symbol: the symbol travels in the - ``symbols`` parameter (preserving the ``BASE/QUOTE`` form verbatim) and - no ``feed`` is sent, since crypto has a single consolidated feed. + ``symbols`` parameter (uppercased ``BASE/QUOTE`` form, the slash + preserved) and no ``feed`` is sent, since crypto has a single + consolidated feed. Args: symbol: The crypto symbol in ``BASE/QUOTE`` form (e.g. "BTC/USD"). @@ -282,11 +417,11 @@ def _crypto_bars_params( The query parameter mapping for the crypto bars endpoint. """ return { - "symbols": symbol, + "symbols": symbol.upper(), "timeframe": self._map_frequency(frequency), "start": start, "end": end, - "limit": 10000, + "limit": self.PAGE_LIMIT, } def _bars_request( @@ -296,8 +431,8 @@ def _bars_request( Branches on the resolved asset class so the sync and async fetchers share one routing decision. The stock branch uppercases the symbol into the - path; the crypto branch hits the multi-symbol endpoint and preserves the - ``BASE/QUOTE`` symbol verbatim. + path; the crypto branch hits the multi-symbol endpoint with the + uppercased ``BASE/QUOTE`` symbol (slash preserved). Args: symbol: The requested symbol. @@ -312,7 +447,7 @@ def _bars_request( if asset_class == "crypto": # The {loc} path segment selects the venue group; "us" is the # default consolidated US crypto feed. - endpoint = f"{self.crypto_base_url}/v1beta3/crypto/us/bars" + endpoint = f"{self.base_url}/v1beta3/crypto/us/bars" return endpoint, self._crypto_bars_params(symbol, frequency, start, end) endpoint = f"{self.base_url}/v2/stocks/{symbol.upper()}/bars" return endpoint, self._stock_bars_params(frequency, start, end) @@ -342,26 +477,28 @@ def _merge_bars(self, accumulated: Any, page_bars: Any) -> Any: merged_list.extend(page_bars or []) return merged_list - def _check_response_status(self, status_code: int, symbol: str, response_text: str) -> None: - """Map an HTTP status code to a typed provider exception. + def _check_response_status(self, response: Any, symbol: str) -> None: + """Map an HTTP error response to a typed provider exception. The status code is inspected directly rather than delegating to ``raise_for_status``, so that each documented failure mode surfaces as a specific provider error instead of a generic transport error. Args: - status_code: The HTTP status code from the response. + response: The HTTP response (status code, headers, and body text + are consulted). symbol: The requested symbol, for error context. - response_text: The response body, included in network-error messages. Raises: - RateLimitError: On HTTP 429. + RateLimitError: On HTTP 429, with ``retry_after`` derived from the + rate-limit response headers when present. AuthenticationError: On HTTP 401/403. DataNotAvailableError: On HTTP 404. NetworkError: On any other non-200 status. """ + status_code = response.status_code if status_code == 429: - raise RateLimitError(provider="alpaca", retry_after=60.0) + raise RateLimitError(provider="alpaca", retry_after=self._retry_after_seconds(response)) if status_code in (401, 403): raise AuthenticationError( provider="alpaca", message="Invalid API key or unauthorized access" @@ -369,7 +506,36 @@ def _check_response_status(self, status_code: int, symbol: str, response_text: s if status_code == 404: raise DataNotAvailableError(provider="alpaca", symbol=symbol) if status_code != 200: - raise NetworkError(provider="alpaca", message=f"HTTP {status_code}: {response_text}") + raise NetworkError(provider="alpaca", message=f"HTTP {status_code}: {response.text}") + + @staticmethod + def _retry_after_seconds(response: Any) -> float: + """Derive a retry delay from a 429 response's rate-limit headers. + + Prefers an explicit ``Retry-After`` (delay in seconds), then + ``X-RateLimit-Reset`` (epoch seconds of the window reset), and falls + back to one rate-limit period when neither header is usable. + + Args: + response: The 429 HTTP response. + + Returns: + Seconds to wait before retrying (never negative). + """ + headers = getattr(response, "headers", None) or {} + retry_after = headers.get("Retry-After") + if retry_after is not None: + try: + return max(float(retry_after), 0.0) + except (TypeError, ValueError): + pass + reset = headers.get("X-RateLimit-Reset") + if reset is not None: + try: + return max(float(reset) - time.time(), 0.0) + except (TypeError, ValueError): + pass + return 60.0 def _parse_bars_response(self, response: Any) -> dict[str, Any]: """Parse a validated bars response into a JSON dict. @@ -388,6 +554,73 @@ def _parse_bars_response(self, response: Any) -> dict[str, Any]: except Exception as err: raise NetworkError(provider="alpaca", message="Failed to parse JSON response") from err + @_PAGE_RETRY + def _get_page(self, endpoint: str, params: dict[str, Any], symbol: str) -> dict[str, Any]: + """Fetch and parse one bars page, retrying transient failures. + + Retry lives at the page level so a failure on a later page never + refetches earlier pages, and a 429's server-provided delay is honored + per attempt. Each attempt acquires its own rate-limit token, since each + attempt is a real request. + + Args: + endpoint: The bars endpoint URL. + params: The query parameters for this page. + symbol: The requested symbol, for error context. + + Returns: + The parsed JSON payload for this page. + + Raises: + RateLimitError: On HTTP 429 (after retries are exhausted). + AuthenticationError: On HTTP 401/403. + DataNotAvailableError: On HTTP 404. + NetworkError: On other HTTP errors, transport failures, or a body + that cannot be decoded as JSON (after retries are exhausted). + """ + # One acquisition per attempt is the only rate-limit gate, since + # fetch_ohlcv is fully overridden and the base never acquires. + self.rate_limiter.acquire(blocking=True) + try: + response = self.session.get(endpoint, params=params) + except Exception as err: + raise NetworkError(provider="alpaca", message=f"Request failed: {endpoint}") from err + self._check_response_status(response, symbol) + return self._parse_bars_response(response) + + @_PAGE_RETRY + async def _get_page_async( + self, endpoint: str, params: dict[str, Any], symbol: str + ) -> dict[str, Any]: + """Asynchronously fetch and parse one bars page, retrying transient failures. + + Mirrors :meth:`_get_page` over the async transport. The rate-limit + acquisition is pushed to a worker thread so a throttled fetch never + blocks the event loop. + + Args: + endpoint: The bars endpoint URL. + params: The query parameters for this page. + symbol: The requested symbol, for error context. + + Returns: + The parsed JSON payload for this page. + + Raises: + RateLimitError: On HTTP 429 (after retries are exhausted). + AuthenticationError: On HTTP 401/403. + DataNotAvailableError: On HTTP 404. + NetworkError: On other HTTP errors, transport failures, or a body + that cannot be decoded as JSON (after retries are exhausted). + """ + await asyncio.to_thread(self.rate_limiter.acquire, True) + try: + response = await self._aget(endpoint, params=params) + except Exception as err: + raise NetworkError(provider="alpaca", message=f"Request failed: {endpoint}") from err + self._check_response_status(response, symbol) + return self._parse_bars_response(response) + def _fetch_raw_data( self, symbol: str, @@ -406,8 +639,8 @@ def _fetch_raw_data( expects (a list for stocks, a per-symbol dict for crypto). Args: - symbol: The symbol to fetch (stocks are case-insensitive; crypto - symbols are preserved verbatim). + symbol: The symbol to fetch (case-insensitive; uppercased into the + request for both asset classes). start: Inclusive start date/datetime in ISO-8601 (RFC-3339) form. end: Inclusive end date/datetime in ISO-8601 (RFC-3339) form. frequency: Canonical frequency key or alias. @@ -418,41 +651,27 @@ def _fetch_raw_data( The parsed JSON payload containing the ``bars`` data. Raises: - RateLimitError, AuthenticationError, DataNotAvailableError, - NetworkError: Per the HTTP status of the response. + RateLimitError: On HTTP 429 (after per-page retries). + AuthenticationError: On HTTP 401/403. + DataNotAvailableError: On HTTP 404. + DataValidationError: On an unsupported frequency or asset class. + NetworkError: On other HTTP, transport, or JSON-decoding failures. """ resolved = self._resolve_asset_class(symbol, asset_class) endpoint, params = self._bars_request(symbol, start, end, frequency, resolved) accumulated: Any = None token: str | None = None - try: - while True: - # A fresh dict per page keeps each request's params independent; - # mutating one shared dict would otherwise rewrite the token on - # earlier requests that already went out. - page_params = {**params, "page_token": token} if token else params - # One acquisition per page is the only rate-limit gate, since - # fetch_ohlcv is fully overridden and the base never acquires. - self.rate_limiter.acquire(blocking=True) - response = self.session.get(endpoint, params=page_params) - self._check_response_status(response.status_code, symbol, response.text) - payload = self._parse_bars_response(response) - accumulated = self._merge_bars(accumulated, payload.get("bars")) - token = payload.get("next_page_token") - if not token: - break - return {"bars": accumulated} - except ( - AuthenticationError, - RateLimitError, - NetworkError, - DataNotAvailableError, - ProviderError, - ): - raise - except Exception as err: - raise NetworkError(provider="alpaca", message=f"Request failed: {endpoint}") from err + while True: + # A fresh dict per page keeps each request's params independent; + # mutating one shared dict would otherwise rewrite the token on + # earlier requests that already went out. + page_params = {**params, "page_token": token} if token else params + payload = self._get_page(endpoint, page_params, symbol) + accumulated = self._merge_bars(accumulated, payload.get("bars")) + token = payload.get("next_page_token") + if not token: + return {"bars": accumulated} async def _fetch_raw_data_async( self, @@ -469,8 +688,8 @@ async def _fetch_raw_data_async( applies the same stock/crypto routing and ``next_page_token`` loop. Args: - symbol: The symbol to fetch (stocks are case-insensitive; crypto - symbols are preserved verbatim). + symbol: The symbol to fetch (case-insensitive; uppercased into the + request for both asset classes). start: Inclusive start date/datetime in ISO-8601 (RFC-3339) form. end: Inclusive end date/datetime in ISO-8601 (RFC-3339) form. frequency: Canonical frequency key or alias. @@ -481,41 +700,27 @@ async def _fetch_raw_data_async( The parsed JSON payload containing the ``bars`` data. Raises: - RateLimitError, AuthenticationError, DataNotAvailableError, - NetworkError: Per the HTTP status of the response. + RateLimitError: On HTTP 429 (after per-page retries). + AuthenticationError: On HTTP 401/403. + DataNotAvailableError: On HTTP 404. + DataValidationError: On an unsupported frequency or asset class. + NetworkError: On other HTTP, transport, or JSON-decoding failures. """ resolved = self._resolve_asset_class(symbol, asset_class) endpoint, params = self._bars_request(symbol, start, end, frequency, resolved) accumulated: Any = None token: str | None = None - try: - while True: - # A fresh dict per page keeps each request's params independent; - # mutating one shared dict would otherwise rewrite the token on - # earlier requests that already went out. - page_params = {**params, "page_token": token} if token else params - # One acquisition per page is the only rate-limit gate, since - # fetch_ohlcv is fully overridden and the base never acquires. - self.rate_limiter.acquire(blocking=True) - response = await self._aget(endpoint, params=page_params) - self._check_response_status(response.status_code, symbol, response.text) - payload = self._parse_bars_response(response) - accumulated = self._merge_bars(accumulated, payload.get("bars")) - token = payload.get("next_page_token") - if not token: - break - return {"bars": accumulated} - except ( - AuthenticationError, - RateLimitError, - NetworkError, - DataNotAvailableError, - ProviderError, - ): - raise - except Exception as err: - raise NetworkError(provider="alpaca", message=f"Request failed: {endpoint}") from err + while True: + # A fresh dict per page keeps each request's params independent; + # mutating one shared dict would otherwise rewrite the token on + # earlier requests that already went out. + page_params = {**params, "page_token": token} if token else params + payload = await self._get_page_async(endpoint, page_params, symbol) + accumulated = self._merge_bars(accumulated, payload.get("bars")) + token = payload.get("next_page_token") + if not token: + return {"bars": accumulated} def _extract_bars(self, raw_data: dict[str, Any], symbol: str) -> list[dict[str, Any]]: """Extract the bar list from either response shape. @@ -533,23 +738,24 @@ def _extract_bars(self, raw_data: dict[str, Any], symbol: str) -> list[dict[str, """ bars = raw_data.get("bars") if isinstance(bars, dict): - return bars.get(symbol) or bars.get(symbol.upper()) or [] + return bars.get(symbol.upper()) or bars.get(symbol) or [] return bars or [] def _transform_data( self, raw_data: dict[str, Any], symbol: str, - asset_class: str | None = None, + asset_class: str | None = None, # noqa: ARG002 ) -> pl.DataFrame: """Transform a raw bars payload into the canonical OHLCV schema. Args: raw_data: The parsed JSON payload from a bars request. - symbol: The requested symbol; the literal is added as a column, - uppercased for stocks and used verbatim for crypto. - asset_class: When ``"crypto"``, the symbol literal is preserved - verbatim; otherwise it is uppercased. + symbol: The requested symbol; its uppercased form is written into + the ``symbol`` column (the canonical contract), with a crypto + slash preserved (e.g. ``"BTC/USD"``). + asset_class: Unused; kept so both fetch paths can thread the + resolved asset class through a single transform signature. Returns: A DataFrame with columns @@ -563,9 +769,9 @@ def _transform_data( if not bars: return self._create_empty_dataframe() - # Crypto symbols keep their BASE/QUOTE form verbatim; stocks uppercase. - symbol_literal = symbol if asset_class == "crypto" else symbol.upper() - return self._bars_to_dataframe(bars, symbol_literal, symbol) + # Canonical contract: the symbol column is uppercase for both asset + # classes; a crypto BASE/QUOTE slash is preserved. + return self._bars_to_dataframe(bars, symbol.upper(), symbol) def _bars_to_dataframe( self, bars: list[dict[str, Any]], symbol_literal: str, symbol: str @@ -619,12 +825,6 @@ def _bars_to_dataframe( provider="alpaca", message=f"Failed to transform data for {symbol}" ) from err - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((NetworkError, RateLimitError)), - reraise=True, - ) def fetch_ohlcv( self, symbol: str, @@ -640,13 +840,14 @@ def fetch_ohlcv( rate-limit token up front whereas this provider rate-limits per page inside the fetch. The base contract is reproduced here directly -- input validation, circuit-breaker-wrapped fetch/transform/validate, and - the same info-logging. + the same info-logging. Transient-failure retry happens per page inside + the fetch (see ``_get_page``) rather than around the whole fetch. Args: symbol: The symbol to fetch. A ``BASE/QUOTE`` symbol (e.g. "BTC/USD") is treated as crypto unless ``asset_class`` overrides it. - start: Start date in YYYY-MM-DD format (inclusive). - end: End date in YYYY-MM-DD format (inclusive). + start: Inclusive start, ``YYYY-MM-DD`` or RFC-3339 datetime. + end: Inclusive end, ``YYYY-MM-DD`` or RFC-3339 datetime. frequency: Canonical frequency key or alias. asset_class: Explicit asset class ("stock" or "crypto"); inferred from the symbol when ``None``. @@ -654,6 +855,18 @@ def fetch_ohlcv( Returns: A DataFrame in the canonical OHLCV schema ``[timestamp, symbol, open, high, low, close, volume]``. + + Raises: + ValueError: If the symbol is empty, a date bound is malformed, or + start is after end. + DataValidationError: If the frequency or asset class is + unsupported, or the response cannot be transformed/validated. + AuthenticationError: If the credentials are rejected (HTTP 401/403). + RateLimitError: If HTTP 429 persists past the per-page retries. + DataNotAvailableError: If the endpoint reports no data (HTTP 404). + NetworkError: On other HTTP, transport, or JSON-decoding failures. + CircuitBreakerOpenError: If repeated failures have opened the + circuit breaker. """ self.logger.info( "Fetching OHLCV data", @@ -693,16 +906,17 @@ async def fetch_ohlcv_async( ) -> pl.DataFrame: """Asynchronously fetch OHLCV bars for a stock or crypto symbol. - Mirrors :meth:`fetch_ohlcv` over the async transport. The circuit - breaker wraps a synchronous callable, so the coroutine is awaited first - and only the transform/validate step runs inside the breaker, preserving - failure accounting without blocking the event loop on the fetch. + Mirrors :meth:`fetch_ohlcv` over the async transport with the same + resilience semantics: the circuit breaker wraps the full + fetch/transform/validate pipeline (an open breaker refuses before any + request goes out, and fetch failures count toward opening it), and + transient failures retry per page inside the fetch. Args: symbol: The symbol to fetch. A ``BASE/QUOTE`` symbol (e.g. "BTC/USD") is treated as crypto unless ``asset_class`` overrides it. - start: Start date in YYYY-MM-DD format (inclusive). - end: End date in YYYY-MM-DD format (inclusive). + start: Inclusive start, ``YYYY-MM-DD`` or RFC-3339 datetime. + end: Inclusive end, ``YYYY-MM-DD`` or RFC-3339 datetime. frequency: Canonical frequency key or alias. asset_class: Explicit asset class ("stock" or "crypto"); inferred from the symbol when ``None``. @@ -710,6 +924,18 @@ async def fetch_ohlcv_async( Returns: A DataFrame in the canonical OHLCV schema ``[timestamp, symbol, open, high, low, close, volume]``. + + Raises: + ValueError: If the symbol is empty, a date bound is malformed, or + start is after end. + DataValidationError: If the frequency or asset class is + unsupported, or the response cannot be transformed/validated. + AuthenticationError: If the credentials are rejected (HTTP 401/403). + RateLimitError: If HTTP 429 persists past the per-page retries. + DataNotAvailableError: If the endpoint reports no data (HTTP 404). + NetworkError: On other HTTP, transport, or JSON-decoding failures. + CircuitBreakerOpenError: If repeated failures have opened the + circuit breaker. """ self.logger.info( "Fetching OHLCV data (async)", @@ -723,15 +949,14 @@ async def fetch_ohlcv_async( self._validate_inputs(symbol, start, end, frequency) resolved = self._resolve_asset_class(symbol, asset_class) - raw_data = await self._fetch_raw_data_async( - symbol, start, end, frequency, asset_class=resolved - ) - - def _transform_and_validate() -> pl.DataFrame: + async def _fetch_and_process() -> pl.DataFrame: + raw_data = await self._fetch_raw_data_async( + symbol, start, end, frequency, asset_class=resolved + ) df = self._transform_data(raw_data, symbol, asset_class=resolved) return self._validate_ohlcv(df, self.name) - validated_data = self._with_circuit_breaker(_transform_and_validate) + validated_data = await self._with_circuit_breaker_async(_fetch_and_process) self.logger.info( "Successfully fetched OHLCV data (async)", diff --git a/src/ml4t/data/providers/mixins/circuit_breaker.py b/src/ml4t/data/providers/mixins/circuit_breaker.py index 07a68b7..ba0d64d 100644 --- a/src/ml4t/data/providers/mixins/circuit_breaker.py +++ b/src/ml4t/data/providers/mixins/circuit_breaker.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Coroutine from datetime import datetime from typing import Any, ClassVar, TypeVar @@ -82,6 +82,46 @@ def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: self._on_failure() raise + async def call_async( + self, + func: Callable[..., Coroutine[Any, Any, T]], + *args: Any, + **kwargs: Any, + ) -> T: + """Execute an async function with circuit breaker protection. + + Mirrors :meth:`call` for awaited callables, so async fetch paths get + the same state handling and failure accounting as sync ones. + + Args: + func: Async function to execute + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Function result + + Raises: + CircuitBreakerOpenError: If circuit is open + Original exception: If function fails + """ + if self.state == "OPEN": + if self._should_attempt_reset(): + self.state = "HALF_OPEN" + logger.info("Circuit breaker entering HALF_OPEN state") + else: + raise CircuitBreakerOpenError( + f"Circuit breaker is OPEN. Failures: {self.failure_count}" + ) + + try: + result = await func(*args, **kwargs) + self._on_success() + return result + except self.expected_exception: + self._on_failure() + raise + def _should_attempt_reset(self) -> bool: """Check if enough time passed to attempt reset.""" if self.last_failure_time is None: @@ -189,6 +229,30 @@ def _with_circuit_breaker( return self.circuit_breaker.call(func, *args, **kwargs) + async def _with_circuit_breaker_async( + self, + func: Callable[..., Coroutine[Any, Any, T]], + *args: Any, + **kwargs: Any, + ) -> T: + """Execute an async function with circuit breaker protection. + + Args: + func: Async function to execute + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Function result + + Raises: + CircuitBreakerOpenError: If circuit is open + """ + if not hasattr(self, "circuit_breaker"): + self.init_circuit_breaker() + + return await self.circuit_breaker.call_async(func, *args, **kwargs) + def _get_circuit_status(self) -> dict[str, Any]: """Get circuit breaker status. From 2725a6a1591e5daf2ff9d9a1c036582d2041a0d1 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 10 Jun 2026 16:57:07 +0000 Subject: [PATCH 08/13] fix: redact Alpaca secret and require both credentials - Strip api_secret from the sanitized provider info config - Mark alpaca available only when both key and secret resolve - Accept APCA_* SDK env aliases in config mapping and autodetect - Warn when an alpaca config lacks api_secret --- src/ml4t/data/config/validator.py | 7 +++++ src/ml4t/data/managers/config_manager.py | 6 +++- src/ml4t/data/managers/provider_manager.py | 36 +++++++++++++++++----- 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/ml4t/data/config/validator.py b/src/ml4t/data/config/validator.py index b95f968..d91fec1 100644 --- a/src/ml4t/data/config/validator.py +++ b/src/ml4t/data/config/validator.py @@ -73,6 +73,13 @@ def _validate_providers(self) -> None: f"Provider {provider.name} ({provider.type}) may require an API key" ) + # Alpaca authenticates with a key/secret pair, so a missing secret + # is just as fatal as a missing key. + if provider.type == "alpaca" and not provider.api_secret: + self.warnings.append( + f"Provider {provider.name} (alpaca) requires api_secret as well as api_key" + ) + # Validate rate limits if provider.rate_limit and provider.rate_limit.requests_per_second <= 0: self.errors.append( diff --git a/src/ml4t/data/managers/config_manager.py b/src/ml4t/data/managers/config_manager.py index 0b5024e..7dba835 100644 --- a/src/ml4t/data/managers/config_manager.py +++ b/src/ml4t/data/managers/config_manager.py @@ -40,8 +40,12 @@ class ConfigManager: >>> print(config_mgr.config["providers"]["yahoo"]) """ - # Environment variable to provider config mapping + # Environment variable to provider config mapping. Later entries win when + # two variables map to the same field, so the APCA_* aliases (Alpaca's own + # SDK/CLI names) come before the ALPACA_* project-convention names. ENV_MAPPING = { + "APCA_API_KEY_ID": ("alpaca", "api_key"), + "APCA_API_SECRET_KEY": ("alpaca", "api_secret"), "ALPACA_API_KEY": ("alpaca", "api_key"), "ALPACA_API_SECRET": ("alpaca", "api_secret"), "CRYPTOCOMPARE_API_KEY": ("cryptocompare", "api_key"), diff --git a/src/ml4t/data/managers/provider_manager.py b/src/ml4t/data/managers/provider_manager.py index 611d73a..fb6aa29 100644 --- a/src/ml4t/data/managers/provider_manager.py +++ b/src/ml4t/data/managers/provider_manager.py @@ -157,6 +157,9 @@ class ProviderManager: } ) + # Credential fields that must never leave get_provider_info's sanitized view + SECRET_FIELDS = frozenset({"api_key", "api_secret"}) + def __init__(self, config: dict[str, Any]) -> None: """Initialize ProviderManager. @@ -238,16 +241,18 @@ def _detect_available_providers(self) -> None: # Check configured providers for provider_name, provider_config in providers_config.items(): - if ( - provider_name in self.FREE_PROVIDERS - or provider_name in self.KEYED_PROVIDERS - and provider_config.get("api_key") - ): + available = provider_name in self.FREE_PROVIDERS or ( + provider_name in self.KEYED_PROVIDERS and provider_config.get("api_key") + ) + # Alpaca is two-credential: a key without a secret is not + # constructable, so it must not be reported as available. + if provider_name == "alpaca" and available: + available = bool(provider_config.get("api_secret") or self._alpaca_env_secret()) + if available: self._available_providers.append(provider_name) # Check environment for API keys not in config env_to_provider = { - "ALPACA_API_KEY": "alpaca", "CRYPTOCOMPARE_API_KEY": "cryptocompare", "DATABENTO_API_KEY": "databento", "MASSIVE_API_KEY": "massive", @@ -264,11 +269,28 @@ def _detect_available_providers(self) -> None: if "POLYGON_API_KEY" in os.environ and "polygon" not in self._available_providers: self._available_providers.append("polygon") + # Alpaca needs both credentials; either the project (ALPACA_*) or the + # Alpaca SDK (APCA_*) naming scheme satisfies each one. + if "alpaca" not in self._available_providers and ( + self._alpaca_env_key() and self._alpaca_env_secret() + ): + self._available_providers.append("alpaca") + # Add free providers if not already detected for free_provider in self.FREE_PROVIDERS: if free_provider not in self._available_providers: self._available_providers.append(free_provider) + @staticmethod + def _alpaca_env_key() -> str | None: + """Return the Alpaca API key from the environment, if set.""" + return os.environ.get("ALPACA_API_KEY") or os.environ.get("APCA_API_KEY_ID") + + @staticmethod + def _alpaca_env_secret() -> str | None: + """Return the Alpaca API secret from the environment, if set.""" + return os.environ.get("ALPACA_API_SECRET") or os.environ.get("APCA_API_SECRET_KEY") + @property def available_providers(self) -> list[str]: """Get list of available provider names.""" @@ -348,7 +370,7 @@ def get_provider_info(self, provider_name: str) -> dict[str, Any]: "configured": provider_name in self.config.get("providers", {}), "has_api_key": bool(config.get("api_key")), "is_free": provider_name in self.FREE_PROVIDERS, - "config": {k: v for k, v in config.items() if k != "api_key"}, + "config": {k: v for k, v in config.items() if k not in self.SECRET_FIELDS}, } def close_all(self) -> None: From 8895258257ed5e445021cc0219ab16dd2047001c Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 10 Jun 2026 16:57:15 +0000 Subject: [PATCH 09/13] test: cover Alpaca async path, error branches, and registration - Add fetch_ohlcv_async, per-page retry, and rate-limit header tests - Cover circuit breaker call_async, lazy init, and registry caching - Pin two-credential availability and secret redaction behavior - Dedupe the provider fixture and mock page-response builder --- tests/integration/test_alpaca.py | 2 +- tests/test_alpaca_provider.py | 562 ++++++++++++++++++++++++--- tests/test_base_provider_enhanced.py | 49 +++ 3 files changed, 558 insertions(+), 55 deletions(-) diff --git a/tests/integration/test_alpaca.py b/tests/integration/test_alpaca.py index 2982dbe..fd051fe 100644 --- a/tests/integration/test_alpaca.py +++ b/tests/integration/test_alpaca.py @@ -4,7 +4,7 @@ Requirements: - ALPACA_API_KEY and ALPACA_API_SECRET environment variables must be set - - Free tier uses the IEX feed (~15-min delayed); ~200 requests/min + - Free tier uses the real-time single-exchange IEX feed; ~200 requests/min - API key from: https://alpaca.markets/ Test Coverage: diff --git a/tests/test_alpaca_provider.py b/tests/test_alpaca_provider.py index 3bd7126..b2e60e7 100644 --- a/tests/test_alpaca_provider.py +++ b/tests/test_alpaca_provider.py @@ -1,12 +1,15 @@ """Tests for Alpaca data provider module.""" +import time from unittest.mock import AsyncMock, MagicMock, patch +import httpx import polars as pl import pytest from ml4t.data.core.exceptions import ( AuthenticationError, + CircuitBreakerOpenError, DataNotAvailableError, DataValidationError, NetworkError, @@ -28,6 +31,32 @@ ] +def _page_response(bars, token=None, status=200, headers=None): + """Build a mock bars response with a given status, token, and headers.""" + response = MagicMock() + response.status_code = status + response.text = "" + response.headers = headers or {} + response.json.return_value = {"bars": bars, "next_page_token": token} + return response + + +@pytest.fixture +def provider(): + """Create a provider instance.""" + return AlpacaDataProvider(api_key="k", api_secret="s") + + +@pytest.fixture(autouse=True) +def _no_retry_sleep(): + """Null out tenacity's sleep so retry-path tests never wait for real.""" + with ( + patch.object(AlpacaDataProvider._get_page.retry, "sleep", MagicMock()), + patch.object(AlpacaDataProvider._get_page_async.retry, "sleep", AsyncMock()), + ): + yield + + class TestAlpacaProviderInit: """Tests for provider construction and authentication.""" @@ -124,11 +153,6 @@ def test_name_property(self): class TestFrequencyMapping: """Tests for frequency mapping to Alpaca timeframes.""" - @pytest.fixture - def provider(self): - """Create a provider instance.""" - return AlpacaDataProvider(api_key="k", api_secret="s") - @pytest.mark.parametrize( ("frequency", "expected"), [ @@ -198,11 +222,6 @@ def test_default_rate_limit(self): class TestFetchRawDataStock: """Tests for single-page stock bar fetching and error mapping.""" - @pytest.fixture - def provider(self): - """Create a provider instance.""" - return AlpacaDataProvider(api_key="k", api_secret="s") - def test_fetch_raw_data_success(self, provider): """A 200 response returns the parsed bars structure.""" mock_response = MagicMock() @@ -315,11 +334,6 @@ async def test_fetch_raw_data_async_stock(self, provider): class TestTransformDataStock: """Tests for transforming stock bars into the standard schema.""" - @pytest.fixture - def provider(self): - """Create a provider instance.""" - return AlpacaDataProvider(api_key="k", api_secret="s") - def test_transform_data_stock(self, provider): """Bars transform to the canonical OHLCV schema, sorted, with invariants.""" raw = {"bars": STOCK_BARS, "next_page_token": None} @@ -371,11 +385,6 @@ def test_empty_bars_returns_empty_dataframe(self, provider): class TestAssetRouting: """Tests for routing symbols to the stock vs crypto endpoint.""" - @pytest.fixture - def provider(self): - """Create a provider instance.""" - return AlpacaDataProvider(api_key="k", api_secret="s") - def test_resolve_asset_class_by_slash(self, provider): """A slash in the symbol resolves to crypto; otherwise stock.""" assert provider._resolve_asset_class("BTC/USD", None) == "crypto" @@ -441,11 +450,6 @@ def test_feed_not_sent_for_crypto(self, provider): class TestFetchRawDataCrypto: """Tests for single-page crypto bar fetching.""" - @pytest.fixture - def provider(self): - """Create a provider instance.""" - return AlpacaDataProvider(api_key="k", api_secret="s") - def test_fetch_raw_data_crypto_success(self, provider): """A 200 crypto response returns the parsed symbol-keyed bars structure.""" mock_response = MagicMock() @@ -494,11 +498,6 @@ async def test_crypto_async(self, provider): class TestTransformDataCrypto: """Tests for transforming crypto bars into the standard schema.""" - @pytest.fixture - def provider(self): - """Create a provider instance.""" - return AlpacaDataProvider(api_key="k", api_secret="s") - def test_transform_data_crypto(self, provider): """Crypto bars transform with the slash-preserved symbol and weekend bars kept.""" raw = {"bars": {"BTC/USD": CRYPTO_BARS}, "next_page_token": None} @@ -528,23 +527,9 @@ def test_transform_data_crypto(self, provider): assert timestamps == sorted(timestamps) -def _page_response(bars, token): - """Build a mock bars response carrying a given next_page_token.""" - response = MagicMock() - response.status_code = 200 - response.text = "" - response.json.return_value = {"bars": bars, "next_page_token": token} - return response - - class TestPagination: """Tests for following next_page_token across multiple pages.""" - @pytest.fixture - def provider(self): - """Create a provider instance.""" - return AlpacaDataProvider(api_key="k", api_secret="s") - def test_follows_next_page_token(self, provider): """Two stock pages are merged and page 2 is requested with the token.""" page1 = _page_response([STOCK_BARS[0]], "abc") @@ -649,14 +634,47 @@ def test_alpaca_exported(self): assert Exported is AlpacaDataProvider assert "AlpacaDataProvider" in providers.__all__ + @staticmethod + def _env_without_alpaca(): + """Return a copy of the environment with all Alpaca credentials removed.""" + import os + + return {k: v for k, v in os.environ.items() if not k.startswith(("ALPACA_", "APCA_"))} + def test_env_autodetect_alpaca(self): - """ALPACA_API_KEY auto-detection resolves the alpaca provider.""" + """Both ALPACA_* env credentials together mark alpaca available.""" + import os + + from ml4t.data.managers.provider_manager import ProviderManager + + env = self._env_without_alpaca() + env["ALPACA_API_KEY"] = "k" + env["ALPACA_API_SECRET"] = "s" + with patch.dict(os.environ, env, clear=True): + manager = ProviderManager(config={"providers": {}}) + assert "alpaca" in manager.available_providers + + def test_env_autodetect_requires_both_credentials(self): + """A key without a secret must not mark alpaca available.""" import os from ml4t.data.managers.provider_manager import ProviderManager - env = {k: v for k, v in os.environ.items() if not k.startswith("ALPACA_")} + env = self._env_without_alpaca() env["ALPACA_API_KEY"] = "k" + with patch.dict(os.environ, env, clear=True): + manager = ProviderManager(config={"providers": {}}) + assert "alpaca" not in manager.available_providers + + def test_env_autodetect_apca_aliases(self): + """The Alpaca SDK's APCA_* env names also satisfy auto-detection.""" + import os + + from ml4t.data.managers.provider_manager import ProviderManager + + env = self._env_without_alpaca() + env["APCA_API_KEY_ID"] = "k" + env["APCA_API_SECRET_KEY"] = "s" with patch.dict(os.environ, env, clear=True): manager = ProviderManager(config={"providers": {}}) assert "alpaca" in manager.available_providers @@ -676,7 +694,7 @@ def test_config_manager_injects_both_credentials(self): from ml4t.data.managers.config_manager import ConfigManager - env = {k: v for k, v in os.environ.items() if not k.startswith("ALPACA_")} + env = self._env_without_alpaca() env["ALPACA_API_KEY"] = "k" env["ALPACA_API_SECRET"] = "s" with patch.dict(os.environ, env, clear=True): @@ -686,20 +704,456 @@ def test_config_manager_injects_both_credentials(self): assert provider_config.get("api_key") == "k" assert provider_config.get("api_secret") == "s" - def test_alpaca_missing_secret_fails_clearly(self): - """A config with api_key but no api_secret fails with a clear error. + def test_config_manager_accepts_apca_aliases(self): + """APCA_* env names inject credentials, and ALPACA_* names win over them.""" + import os + + from ml4t.data.managers.config_manager import ConfigManager + + env = self._env_without_alpaca() + env["APCA_API_KEY_ID"] = "apca-key" + env["APCA_API_SECRET_KEY"] = "apca-secret" + with patch.dict(os.environ, env, clear=True): + provider_config = ConfigManager().get_provider_config("alpaca") + assert provider_config.get("api_key") == "apca-key" + assert provider_config.get("api_secret") == "apca-secret" + + env["ALPACA_API_KEY"] = "alpaca-key" + with patch.dict(os.environ, env, clear=True): + provider_config = ConfigManager().get_provider_config("alpaca") + assert provider_config.get("api_key") == "alpaca-key" + + def test_alpaca_config_key_without_secret_not_available(self): + """A config with api_key but no api_secret leaves alpaca unavailable. - Availability is keyed on api_key alone, so an alpaca entry with only a - key is marked available but fails loudly at construction rather than - silently producing a broken provider. + A key without a secret is not constructable, so availability must not + be reported; get_provider then fails with the not-available error + instead of a late construction failure. """ import os from ml4t.data.managers.provider_manager import ProviderManager - env = {k: v for k, v in os.environ.items() if not k.startswith("ALPACA_")} + env = self._env_without_alpaca() with patch.dict(os.environ, env, clear=True): manager = ProviderManager(config={"providers": {"alpaca": {"api_key": "k"}}}) - with pytest.raises(ValueError, match="secret"): + assert "alpaca" not in manager.available_providers + with pytest.raises(ValueError, match="not available"): manager.get_provider("alpaca") + + def test_alpaca_config_key_with_env_secret_available(self): + """A configured key plus an env secret is constructable, so available.""" + import os + + from ml4t.data.managers.provider_manager import ProviderManager + + env = self._env_without_alpaca() + env["ALPACA_API_SECRET"] = "s" + with patch.dict(os.environ, env, clear=True): + manager = ProviderManager(config={"providers": {"alpaca": {"api_key": "k"}}}) + + assert "alpaca" in manager.available_providers + + def test_provider_info_redacts_secrets(self): + """get_provider_info must strip api_key and api_secret from the config.""" + import os + + from ml4t.data.managers.provider_manager import ProviderManager + + env = self._env_without_alpaca() + with patch.dict(os.environ, env, clear=True): + manager = ProviderManager( + config={"providers": {"alpaca": {"api_key": "k", "api_secret": "s", "feed": "iex"}}} + ) + info = manager.get_provider_info("alpaca") + + assert info["has_api_key"] is True + assert "api_key" not in info["config"] + assert "api_secret" not in info["config"] + assert info["config"]["feed"] == "iex" + + def test_validator_warns_on_missing_secret(self): + """The config validator warns when an alpaca provider lacks a secret.""" + from ml4t.data.config.models import DataConfig, ProviderConfig + from ml4t.data.config.validator import ConfigValidator + + config = DataConfig(providers=[ProviderConfig(name="alpaca", type="alpaca", api_key="k")]) + validator = ConfigValidator(config) + validator.validate() + + assert any("api_secret" in warning for warning in validator.warnings) + + +class TestValidateInputs: + """Tests for the date/datetime input validation override.""" + + def test_accepts_date_bounds(self, provider): + """Plain YYYY-MM-DD bounds validate.""" + provider._validate_inputs("AAPL", "2024-01-01", "2024-01-31", "daily") + + def test_accepts_rfc3339_bounds(self, provider): + """RFC-3339 datetime bounds with a Z offset validate.""" + provider._validate_inputs( + "BTC/USD", "2024-01-01T00:00:00Z", "2024-01-01T01:00:00Z", "minute" + ) + + def test_accepts_mixed_date_and_datetime_bounds(self, provider): + """A naive date start and an aware datetime end stay comparable.""" + provider._validate_inputs("AAPL", "2024-01-01", "2024-01-02T01:00:00Z", "hourly") + + def test_rejects_malformed_bound(self, provider): + """A non-ISO bound raises ValueError mentioning both accepted forms.""" + with pytest.raises(ValueError, match="YYYY-MM-DD or RFC-3339"): + provider._validate_inputs("AAPL", "01/02/2024", "2024-01-31", "daily") + + def test_rejects_start_after_end(self, provider): + """A start bound after the end bound raises ValueError.""" + with pytest.raises(ValueError, match="before or equal"): + provider._validate_inputs("AAPL", "2024-02-01", "2024-01-01", "daily") + + def test_rejects_empty_symbol(self, provider): + """An empty symbol raises ValueError.""" + with pytest.raises(ValueError, match="Symbol cannot be empty"): + provider._validate_inputs(" ", "2024-01-01", "2024-01-31", "daily") + + def test_fetch_ohlcv_accepts_datetime_bounds(self, provider): + """The public sync path accepts RFC-3339 bounds end to end.""" + response = _page_response({"BTC/USD": CRYPTO_BARS}) + + with ( + patch.object(provider.session, "get", return_value=response) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + df = provider.fetch_ohlcv( + "BTC/USD", "2024-01-01T00:00:00Z", "2024-01-06T01:00:00Z", "minute" + ) + + assert df.height == 2 + assert mock_get.call_args.kwargs["params"]["start"] == "2024-01-01T00:00:00Z" + + +class TestAssetClassValidation: + """Tests for explicit asset_class validation and normalization.""" + + def test_invalid_asset_class_raises(self, provider): + """A typo'd asset_class raises DataValidationError, not a silent 404.""" + with pytest.raises(DataValidationError, match="asset_class"): + provider._resolve_asset_class("AAPL", "equity") + + def test_asset_class_is_case_insensitive(self, provider): + """Mixed-case asset_class values normalize to lowercase.""" + assert provider._resolve_asset_class("AAPL", "Crypto") == "crypto" + assert provider._resolve_asset_class("BTC/USD", "STOCK") == "stock" + + +class TestSymbolNormalization: + """Tests for the canonical uppercase-symbol contract.""" + + def test_lowercase_crypto_symbol_uppercased_in_request(self, provider): + """A lowercase BASE/QUOTE symbol is uppercased into the symbols param.""" + _, params = provider._bars_request("btc/usd", "2024-01-01", "2024-01-07", "daily", "crypto") + + assert params["symbols"] == "BTC/USD" + + def test_lowercase_crypto_symbol_uppercased_in_output(self, provider): + """The symbol column is uppercase even for lowercase crypto input.""" + raw = {"bars": {"BTC/USD": CRYPTO_BARS}, "next_page_token": None} + + df = provider._transform_data(raw, "btc/usd", asset_class="crypto") + + assert df["symbol"].to_list() == ["BTC/USD", "BTC/USD"] + + def test_lowercase_stock_symbol_uppercased_in_output(self, provider): + """The symbol column is uppercase for lowercase stock input.""" + raw = {"bars": STOCK_BARS, "next_page_token": None} + + df = provider._transform_data(raw, "aapl") + + assert df["symbol"].to_list() == ["AAPL", "AAPL"] + + +class TestAdjustmentParam: + """Tests for the stock price-adjustment parameter.""" + + def test_default_adjustment_is_raw(self, provider): + """The default sends Alpaca's own unadjusted default explicitly.""" + _, params = provider._bars_request("AAPL", "2024-01-01", "2024-01-07", "daily", "stock") + + assert params["adjustment"] == "raw" + + def test_custom_adjustment_forwarded(self): + """A custom adjustment is forwarded in the stock request params.""" + provider = AlpacaDataProvider(api_key="k", api_secret="s", adjustment="all") + + _, params = provider._bars_request("AAPL", "2024-01-01", "2024-01-07", "daily", "stock") + + assert params["adjustment"] == "all" + + def test_adjustment_not_sent_for_crypto(self, provider): + """Crypto has no adjustment concept, so the param must not be sent.""" + _, params = provider._bars_request("BTC/USD", "2024-01-01", "2024-01-07", "daily", "crypto") + + assert "adjustment" not in params + + +class TestRetryAfterHeaders: + """Tests for deriving retry_after from 429 rate-limit headers.""" + + def test_retry_after_header_honored(self, provider): + """An explicit Retry-After header becomes the retry_after value.""" + response = _page_response([], status=429, headers={"Retry-After": "7"}) + + with pytest.raises(RateLimitError) as exc_info: + provider._check_response_status(response, "AAPL") + + assert exc_info.value.retry_after == 7.0 + + def test_rate_limit_reset_header_honored(self, provider): + """X-RateLimit-Reset (epoch seconds) maps to a relative delay.""" + reset_at = time.time() + 30 + response = _page_response([], status=429, headers={"X-RateLimit-Reset": str(reset_at)}) + + with pytest.raises(RateLimitError) as exc_info: + provider._check_response_status(response, "AAPL") + + assert 0 < exc_info.value.retry_after <= 30.0 + + def test_missing_headers_fall_back(self, provider): + """Without usable headers the delay falls back to one rate-limit period.""" + response = _page_response([], status=429) + + with pytest.raises(RateLimitError) as exc_info: + provider._check_response_status(response, "AAPL") + + assert exc_info.value.retry_after == 60.0 + + +class TestRetryWiring: + """Tests for the per-page retry policy.""" + + def test_persistent_500_retries_three_times(self, provider): + """A persistent 500 is retried 3 times, then NetworkError is raised.""" + response = _page_response([], status=500) + + with ( + patch.object(provider.session, "get", return_value=response) as mock_get, + patch.object(provider.rate_limiter, "acquire") as mock_acquire, + ): + with pytest.raises(NetworkError): + provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + + assert mock_get.call_count == 3 + # Each attempt is a real request, so each acquires its own token. + assert mock_acquire.call_count == 3 + + def test_401_fails_fast_without_retry(self, provider): + """An auth failure is not retried.""" + response = _page_response([], status=401) + + with ( + patch.object(provider.session, "get", return_value=response) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + with pytest.raises(AuthenticationError): + provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + + assert mock_get.call_count == 1 + + def test_page_retry_preserves_earlier_pages(self, provider): + """A transient failure on page 2 retries only page 2, keeping page 1.""" + page1 = _page_response([STOCK_BARS[0]], "abc") + flaky = _page_response([], status=500) + page2 = _page_response([STOCK_BARS[1]], None) + + with ( + patch.object(provider.session, "get", side_effect=[page1, flaky, page2]) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + data = provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-07", "daily") + + assert data["bars"] == [STOCK_BARS[0], STOCK_BARS[1]] + # Page 1 was fetched once; only page 2 was retried. + assert mock_get.call_count == 3 + assert mock_get.call_args_list[0].kwargs["params"] is not None + assert mock_get.call_args_list[1].kwargs["params"]["page_token"] == "abc" + assert mock_get.call_args_list[2].kwargs["params"]["page_token"] == "abc" + + +class TestErrorFallbacks: + """Tests for transport-failure and transform-failure fallbacks.""" + + def test_sync_transport_error_wrapped(self, provider): + """A transport-level failure wraps into NetworkError with its cause.""" + with ( + patch.object( + provider.session, "get", side_effect=httpx.ConnectError("boom") + ) as mock_get, + patch.object(provider.rate_limiter, "acquire"), + ): + with pytest.raises(NetworkError, match="Request failed") as exc_info: + provider._fetch_raw_data("AAPL", "2024-01-01", "2024-01-05", "daily") + + assert isinstance(exc_info.value.__cause__, httpx.ConnectError) + # NetworkError is transient, so the transport failure is retried. + assert mock_get.call_count == 3 + + @pytest.mark.asyncio + async def test_async_transport_error_wrapped(self, provider): + """The async transport failure wraps into NetworkError the same way.""" + with ( + patch.object( + provider, "_aget", new=AsyncMock(side_effect=httpx.ConnectError("boom")) + ) as mock_aget, + patch.object(provider.rate_limiter, "acquire"), + ): + with pytest.raises(NetworkError, match="Request failed") as exc_info: + await provider._fetch_raw_data_async("AAPL", "2024-01-01", "2024-01-05", "daily") + + assert isinstance(exc_info.value.__cause__, httpx.ConnectError) + assert mock_aget.call_count == 3 + + @pytest.mark.asyncio + async def test_async_429_maps_to_rate_limit(self, provider): + """An async 429 maps to RateLimitError after the per-page retries.""" + response = _page_response([], status=429, headers={"Retry-After": "0"}) + + with ( + patch.object(provider, "_aget", new=AsyncMock(return_value=response)) as mock_aget, + patch.object(provider.rate_limiter, "acquire"), + ): + with pytest.raises(RateLimitError): + await provider._fetch_raw_data_async("AAPL", "2024-01-01", "2024-01-05", "daily") + + assert mock_aget.call_count == 3 + + def test_malformed_bars_raise_data_validation(self, provider): + """Untransformable bars raise DataValidationError naming the symbol.""" + raw = { + "bars": [{"t": "not-a-timestamp", "o": 1, "h": 1, "l": 1, "c": 1, "v": 1}], + "next_page_token": None, + } + + with pytest.raises(DataValidationError, match="AAPL"): + provider._transform_data(raw, "AAPL") + + +class TestFetchOhlcvAsync: + """Tests for the public async fetch path.""" + + @pytest.mark.asyncio + async def test_fetch_ohlcv_async_stock(self, provider): + """The async public path returns the canonical schema for stocks.""" + response = _page_response(STOCK_BARS) + + with ( + patch.object(provider, "_aget", new=AsyncMock(return_value=response)), + patch.object(provider.rate_limiter, "acquire"), + ): + df = await provider.fetch_ohlcv_async("AAPL", "2024-01-01", "2024-01-05") + + assert df.columns == ["timestamp", "symbol", "open", "high", "low", "close", "volume"] + assert df.height == 2 + assert df["symbol"].to_list() == ["AAPL", "AAPL"] + timestamps = df["timestamp"].to_list() + assert timestamps == sorted(timestamps) + + @pytest.mark.asyncio + async def test_fetch_ohlcv_async_crypto(self, provider): + """The async public path handles the crypto dict shape and casing.""" + response = _page_response({"BTC/USD": CRYPTO_BARS}) + + with ( + patch.object(provider, "_aget", new=AsyncMock(return_value=response)), + patch.object(provider.rate_limiter, "acquire"), + ): + df = await provider.fetch_ohlcv_async( + "BTC/USD", "2024-01-01T00:00:00Z", "2024-01-07T00:00:00Z", "daily" + ) + + assert df.height == 2 + assert df["symbol"].to_list() == ["BTC/USD", "BTC/USD"] + + @pytest.mark.asyncio + async def test_fetch_ohlcv_async_validates_before_fetching(self, provider): + """Invalid bounds raise ValueError before any request goes out.""" + with patch.object(provider, "_aget", new=AsyncMock()) as mock_aget: + with pytest.raises(ValueError, match="YYYY-MM-DD or RFC-3339"): + await provider.fetch_ohlcv_async("AAPL", "bad", "2024-01-05") + + mock_aget.assert_not_called() + + @pytest.mark.asyncio + async def test_fetch_ohlcv_async_open_breaker_blocks_fetch(self, provider): + """An OPEN breaker refuses the async fetch before any request.""" + provider.init_circuit_breaker() + provider.circuit_breaker.state = "OPEN" + provider.circuit_breaker.last_failure_time = time.time() + + with patch.object(provider, "_aget", new=AsyncMock()) as mock_aget: + with pytest.raises(CircuitBreakerOpenError): + await provider.fetch_ohlcv_async("AAPL", "2024-01-01", "2024-01-05") + + mock_aget.assert_not_called() + + @pytest.mark.asyncio + async def test_fetch_ohlcv_async_failure_counts_toward_breaker(self, provider): + """An async fetch failure increments the breaker's failure count.""" + response = _page_response([], status=404) + + with ( + patch.object(provider, "_aget", new=AsyncMock(return_value=response)), + patch.object(provider.rate_limiter, "acquire"), + ): + with pytest.raises(DataNotAvailableError): + await provider.fetch_ohlcv_async("AAPL", "2024-01-01", "2024-01-05") + + assert provider.circuit_breaker.failure_count == 1 + + +class TestMalformedRateLimitHeaders: + """Tests for unparsable rate-limit header values.""" + + def test_malformed_rate_limit_headers_fall_back(self, provider): + """Unparsable header values fall through to the default delay.""" + response = _page_response( + [], + status=429, + headers={"Retry-After": "soon", "X-RateLimit-Reset": "later"}, + ) + + with pytest.raises(RateLimitError) as exc_info: + provider._check_response_status(response, "AAPL") + + assert exc_info.value.retry_after == 60.0 + + +class TestAsyncBreakerWiring: + """Tests for the async circuit-breaker mixin plumbing.""" + + @pytest.mark.asyncio + async def test_with_circuit_breaker_async_lazy_init(self, provider): + """The async wrapper initializes the breaker when none exists yet.""" + del provider.circuit_breaker + + async def _ok(): + return 42 + + result = await provider._with_circuit_breaker_async(_ok) + + assert result == 42 + assert provider.circuit_breaker.state == "CLOSED" + + +class TestRegistryCaching: + """Tests for the lazy provider-class registry.""" + + def test_provider_classes_cached(self): + """Repeated lookups return the same cached registry mapping.""" + from ml4t.data.managers.provider_manager import ProviderManager + + first = ProviderManager._get_provider_classes() + second = ProviderManager._get_provider_classes() + + assert first is second + assert "alpaca" in first diff --git a/tests/test_base_provider_enhanced.py b/tests/test_base_provider_enhanced.py index 1ba7052..51b43b6 100644 --- a/tests/test_base_provider_enhanced.py +++ b/tests/test_base_provider_enhanced.py @@ -151,6 +151,55 @@ def success_func(): assert breaker.state == "CLOSED" assert breaker.failure_count == 0 + @pytest.mark.asyncio + async def test_call_async_success_and_failure_accounting(self): + """call_async mirrors call: successes pass through, failures count.""" + breaker = CircuitBreaker(failure_threshold=2) + + async def success_func(): + return "success" + + async def failing_func(): + raise Exception("Mock failure") + + assert await breaker.call_async(success_func) == "success" + assert breaker.state == "CLOSED" + + with pytest.raises(Exception, match="Mock failure"): + await breaker.call_async(failing_func) + assert breaker.failure_count == 1 + + with pytest.raises(Exception, match="Mock failure"): + await breaker.call_async(failing_func) + assert breaker.state == "OPEN" + + # While open, calls are refused before executing the function. + with pytest.raises(CircuitBreakerOpenError): + await breaker.call_async(success_func) + + @pytest.mark.asyncio + async def test_call_async_half_open_recovery(self): + """After the reset timeout, call_async probes HALF_OPEN and recovers.""" + import time + + breaker = CircuitBreaker(failure_threshold=1, reset_timeout=0.05) + + async def failing_func(): + raise Exception("Mock failure") + + async def success_func(): + return "success" + + with pytest.raises(Exception, match="Mock failure"): + await breaker.call_async(failing_func) + assert breaker.state == "OPEN" + + time.sleep(0.1) + + assert await breaker.call_async(success_func) == "success" + assert breaker.state == "CLOSED" + assert breaker.failure_count == 0 + class TestBaseProvider: """Test enhanced BaseProvider functionality.""" From 4457accd06149b1ad206f644005d6ec7eb157606 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 10 Jun 2026 16:57:24 +0000 Subject: [PATCH 10/13] docs: document Alpaca provider setup and correct feed claims - Add docs/providers/alpaca.md plus README, mkdocs, and env-var entries - Correct the IEX delay claim and document the raw adjustment default - Refresh provider counts and catalog line counts --- README.md | 1 + docs/INTEGRATION_TESTING.md | 2 + docs/providers/README.md | 6 ++ docs/providers/alpaca.md | 142 ++++++++++++++++++++++++++++ mkdocs.yml | 1 + src/ml4t/data/providers/AGENT.md | 2 +- src/ml4t/data/providers/AGENTS.md | 2 +- src/ml4t/data/providers/__init__.py | 2 +- 8 files changed, 155 insertions(+), 3 deletions(-) create mode 100644 docs/providers/alpaca.md diff --git a/README.md b/README.md index c9d56b1..f51e96f 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,7 @@ fred = FREDProvider().fetch_series("GDP", "2020-01-01", "2024-12-31") | Provider | Coverage | |----------|----------| +| Alpaca | US equities + crypto (free IEX feed) | | EODHD | 60+ global exchanges | | Tiingo | US equities with quality focus | | Twelve Data | Multi-asset coverage | diff --git a/docs/INTEGRATION_TESTING.md b/docs/INTEGRATION_TESTING.md index 4cdd6f0..37538ac 100644 --- a/docs/INTEGRATION_TESTING.md +++ b/docs/INTEGRATION_TESTING.md @@ -170,6 +170,8 @@ pytest tests/integration/ -v -W default Required secrets: ``` +ALPACA_API_KEY +ALPACA_API_SECRET CRYPTOCOMPARE_API_KEY DATABENTO_API_KEY OANDA_API_KEY diff --git a/docs/providers/README.md b/docs/providers/README.md index 5726a77..23593c4 100644 --- a/docs/providers/README.md +++ b/docs/providers/README.md @@ -13,6 +13,7 @@ For detailed pricing, terms, and gap analysis, see [PROVIDER_AUDIT.md](PROVIDER_ | Provider | API Key | Free Tier | Best For | |----------|---------|-----------|----------| | [YahooFinance](yahoo.md) | No | Unlimited* | Quick start, US equities | +| [Alpaca](alpaca.md) | Yes | 200 req/min (IEX feed) | US equities + crypto, intraday | | [EODHD](eodhd.md) | Yes | 20 calls/day | Global equities (60+ exchanges) | | [Tiingo](tiingo.md) | Yes | 1,000 req/day | US equities alternative | | [Finnhub](finnhub.md) | Yes | 60 req/min | Company metrics, real-time | @@ -73,6 +74,7 @@ For detailed pricing, terms, and gap analysis, see [PROVIDER_AUDIT.md](PROVIDER_ | Provider | Minute | Hourly | Daily | Options | Fundamentals | |----------|--------|--------|-------|---------|--------------| | YahooFinance | ✅ (7d) | ✅ | ✅ | ❌* | ❌* | +| Alpaca | ✅ | ✅ | ✅ | ❌ | ❌ | | Databento | ✅ | ✅ | ✅ | ✅ (OPRA) | ❌ | | Massive | ✅ | ✅ | ✅ | ✅ | ✅ | | EODHD | ❌ | ❌ | ✅ | ✅ ($29.99) | ✅ ($59.99) | @@ -200,6 +202,9 @@ Create a `.env` file in your project root: ```bash # Free tier providers +ALPACA_API_KEY=your_key_here +ALPACA_API_SECRET=your_secret_here +# (Alpaca's own SDK names APCA_API_KEY_ID / APCA_API_SECRET_KEY also work) EODHD_API_KEY=your_key_here TIINGO_API_KEY=your_key_here FINNHUB_API_KEY=your_key_here @@ -251,6 +256,7 @@ See [Incremental Updates Guide](../storage/INCREMENTAL_ARCHITECTURE.md) for deta | Provider | Split Adjusted | Dividend Adjusted | |----------|----------------|-------------------| | YahooFinance | ✅ | ✅ | +| Alpaca | Opt-in (`adjustment=`) | Opt-in (`adjustment=`) | | EODHD | ✅ | ✅ | | Tiingo | ✅ | ✅ | | WikiPrices | ✅ | ✅ | diff --git a/docs/providers/alpaca.md b/docs/providers/alpaca.md new file mode 100644 index 0000000..015bf10 --- /dev/null +++ b/docs/providers/alpaca.md @@ -0,0 +1,142 @@ +# Alpaca Provider + +**Provider**: `AlpacaDataProvider` +**Website**: [alpaca.markets](https://alpaca.markets) +**API Key**: Required (key + secret pair) +**Free Tier**: 200 requests/min, real-time IEX feed + +--- + +## Overview + +Alpaca provides long-history, high-frequency US market data across equities and +crypto over a single historical REST API. One provider serves both asset +classes: plain tickers route to the stock bars endpoint, `BASE/QUOTE` symbols +route to the crypto bars endpoint. + +**Best For**: Free US intraday equities, US crypto bars + +**Pricing**: +| Tier | Price | Features | +|------|-------|----------| +| Basic | $0/mo | 200 req/min, real-time IEX feed, no recent-15-min SIP access | +| Algo Trader Plus | $99/mo | 10,000 req/min, full SIP (consolidated tape) | + +--- + +## Quick Start + +```python +import os +os.environ["ALPACA_API_KEY"] = "your_key_here" +os.environ["ALPACA_API_SECRET"] = "your_secret_here" + +from ml4t.data.providers import AlpacaDataProvider + +provider = AlpacaDataProvider() + +# US stocks +df = provider.fetch_ohlcv("AAPL", "2024-01-01", "2024-12-01") + +# Crypto (BASE/QUOTE symbol routes to the crypto endpoint) +df = provider.fetch_ohlcv("BTC/USD", "2024-01-01", "2024-01-31") + +# Intraday with RFC-3339 datetime bounds +df = provider.fetch_ohlcv( + "BTC/USD", "2024-01-01T00:00:00Z", "2024-01-01T01:00:00Z", frequency="minute" +) + +provider.close() +``` + +Async usage: + +```python +async with AlpacaDataProvider() as provider: + df = await provider.fetch_ohlcv_async("AAPL", "2024-01-01", "2024-12-01") +``` + +--- + +## Symbol Format + +| Asset class | Format | Examples | +|-------------|--------|----------| +| US stocks | Plain ticker | AAPL, MSFT | +| Crypto | BASE/QUOTE | BTC/USD, ETH/USD | + +Symbols are uppercased into requests and into the output `symbol` column; the +crypto slash is preserved (e.g. `BTC/USD`). + +--- + +## Supported Frequencies + +| Frequency | Availability | +|-----------|--------------| +| `daily` | ✅ | +| `hourly` | ✅ | +| `minute` | ✅ | + +`start`/`end` accept `YYYY-MM-DD` dates or RFC-3339 datetimes (both inclusive); +datetime bounds are the natural shape for sub-day minute/hour windows. + +--- + +## Feeds and Adjustment + +- `feed="iex"` (default): the free feed, real-time but served from a single + exchange (IEX, roughly 2-3% of US volume). The Basic plan additionally cannot + query the most recent 15 minutes of SIP data. +- `feed="sip"`: the consolidated tape (paid plans). +- `adjustment="raw"` (default, Alpaca's own default): stock bars are **not** + adjusted for splits or dividends. Pass `adjustment="split"`, `"dividend"`, or + `"all"` for adjusted bars. Crypto has no adjustment concept. + +```python +provider = AlpacaDataProvider(feed="sip", adjustment="all") +``` + +--- + +## API Key Setup + +```bash +# .env file (project convention) +ALPACA_API_KEY=your_key_here +ALPACA_API_SECRET=your_secret_here +``` + +Alpaca's own SDK/CLI names `APCA_API_KEY_ID` / `APCA_API_SECRET_KEY` are also +accepted. Get a free key at [alpaca.markets](https://alpaca.markets/). + +--- + +## Rate Limits + +| Tier | Limit | +|------|-------| +| Basic (free) | 200 req/min | +| Algo Trader Plus | 10,000 req/min | + +The provider throttles client-side to 200 req/min by default (override with the +`rate_limit` constructor argument), honors 429 `Retry-After`/rate-limit-reset +headers, and retries transient failures per pagination page. + +--- + +## Not Yet Implemented + +| Feature | Priority | +|---------|----------| +| Quotes / trades (tick) endpoints | MEDIUM | +| Options bars | LOW | +| News API | LOW | + +--- + +## See Also + +- [Alpaca Market Data docs](https://docs.alpaca.markets/us/docs/about-market-data-api) +- [Provider README](README.md) +- [Provider Audit](PROVIDER_AUDIT.md) diff --git a/mkdocs.yml b/mkdocs.yml index 82a9d8f..e0ccd25 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -124,6 +124,7 @@ nav: - Providers: - providers/index.md - Yahoo Finance: providers/yahoo.md + - Alpaca: providers/alpaca.md - CoinGecko: providers/coingecko.md - FRED: providers/fred.md - Fama-French: providers/fama_french.md diff --git a/src/ml4t/data/providers/AGENT.md b/src/ml4t/data/providers/AGENT.md index 244cc30..08c2a07 100644 --- a/src/ml4t/data/providers/AGENT.md +++ b/src/ml4t/data/providers/AGENT.md @@ -13,7 +13,7 @@ | File | Lines | Purpose | |------|-------|---------| | yahoo.py | 603 | Yahoo Finance (free) | -| alpaca.py | 722 | Alpaca US stocks + crypto (free IEX feed, two-cred auth) | +| alpaca.py | 968 | Alpaca US stocks + crypto (free IEX feed, two-cred auth) | | binance.py | 410 | Binance authenticated | | binance_public.py | 1430 | Binance public API | | eodhd.py | 464 | EOD Historical Data | diff --git a/src/ml4t/data/providers/AGENTS.md b/src/ml4t/data/providers/AGENTS.md index 87395e1..976dd59 100644 --- a/src/ml4t/data/providers/AGENTS.md +++ b/src/ml4t/data/providers/AGENTS.md @@ -13,7 +13,7 @@ | File | Lines | Purpose | |------|-------|---------| | yahoo.py | 603 | Yahoo Finance (free) | -| alpaca.py | 722 | Alpaca US stocks + crypto (free IEX feed, two-cred auth) | +| alpaca.py | 968 | Alpaca US stocks + crypto (free IEX feed, two-cred auth) | | binance_api.py | 410 | Binance REST API | | binance_bulk.py | 1430 | Binance bulk historical archive | | eodhd.py | 464 | EOD Historical Data | diff --git a/src/ml4t/data/providers/__init__.py b/src/ml4t/data/providers/__init__.py index 3de34a5..d1fb3cc 100644 --- a/src/ml4t/data/providers/__init__.py +++ b/src/ml4t/data/providers/__init__.py @@ -2,7 +2,7 @@ This module provides unified access to multiple financial data providers. -Available Providers (20 live + 3 synthetic/testing): +Available Providers (21 live + 3 synthetic/testing): - BaseProvider: Abstract base class for all providers - YahooFinanceProvider: Yahoo Finance (free, no API key) - AlpacaDataProvider: Alpaca US stocks and crypto (free IEX feed, two-credential auth) From 8cc9605d3929d9b000b8975eb8ccfe601331df7f Mon Sep 17 00:00:00 2001 From: Alejandro Date: Thu, 11 Jun 2026 00:01:30 +0000 Subject: [PATCH 11/13] chore: group ml4t imports as first-party in ruff isort - Declare the ml4t namespace known-first-party so its imports form a third section after stdlib and third-party blocks - Drop stale import-section comments splitting the ml4t block in examples --- examples/01_microstructure_tick_analysis.py | 4 +--- examples/02_cross_sectional_nasdaq100.py | 5 ++--- pyproject.toml | 7 +++++++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/01_microstructure_tick_analysis.py b/examples/01_microstructure_tick_analysis.py index 88b790c..e22f4d7 100644 --- a/examples/01_microstructure_tick_analysis.py +++ b/examples/01_microstructure_tick_analysis.py @@ -21,12 +21,10 @@ import matplotlib.pyplot as plt import numpy as np import polars as pl + from ml4t.engineer.bars import DollarBar, VolumeBar from ml4t.engineer.features import microstructure as ms from ml4t.engineer.labeling import BarrierConfig, triple_barrier_labels - -# Import qeval modules -# Import qfeatures modules from ml4t.evaluation import Evaluator, Tier from ml4t.evaluation.splitters import PurgedWalkForwardCV diff --git a/examples/02_cross_sectional_nasdaq100.py b/examples/02_cross_sectional_nasdaq100.py index b3500c8..5973173 100644 --- a/examples/02_cross_sectional_nasdaq100.py +++ b/examples/02_cross_sectional_nasdaq100.py @@ -22,11 +22,10 @@ from pathlib import Path import matplotlib.pyplot as plt - -# Import qfeatures modules -import ml4t.engineer as qf import polars as pl import seaborn as sns + +import ml4t.engineer as qf from ml4t.evaluation import Evaluator, Tier from ml4t.evaluation.evaluation.stats import deflated_sharpe_ratio from ml4t.evaluation.evaluation.viz import create_factor_tearsheet diff --git a/pyproject.toml b/pyproject.toml index 98b08f4..b0dbad7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -236,6 +236,13 @@ ignore = [ "B904", # exception chaining (not always needed) ] +[tool.ruff.lint.isort] +# Treat the whole ml4t namespace (data, engineer, evaluation, ...) as +# first-party so its imports always form their own third section after the +# standard library and third-party blocks, even for ml4t distributions +# installed from outside this repo. +known-first-party = ["ml4t"] + [tool.ruff.lint.per-file-ignores] "tests/*" = ["ARG001", "ARG002", "ARG005", "B017", "F821", "SIM117"] # Test patterns "src/ml4t/data/futures/continuous_downloader.py" = ["ARG002"] # Public API placeholder From a71b0b7b1d0203dcb43d4f83b7b51cc51f7ea9c0 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Thu, 11 Jun 2026 00:13:18 +0000 Subject: [PATCH 12/13] chore: ignore local coverage and planning artifacts - Add coverage output and local working dirs to gitignore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 15800f3..7960f05 100644 --- a/.gitignore +++ b/.gitignore @@ -34,7 +34,11 @@ htmlcov/ logs/ site/ src/ml4t/data/_version.py +coverage/ # Claude Code (local development only) CLAUDE.md .claude/ +.agentic_documentation/ +.skills-outputs/ +planning/ From 09f4d0c678dd72fad195599a6061b4c875d1df87 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Thu, 11 Jun 2026 00:58:43 +0000 Subject: [PATCH 13/13] feat: support 5m/15m/30m frequencies in Alpaca provider - Map 5m/15m/30m and their Nminute aliases to Alpaca timeframes - Add parametrized mapping tests and a live 15m integration test - Document the new frequencies and uniform aliases in provider docs --- docs/providers/alpaca.md | 12 +++++++++--- src/ml4t/data/providers/alpaca.py | 16 +++++++++++----- tests/integration/test_alpaca.py | 17 +++++++++++++++++ tests/test_alpaca_provider.py | 6 ++++++ 4 files changed, 43 insertions(+), 8 deletions(-) diff --git a/docs/providers/alpaca.md b/docs/providers/alpaca.md index 015bf10..f8b4a50 100644 --- a/docs/providers/alpaca.md +++ b/docs/providers/alpaca.md @@ -46,6 +46,9 @@ df = provider.fetch_ohlcv( "BTC/USD", "2024-01-01T00:00:00Z", "2024-01-01T01:00:00Z", frequency="minute" ) +# Multi-year intraday backfills with minute multiples (5m / 15m / 30m) +df = provider.fetch_ohlcv("AAPL", "2021-01-01", "2024-12-31", frequency="15m") + provider.close() ``` @@ -74,9 +77,12 @@ crypto slash is preserved (e.g. `BTC/USD`). | Frequency | Availability | |-----------|--------------| -| `daily` | ✅ | -| `hourly` | ✅ | -| `minute` | ✅ | +| `daily` / `1d` | ✅ | +| `hourly` / `1h` | ✅ | +| `minute` / `1m` | ✅ | +| `5m` / `5minute` | ✅ | +| `15m` / `15minute` | ✅ | +| `30m` / `30minute` | ✅ | `start`/`end` accept `YYYY-MM-DD` dates or RFC-3339 datetimes (both inclusive); datetime bounds are the natural shape for sub-day minute/hour windows. diff --git a/src/ml4t/data/providers/alpaca.py b/src/ml4t/data/providers/alpaca.py index 70dc68e..df0c965 100644 --- a/src/ml4t/data/providers/alpaca.py +++ b/src/ml4t/data/providers/alpaca.py @@ -2,8 +2,8 @@ Alpaca provides long-history, high-frequency market data across multiple asset classes including US equities and crypto, served over a historical REST API. -Daily, hourly, and minute OHLCV bars are supported for both asset classes from a -single symbol-routed provider. +Daily, hourly, and minute OHLCV bars (1/5/15/30-minute) are supported for both +asset classes from a single symbol-routed provider. API Documentation: https://docs.alpaca.markets/us/docs/about-market-data-api @@ -130,9 +130,9 @@ def _retry_wait(retry_state: RetryCallState) -> float: class AlpacaDataProvider(AsyncSessionMixin, BaseProvider): """Alpaca market data provider. - Supports equities and crypto with daily, hourly, and minute OHLCV bars over - Alpaca's historical REST API. Authentication uses two header credentials that - are wired onto both the sync and async HTTP sessions. + Supports equities and crypto with daily, hourly, and minute OHLCV bars + (1/5/15/30-minute) over Alpaca's historical REST API. Authentication uses two + header credentials that are wired onto both the sync and async HTTP sessions. Supports both sync and async operations: # Sync @@ -168,6 +168,12 @@ class AlpacaDataProvider(AsyncSessionMixin, BaseProvider): "minute": "1Min", "1m": "1Min", "1minute": "1Min", + "5m": "5Min", + "5minute": "5Min", + "15m": "15Min", + "15minute": "15Min", + "30m": "30Min", + "30minute": "30Min", } def __init__( diff --git a/tests/integration/test_alpaca.py b/tests/integration/test_alpaca.py index fd051fe..4cdb431 100644 --- a/tests/integration/test_alpaca.py +++ b/tests/integration/test_alpaca.py @@ -9,6 +9,7 @@ Test Coverage: - Stock daily OHLCV data (AAPL) + - Stock 15-minute OHLCV data (AAPL) - Crypto minute OHLCV data (BTC/USD) IMPORTANT: @@ -78,6 +79,22 @@ def test_fetch_stock_daily(self, provider): assert (df["high"] >= df["low"]).all(), "High should be >= Low" assert (df["symbol"] == "AAPL").all() + def test_fetch_stock_15_minute(self, provider): + """Fetch 15-minute stock bars for AAPL with a real API call.""" + df = provider.fetch_ohlcv( + symbol="AAPL", + start="2024-01-02T14:00:00Z", + end="2024-01-02T20:00:00Z", + frequency="15m", + ) + + assert isinstance(df, pl.DataFrame) + assert not df.is_empty(), "Should fetch some 15-minute data for AAPL" + assert all(col in df.columns for col in REQUIRED_COLS) + assert (df["high"] >= df["low"]).all(), "High should be >= Low" + # Bars must land on 15-minute boundaries. + assert (df["timestamp"].dt.minute() % 15 == 0).all() + def test_fetch_crypto_minute(self, provider): """Fetch minute crypto bars for BTC/USD with a real API call.""" df = provider.fetch_ohlcv( diff --git a/tests/test_alpaca_provider.py b/tests/test_alpaca_provider.py index b2e60e7..ee338f5 100644 --- a/tests/test_alpaca_provider.py +++ b/tests/test_alpaca_provider.py @@ -167,6 +167,12 @@ class TestFrequencyMapping: ("minute", "1Min"), ("1m", "1Min"), ("1minute", "1Min"), + ("5m", "5Min"), + ("5minute", "5Min"), + ("15m", "15Min"), + ("15minute", "15Min"), + ("30m", "30Min"), + ("30minute", "30Min"), ], ) def test_frequency_maps_to_timeframe(self, provider, frequency, expected):