diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 2f3615a6357..807c96a356d 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -199,6 +199,35 @@ def __init__( self.inst_processors ), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty" + def _resolve_instruments(self, instruments): + if instruments is None: + warnings.warn("`instruments` is not set, will load all stocks") + instruments = "all" + if isinstance(instruments, str): + return D.instruments(instruments, filter_pipe=self.filter_pipe) + if self.filter_pipe is not None: + warnings.warn( + "`filter_pipe` is not None, but it will not be used with `instruments` as list" + ) + return instruments + + def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: + instruments = self._resolve_instruments(instruments) + if self.is_group: + df = pd.concat( + { + grp: self._load_group_df( + instruments, exprs, names, start_time, end_time, grp + ) + for grp, (exprs, names) in self.fields.items() + }, + axis=1, + ) + else: + exprs, names = self.fields + df = self._load_group_df(instruments, exprs, names, start_time, end_time) + return df + def load_group_df( self, instruments, @@ -208,19 +237,34 @@ def load_group_df( end_time: Union[str, pd.Timestamp] = None, gp_name: str = None, ) -> pd.DataFrame: - if instruments is None: - warnings.warn("`instruments` is not set, will load all stocks") - instruments = "all" - if isinstance(instruments, str): - instruments = D.instruments(instruments, filter_pipe=self.filter_pipe) - elif self.filter_pipe is not None: - warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list") + instruments = self._resolve_instruments(instruments) + return self._load_group_df( + instruments, exprs, names, start_time, end_time, gp_name + ) + def _load_group_df( + self, + instruments, + exprs: list, + names: list, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + gp_name: str = None, + ) -> pd.DataFrame: freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq inst_processors = ( - self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, []) + self.inst_processors + if isinstance(self.inst_processors, list) + else self.inst_processors.get(gp_name, []) + ) + df = D.features( + instruments, + exprs, + start_time, + end_time, + freq=freq, + inst_processors=inst_processors, ) - df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors) df.columns = names if self.swap_level: df = df.swaplevel().sort_index() # NOTE: if swaplevel, return diff --git a/tests/data_mid_layer_tests/test_dataloader.py b/tests/data_mid_layer_tests/test_dataloader.py index 8646e785877..a54df9514c0 100644 --- a/tests/data_mid_layer_tests/test_dataloader.py +++ b/tests/data_mid_layer_tests/test_dataloader.py @@ -4,7 +4,9 @@ import sys import unittest import qlib +import pandas as pd from pathlib import Path +from unittest.mock import patch sys.path.append(str(Path(__file__).resolve().parent)) from qlib.data.dataset.loader import NestedDataLoader, QlibDataLoader @@ -16,6 +18,41 @@ class TestDataLoader(unittest.TestCase): + def test_group_loader_applies_filter_pipe_once(self): + filter_pipe = [{"filter_type": "NameDFilter", "name_rule_re": "SH.*"}] + instruments = { + "SH600000": [(pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-02"))] + } + loader = QlibDataLoader( + config={ + "feature": (["$close"], ["CLOSE"]), + "label": (["Ref($close, -1)"], ["LABEL0"]), + }, + filter_pipe=filter_pipe, + swap_level=False, + ) + + def mock_features( + instruments_arg, exprs, start_time, end_time, freq, inst_processors + ): + self.assertIs(instruments_arg, instruments) + return pd.DataFrame( + [[1.0]], index=pd.Index(["SH600000"], name="instrument") + ) + + with patch( + "qlib.data.dataset.loader.D.instruments", return_value=instruments + ) as mock_instruments, patch( + "qlib.data.dataset.loader.D.features", side_effect=mock_features + ) as mock_features_fn: + df = loader.load( + instruments="csi300", start_time="2020-01-01", end_time="2020-01-02" + ) + + mock_instruments.assert_called_once_with("csi300", filter_pipe=filter_pipe) + self.assertEqual(mock_features_fn.call_count, 2) + self.assertEqual(list(df.columns.get_level_values(0)), ["feature", "label"]) + def test_nested_data_loader(self): qlib.init(kernels=1) nd = NestedDataLoader(