diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 46629d46..1b529ef2 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -6,6 +6,10 @@ on: branches: [main] tags: ['v*'] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: pre-commit: if: ${{ !github.event.pull_request.draft }} diff --git a/changelog/23.improvement.md b/changelog/23.improvement.md new file mode 100644 index 00000000..7d73a4df --- /dev/null +++ b/changelog/23.improvement.md @@ -0,0 +1 @@ +Refactored intake-esgf-based data requests to have a common base class (`ref_sample_data.data_request.base.IntakeESGFDataRequest`) diff --git a/scripts/fetch_test_data.py b/scripts/fetch_test_data.py index fd9ad33f..d9064479 100755 --- a/scripts/fetch_test_data.py +++ b/scripts/fetch_test_data.py @@ -6,7 +6,6 @@ import pooch import typer import xarray as xr -from intake_esgf import ESGFCatalog from ref_sample_data import CMIP6Request, DataRequest, Obs4MIPsRequest @@ -14,35 +13,6 @@ app = typer.Typer() -def fetch_datasets(request: DataRequest, quiet: bool) -> pd.DataFrame: - """ - Fetch the datasets from ESGF. - - Parameters - ---------- - request - The request object - quiet - Whether to suppress progress messages from intake-esgf - - Returns - ------- - Dataframe that contains metadata and paths to the fetched datasets - """ - cat = ESGFCatalog() - - cat.search(**request.facets) - if request.remove_ensembles: - cat.remove_ensembles() - - path_dict = cat.to_path_dict(prefer_streaming=False, minimal_keys=False, quiet=quiet) - merged_df = cat.df.merge(pd.Series(path_dict, name="files"), left_on="key", right_index=True) - if request.time_span: - merged_df["time_start"] = request.time_span[0] - merged_df["time_end"] = request.time_span[1] - return merged_df - - def deduplicate_datasets(datasets: pd.DataFrame) -> pd.DataFrame: """ Deduplicate a dataset collection. @@ -90,7 +60,7 @@ def process_sample_data_request( quiet Whether to suppress progress messages """ - datasets = fetch_datasets(request, quiet) + datasets = request.fetch_datasets() datasets = deduplicate_datasets(datasets) for _, dataset in datasets.iterrows(): @@ -98,7 +68,7 @@ def process_sample_data_request( ds_orig = xr.open_dataset(ds_filename) if decimate: - ds_decimated = request.decimate_dataset(ds_orig, request.time_span) + ds_decimated = request.decimate_dataset(ds_orig) else: ds_decimated = ds_orig if ds_decimated is None: diff --git a/src/ref_sample_data/__init__.py b/src/ref_sample_data/__init__.py index c33270f0..522d334d 100644 --- a/src/ref_sample_data/__init__.py +++ b/src/ref_sample_data/__init__.py @@ -11,4 +11,4 @@ from .data_request.cmip6 import CMIP6Request from .data_request.obs4mips import Obs4MIPsRequest -__all__ = ["DataRequest", "CMIP6Request", "Obs4MIPsRequest"] +__all__ = ["CMIP6Request", "DataRequest", "Obs4MIPsRequest"] diff --git a/src/ref_sample_data/data_request/base.py b/src/ref_sample_data/data_request/base.py index b387a43e..cade7859 100644 --- a/src/ref_sample_data/data_request/base.py +++ b/src/ref_sample_data/data_request/base.py @@ -3,6 +3,7 @@ import pandas as pd import xarray as xr +from intake_esgf import ESGFCatalog class DataRequest(Protocol): @@ -14,11 +15,15 @@ class DataRequest(Protocol): differently to generate the sample data. """ - facets: dict[str, str | tuple[str, ...]] - remove_ensembles: bool - time_span: tuple[str, str] + def fetch_datasets(self) -> pd.DataFrame: + """ + Fetch the datasets from the source + + Returns a dataframe of the metadata and paths to the fetched datasets. + """ + ... - def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: + def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None: """Downscale the dataset to a smaller size.""" ... @@ -27,3 +32,28 @@ def generate_filename( ) -> pathlib.Path: """Create the output filename for the dataset.""" ... + + +class IntakeESGFDataRequest(DataRequest): + """ + A data request that fetches datasets from ESGF using intake-esgf. + """ + + facets: dict[str, str | tuple[str, ...]] + remove_ensembles: bool + time_span: tuple[str, str] + + def fetch_datasets(self) -> pd.DataFrame: + """Fetch the datasets from the ESGF.""" + cat = ESGFCatalog() + + cat.search(**self.facets) + if self.remove_ensembles: + cat.remove_ensembles() + + path_dict = cat.to_path_dict(prefer_streaming=False, minimal_keys=False, quiet=True) + merged_df = cat.df.merge(pd.Series(path_dict, name="files"), left_on="key", right_index=True) + if self.time_span: + merged_df["time_start"] = self.time_span[0] + merged_df["time_end"] = self.time_span[1] + return merged_df diff --git a/src/ref_sample_data/data_request/cmip6.py b/src/ref_sample_data/data_request/cmip6.py index 80405b5a..641b2be4 100644 --- a/src/ref_sample_data/data_request/cmip6.py +++ b/src/ref_sample_data/data_request/cmip6.py @@ -6,7 +6,7 @@ import pandas as pd import xarray as xr -from ref_sample_data.data_request.base import DataRequest +from ref_sample_data.data_request.base import IntakeESGFDataRequest from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear @@ -37,7 +37,7 @@ def prefix_to_filename(ds, filename_prefix: str) -> str: return filename -class CMIP6Request(DataRequest): +class CMIP6Request(IntakeESGFDataRequest): """ Represents a CMIP6 dataset request @@ -86,7 +86,7 @@ def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tu assert all(key in self.avail_facets for key in self.cmip6_path_items), "Error message" assert all(key in self.avail_facets for key in self.cmip6_filename_paths), "Error message" - def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: + def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None: """ Downscale the dataset to a smaller size. @@ -115,8 +115,8 @@ def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | Non else: raise ValueError("Cannot decimate this grid: too many dimensions") - if "time" in dataset.dims and time_span is not None: - result = result.sel(time=slice(*time_span)) + if "time" in dataset.dims and self.time_span is not None: + result = result.sel(time=slice(*self.time_span)) if result.time.size == 0: result = None diff --git a/src/ref_sample_data/data_request/obs4mips.py b/src/ref_sample_data/data_request/obs4mips.py index 72303fea..88656429 100644 --- a/src/ref_sample_data/data_request/obs4mips.py +++ b/src/ref_sample_data/data_request/obs4mips.py @@ -6,12 +6,12 @@ import pandas as pd import xarray as xr -from ref_sample_data.data_request.base import DataRequest +from ref_sample_data.data_request.base import IntakeESGFDataRequest from ref_sample_data.data_request.cmip6 import prefix_to_filename from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear -class Obs4MIPsRequest(DataRequest): +class Obs4MIPsRequest(IntakeESGFDataRequest): """ Represents a Obs4MIPs dataset request """ @@ -65,7 +65,7 @@ def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tu assert all(key in self.avail_facets for key in self.obs4mips_path_items), "Error message" assert all(key in self.avail_facets for key in self.obs4mips_filename_paths), "Error message" - def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: + def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None: """ Downscale the dataset to a smaller size. @@ -94,8 +94,8 @@ def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | Non else: raise ValueError("Cannot decimate this grid: too many dimensions") - if "time" in dataset.dims and time_span is not None: - result = result.sel(time=slice(*time_span)) + if "time" in dataset.dims and self.time_span is not None: + result = result.sel(time=slice(*self.time_span)) if result.time.size == 0: result = None