Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ on:
branches: [main]
tags: ['v*']

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
pre-commit:
if: ${{ !github.event.pull_request.draft }}
Expand Down
1 change: 1 addition & 0 deletions changelog/23.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactored intake-esgf-based data requests to have a common base class (`ref_sample_data.data_request.base.IntakeESGFDataRequest`)
34 changes: 2 additions & 32 deletions scripts/fetch_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,13 @@
import pooch
import typer
import xarray as xr
from intake_esgf import ESGFCatalog

from ref_sample_data import CMIP6Request, DataRequest, Obs4MIPsRequest

OUTPUT_PATH = Path("data")
app = typer.Typer()


def fetch_datasets(request: DataRequest, quiet: bool) -> pd.DataFrame:
"""
Fetch the datasets from ESGF.

Parameters
----------
request
The request object
quiet
Whether to suppress progress messages from intake-esgf

Returns
-------
Dataframe that contains metadata and paths to the fetched datasets
"""
cat = ESGFCatalog()

cat.search(**request.facets)
if request.remove_ensembles:
cat.remove_ensembles()

path_dict = cat.to_path_dict(prefer_streaming=False, minimal_keys=False, quiet=quiet)
merged_df = cat.df.merge(pd.Series(path_dict, name="files"), left_on="key", right_index=True)
if request.time_span:
merged_df["time_start"] = request.time_span[0]
merged_df["time_end"] = request.time_span[1]
return merged_df


def deduplicate_datasets(datasets: pd.DataFrame) -> pd.DataFrame:
"""
Deduplicate a dataset collection.
Expand Down Expand Up @@ -90,15 +60,15 @@ def process_sample_data_request(
quiet
Whether to suppress progress messages
"""
datasets = fetch_datasets(request, quiet)
datasets = request.fetch_datasets()
datasets = deduplicate_datasets(datasets)

for _, dataset in datasets.iterrows():
for ds_filename in dataset["files"]:
ds_orig = xr.open_dataset(ds_filename)

if decimate:
ds_decimated = request.decimate_dataset(ds_orig, request.time_span)
ds_decimated = request.decimate_dataset(ds_orig)
else:
ds_decimated = ds_orig
if ds_decimated is None:
Expand Down
2 changes: 1 addition & 1 deletion src/ref_sample_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
from .data_request.cmip6 import CMIP6Request
from .data_request.obs4mips import Obs4MIPsRequest

__all__ = ["DataRequest", "CMIP6Request", "Obs4MIPsRequest"]
__all__ = ["CMIP6Request", "DataRequest", "Obs4MIPsRequest"]
38 changes: 34 additions & 4 deletions src/ref_sample_data/data_request/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pandas as pd
import xarray as xr
from intake_esgf import ESGFCatalog


class DataRequest(Protocol):
Expand All @@ -14,11 +15,15 @@ class DataRequest(Protocol):
differently to generate the sample data.
"""

facets: dict[str, str | tuple[str, ...]]
remove_ensembles: bool
time_span: tuple[str, str]
def fetch_datasets(self) -> pd.DataFrame:
"""
Fetch the datasets from the source

Returns a dataframe of the metadata and paths to the fetched datasets.
"""
...

def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None:
def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
"""Downscale the dataset to a smaller size."""
...

Expand All @@ -27,3 +32,28 @@ def generate_filename(
) -> pathlib.Path:
"""Create the output filename for the dataset."""
...


class IntakeESGFDataRequest(DataRequest):
"""
A data request that fetches datasets from ESGF using intake-esgf.
"""

facets: dict[str, str | tuple[str, ...]]
remove_ensembles: bool
time_span: tuple[str, str]

def fetch_datasets(self) -> pd.DataFrame:
"""Fetch the datasets from the ESGF."""
cat = ESGFCatalog()

cat.search(**self.facets)
if self.remove_ensembles:
cat.remove_ensembles()

path_dict = cat.to_path_dict(prefer_streaming=False, minimal_keys=False, quiet=True)
merged_df = cat.df.merge(pd.Series(path_dict, name="files"), left_on="key", right_index=True)
if self.time_span:
merged_df["time_start"] = self.time_span[0]
merged_df["time_end"] = self.time_span[1]
return merged_df
10 changes: 5 additions & 5 deletions src/ref_sample_data/data_request/cmip6.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import xarray as xr

from ref_sample_data.data_request.base import DataRequest
from ref_sample_data.data_request.base import IntakeESGFDataRequest
from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear


Expand Down Expand Up @@ -37,7 +37,7 @@ def prefix_to_filename(ds, filename_prefix: str) -> str:
return filename


class CMIP6Request(DataRequest):
class CMIP6Request(IntakeESGFDataRequest):
"""
Represents a CMIP6 dataset request

Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tu
assert all(key in self.avail_facets for key in self.cmip6_path_items), "Error message"
assert all(key in self.avail_facets for key in self.cmip6_filename_paths), "Error message"

def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None:
def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
"""
Downscale the dataset to a smaller size.

Expand Down Expand Up @@ -115,8 +115,8 @@ def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | Non
else:
raise ValueError("Cannot decimate this grid: too many dimensions")

if "time" in dataset.dims and time_span is not None:
result = result.sel(time=slice(*time_span))
if "time" in dataset.dims and self.time_span is not None:
result = result.sel(time=slice(*self.time_span))
if result.time.size == 0:
result = None

Expand Down
10 changes: 5 additions & 5 deletions src/ref_sample_data/data_request/obs4mips.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import pandas as pd
import xarray as xr

from ref_sample_data.data_request.base import DataRequest
from ref_sample_data.data_request.base import IntakeESGFDataRequest
from ref_sample_data.data_request.cmip6 import prefix_to_filename
from ref_sample_data.resample import decimate_curvilinear, decimate_rectilinear


class Obs4MIPsRequest(DataRequest):
class Obs4MIPsRequest(IntakeESGFDataRequest):
"""
Represents a Obs4MIPs dataset request
"""
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tu
assert all(key in self.avail_facets for key in self.obs4mips_path_items), "Error message"
assert all(key in self.avail_facets for key in self.obs4mips_filename_paths), "Error message"

def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None:
def decimate_dataset(self, dataset: xr.Dataset) -> xr.Dataset | None:
"""
Downscale the dataset to a smaller size.

Expand Down Expand Up @@ -94,8 +94,8 @@ def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | Non
else:
raise ValueError("Cannot decimate this grid: too many dimensions")

if "time" in dataset.dims and time_span is not None:
result = result.sel(time=slice(*time_span))
if "time" in dataset.dims and self.time_span is not None:
result = result.sel(time=slice(*self.time_span))
if result.time.size == 0:
result = None

Expand Down
Loading