From 40d6b121c63e463499518276c905af151124c4cd Mon Sep 17 00:00:00 2001 From: Jared Lewis Date: Fri, 28 Mar 2025 19:16:21 -0600 Subject: [PATCH 1/3] refactor: Introduce a base class for handling intake-esgf related queries --- scripts/fetch_test_data.py | 174 ++++++++----------- src/ref_sample_data/__init__.py | 2 +- src/ref_sample_data/data_request/base.py | 38 +++- src/ref_sample_data/data_request/cmip6.py | 10 +- src/ref_sample_data/data_request/obs4mips.py | 10 +- 5 files changed, 119 insertions(+), 115 deletions(-) diff --git a/scripts/fetch_test_data.py b/scripts/fetch_test_data.py index 64cc9494..bda322ee 100755 --- a/scripts/fetch_test_data.py +++ b/scripts/fetch_test_data.py @@ -6,43 +6,13 @@ import pooch import typer import xarray as xr -from intake_esgf import ESGFCatalog -from ref_sample_data import CMIP6Request, DataRequest, Obs4MIPsRequest +from ref_sample_data import DataRequest, PMPRequest OUTPUT_PATH = Path("data") app = typer.Typer() -def fetch_datasets(request: DataRequest, quiet: bool) -> pd.DataFrame: - """ - Fetch the datasets from ESGF. - - Parameters - ---------- - request - The request object - quiet - Whether to suppress progress messages from intake-esgf - - Returns - ------- - Dataframe that contains metadata and paths to the fetched datasets - """ - cat = ESGFCatalog() - - cat.search(**request.facets) - if request.remove_ensembles: - cat.remove_ensembles() - - path_dict = cat.to_path_dict(prefer_streaming=False, minimal_keys=False, quiet=quiet) - merged_df = cat.df.merge(pd.Series(path_dict, name="files"), left_on="key", right_index=True) - if request.time_span: - merged_df["time_start"] = request.time_span[0] - merged_df["time_end"] = request.time_span[1] - return merged_df - - def deduplicate_datasets(datasets: pd.DataFrame) -> pd.DataFrame: """ Deduplicate a dataset collection. @@ -90,7 +60,7 @@ def process_sample_data_request( quiet Whether to suppress progress messages """ - datasets = fetch_datasets(request, quiet) + datasets = request.fetch_datasets() datasets = deduplicate_datasets(datasets) for _, dataset in datasets.iterrows(): @@ -98,7 +68,7 @@ def process_sample_data_request( ds_orig = xr.open_dataset(ds_filename) if decimate: - ds_decimated = request.decimate_dataset(ds_orig, request.time_span) + ds_decimated = request.decimate_dataset(ds_orig) else: ds_decimated = ds_orig if ds_decimated is None: @@ -113,74 +83,78 @@ def process_sample_data_request( DATASETS_TO_FETCH = [ - # Example metric data - CMIP6Request( - facets=dict( - source_id="ACCESS-ESM1-5", - frequency=["fx", "mon"], - variable_id=["areacella", "tas", "tos", "rsut", "rlut", "rsdt"], - experiment_id=["ssp126", "historical"], - ), - remove_ensembles=True, - time_span=("2000", "2025"), - ), - # ESMValTool ECS data - CMIP6Request( - facets=dict( - source_id="ACCESS-ESM1-5", - frequency=["fx", "mon"], - variable_id=["areacella", "rlut", "rsdt", "rsut", "tas"], - experiment_id=["abrupt-4xCO2", "piControl"], - ), - remove_ensembles=True, - time_span=("0101", "0125"), - ), - # ESMValTool TCR data - CMIP6Request( - facets=dict( - source_id="ACCESS-ESM1-5", - frequency=["fx", "mon"], - variable_id=["areacella", "tas"], - experiment_id=["1pctCO2", "piControl"], - ), - remove_ensembles=True, - time_span=("0101", "0180"), - ), - # ILAMB data - CMIP6Request( - facets=dict( - source_id="ACCESS-ESM1-5", - frequency=["fx", "mon"], - variable_id=["areacella", "sftlf", "gpp", "pr"], - experiment_id=["historical"], - ), - remove_ensembles=True, - time_span=("2000", "2025"), - ), - # PMP PDO data - CMIP6Request( - facets=dict( - source_id="ACCESS-ESM1-5", - frequency=["fx", "mon"], - variable_id=["areacella", "ts"], - experiment_id=["historical", "hist-GHG"], - variant_label=["r1i1p1f1", "r2i1p1f1"], - ), - remove_ensembles=False, - time_span=("2000", "2025"), - ), - # Obs4MIPs AIRS data - Obs4MIPsRequest( - facets=dict( - project="obs4MIPs", - institution_id="NASA-JPL", - frequency="mon", - source_id="AIRS-2-1", - variable_id="ta", - ), - remove_ensembles=False, - time_span=("2002", "2016"), + # # Example metric data + # CMIP6Request( + # facets=dict( + # source_id="ACCESS-ESM1-5", + # frequency=["fx", "mon"], + # variable_id=["areacella", "tas", "tos", "rsut", "rlut", "rsdt"], + # experiment_id=["ssp126", "historical"], + # ), + # remove_ensembles=True, + # time_span=("2000", "2025"), + # ), + # # ESMValTool ECS data + # CMIP6Request( + # facets=dict( + # source_id="ACCESS-ESM1-5", + # frequency=["fx", "mon"], + # variable_id=["areacella", "rlut", "rsdt", "rsut", "tas"], + # experiment_id=["abrupt-4xCO2", "piControl"], + # ), + # remove_ensembles=True, + # time_span=("0101", "0125"), + # ), + # # ESMValTool TCR data + # CMIP6Request( + # facets=dict( + # source_id="ACCESS-ESM1-5", + # frequency=["fx", "mon"], + # variable_id=["areacella", "tas"], + # experiment_id=["1pctCO2", "piControl"], + # ), + # remove_ensembles=True, + # time_span=("0101", "0180"), + # ), + # # ILAMB data + # CMIP6Request( + # facets=dict( + # source_id="ACCESS-ESM1-5", + # frequency=["fx", "mon"], + # variable_id=["areacella", "sftlf", "gpp", "pr"], + # experiment_id=["historical"], + # ), + # remove_ensembles=True, + # time_span=("2000", "2025"), + # ), + # # PMP PDO data + # CMIP6Request( + # facets=dict( + # source_id="ACCESS-ESM1-5", + # frequency=["fx", "mon"], + # variable_id=["areacella", "ts"], + # experiment_id=["historical", "hist-GHG"], + # variant_label=["r1i1p1f1", "r2i1p1f1"], + # ), + # remove_ensembles=False, + # time_span=("2000", "2025"), + # ), + PMPRequest( + url="https://pcmdiweb.llnl.gov/pss/pmpdata/obs4MIPs_PCMDI_monthly/NOAA-ESRL-PSD/20CR/mon/psl/gn/v20210727/psl_mon_20CR_PCMDI_gn_187101-201212.nc", + hash="md5:570ce90b3afd1d0b31690ae5dbe32d31", ), + # # Obs4MIPs AIRS data + # Obs4MIPsRequest( + # facets=dict( + # project="obs4MIPs", + # institution_id="NASA-JPL", + # frequency="mon", + # source_id="AIRS-2-1", + # variable_id="ta", + # ), + # remove_ensembles=False, + # time_span=("2002", "2016"), + # ), ] diff --git a/src/ref_sample_data/__init__.py b/src/ref_sample_data/__init__.py index c33270f0..522d334d 100644 --- a/src/ref_sample_data/__init__.py +++ b/src/ref_sample_data/__init__.py @@ -11,4 +11,4 @@ from .data_request.cmip6 import CMIP6Request from .data_request.obs4mips import Obs4MIPsRequest -__all__ = ["DataRequest", "CMIP6Request", "Obs4MIPsRequest"] +__all__ = ["CMIP6Request", "DataRequest", "Obs4MIPsRequest"] diff --git a/src/ref_sample_data/data_request/base.py b/src/ref_sample_data/data_request/base.py index b387a43e..cade7859 100644 --- a/src/ref_sample_data/data_request/base.py +++ b/src/ref_sample_data/data_request/base.py @@ -3,6 +3,7 @@ import pandas as pd import xarray as xr +from intake_esgf import ESGFCatalog class DataRequest(Protocol): @@ -14,11 +15,15 @@ class DataRequest(Protocol): differently to generate the sample data. """ - facets: dict[str, str | tuple[str, ...]] - remove_ensembles: bool - time_span: tuple[str, str] + def fetch_datasets(self) -> pd.DataFrame: + """ + Fetch the datasets from the source + + Returns a dataframe of the metadata and paths to the fetched datasets. + """ + ... - def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: + def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None: """Downscale the dataset to a smaller size.""" ... @@ -27,3 +32,28 @@ def generate_filename( ) -> pathlib.Path: """Create the output filename for the dataset.""" ... + + +class IntakeESGFDataRequest(DataRequest): + """ + A data request that fetches datasets from ESGF using intake-esgf. + """ + + facets: dict[str, str | tuple[str, ...]] + remove_ensembles: bool + time_span: tuple[str, str] + + def fetch_datasets(self) -> pd.DataFrame: + """Fetch the datasets from the ESGF.""" + cat = ESGFCatalog() + + cat.search(**self.facets) + if self.remove_ensembles: + cat.remove_ensembles() + + path_dict = cat.to_path_dict(prefer_streaming=False, minimal_keys=False, quiet=True) + merged_df = cat.df.merge(pd.Series(path_dict, name="files"), left_on="key", right_index=True) + if self.time_span: + merged_df["time_start"] = self.time_span[0] + merged_df["time_end"] = self.time_span[1] + return merged_df diff --git a/src/ref_sample_data/data_request/cmip6.py b/src/ref_sample_data/data_request/cmip6.py index 80405b5a..641b2be4 100644 --- a/src/ref_sample_data/data_request/cmip6.py +++ b/src/ref_sample_data/data_request/cmip6.py @@ -6,7 +6,7 @@ import pandas as pd import xarray as xr -from ref_sample_data.data_request.base import DataRequest +from ref_sample_data.data_request.base import IntakeESGFDataRequest from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear @@ -37,7 +37,7 @@ def prefix_to_filename(ds, filename_prefix: str) -> str: return filename -class CMIP6Request(DataRequest): +class CMIP6Request(IntakeESGFDataRequest): """ Represents a CMIP6 dataset request @@ -86,7 +86,7 @@ def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tu assert all(key in self.avail_facets for key in self.cmip6_path_items), "Error message" assert all(key in self.avail_facets for key in self.cmip6_filename_paths), "Error message" - def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: + def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None: """ Downscale the dataset to a smaller size. @@ -115,8 +115,8 @@ def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | Non else: raise ValueError("Cannot decimate this grid: too many dimensions") - if "time" in dataset.dims and time_span is not None: - result = result.sel(time=slice(*time_span)) + if "time" in dataset.dims and self.time_span is not None: + result = result.sel(time=slice(*self.time_span)) if result.time.size == 0: result = None diff --git a/src/ref_sample_data/data_request/obs4mips.py b/src/ref_sample_data/data_request/obs4mips.py index 72303fea..88656429 100644 --- a/src/ref_sample_data/data_request/obs4mips.py +++ b/src/ref_sample_data/data_request/obs4mips.py @@ -6,12 +6,12 @@ import pandas as pd import xarray as xr -from ref_sample_data.data_request.base import DataRequest +from ref_sample_data.data_request.base import IntakeESGFDataRequest from ref_sample_data.data_request.cmip6 import prefix_to_filename from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear -class Obs4MIPsRequest(DataRequest): +class Obs4MIPsRequest(IntakeESGFDataRequest): """ Represents a Obs4MIPs dataset request """ @@ -65,7 +65,7 @@ def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tu assert all(key in self.avail_facets for key in self.obs4mips_path_items), "Error message" assert all(key in self.avail_facets for key in self.obs4mips_filename_paths), "Error message" - def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: + def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None: """ Downscale the dataset to a smaller size. @@ -94,8 +94,8 @@ def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | Non else: raise ValueError("Cannot decimate this grid: too many dimensions") - if "time" in dataset.dims and time_span is not None: - result = result.sel(time=slice(*time_span)) + if "time" in dataset.dims and self.time_span is not None: + result = result.sel(time=slice(*self.time_span)) if result.time.size == 0: result = None From b58036bd0618db277a53a15229cb675e0a1304a6 Mon Sep 17 00:00:00 2001 From: Jared Lewis Date: Fri, 28 Mar 2025 19:20:44 -0600 Subject: [PATCH 2/3] chore: Changelog --- changelog/23.improvement.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/23.improvement.md diff --git a/changelog/23.improvement.md b/changelog/23.improvement.md new file mode 100644 index 00000000..7d73a4df --- /dev/null +++ b/changelog/23.improvement.md @@ -0,0 +1 @@ +Refactored intake-esgf-based data requests to have a common base class (`ref_sample_data.data_request.base.IntakeESGFDataRequest`) From fe33e1e70c87d51ba32fd94aecc850bf9ae87a2c Mon Sep 17 00:00:00 2001 From: Jared Lewis Date: Fri, 28 Mar 2025 19:23:10 -0600 Subject: [PATCH 3/3] chore: Cancel any in progress workflows --- .github/workflows/ci.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 46629d46..1b529ef2 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -6,6 +6,10 @@ on: branches: [main] tags: ['v*'] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: pre-commit: if: ${{ !github.event.pull_request.draft }}