diff --git a/.github/actions/regenerate/action.yml b/.github/actions/regenerate/action.yml index dfc25062..0ad55a89 100644 --- a/.github/actions/regenerate/action.yml +++ b/.github/actions/regenerate/action.yml @@ -6,9 +6,12 @@ runs: - uses: ./.github/actions/setup with: python-version: 3.12 + cache-esgf: true - name: Verify registry shell: bash + env: + QUIET: true run: | git config --global user.name "$GITHUB_ACTOR" git config --global user.email "$CI_COMMIT_EMAIL" diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml index 898b566e..f4bbf7c0 100644 --- a/.github/actions/setup/action.yml +++ b/.github/actions/setup/action.yml @@ -9,6 +9,10 @@ inputs: description: "The version of uv to use" required: true default: ">=0.4.20" + cache-esgf: + description: "Cache any downloaded ESGF data" + required: false + default: "false" runs: using: "composite" @@ -27,3 +31,10 @@ runs: shell: bash run: | uv sync --all-extras --dev --locked + - name: Cache downloaded ESGF data + uses: actions/cache@v4 + if: ${{ inputs.cache-esgf == 'true' }} + with: + path: | + ~/.esgf + key: esgf diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 82a109de..46629d46 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -36,8 +36,11 @@ jobs: - uses: ./.github/actions/setup with: python-version: ${{ matrix.python-version }} + cache-esgf: true - name: Verify registry + env: + QUIET: true run: | make fetch-test-data git diff --exit-code diff --git a/.gitignore b/.gitignore index 3eb1fc7c..f281cf16 100644 --- a/.gitignore +++ b/.gitignore @@ -151,4 +151,5 @@ dmypy.json # Generated output out +data-raw .ref diff --git a/changelog/16.feature.md b/changelog/16.feature.md new file mode 100644 index 00000000..64559109 --- /dev/null +++ b/changelog/16.feature.md @@ -0,0 +1 @@ +Allow for the fetching of non-decimated datasets diff --git a/pyproject.toml b/pyproject.toml index 841946eb..fb7bdf31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "cmip-ref-sample-data" +name = "ref-sample-data" version = "0.3.2" description = "CMIP Rapid Evaluation Framework Sample Data" readme = "README.md" @@ -13,6 +13,7 @@ dependencies = [ "matplotlib>=3.10.0", "scipy>=1.15.0", "xarray>=2024.10.0", + "typer>=0.15.1", ] [project.license] @@ -29,9 +30,12 @@ dev-dependencies = [ "bump-my-version>=0.29.0", ] +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" [tool.coverage.run] -source = ["packages"] +source = ["src"] branch = true [tool.coverage.report] diff --git a/ruff.toml b/ruff.toml index 9268df3f..3a5a3cc6 100644 --- a/ruff.toml +++ b/ruff.toml @@ -23,6 +23,7 @@ ignore = [ "D200", "D400", "UP007", + "S101" # Use of `assert` detected ] [lint.per-file-ignores] diff --git a/scripts/fetch_test_data.py b/scripts/fetch_test_data.py old mode 100644 new mode 100755 index 47a3bbfe..64cc9494 --- a/scripts/fetch_test_data.py +++ b/scripts/fetch_test_data.py @@ -1,286 +1,33 @@ -import os import pathlib -from abc import ABC, abstractmethod from pathlib import Path -from typing import Any +from typing import Annotated import pandas as pd import pooch +import typer import xarray as xr from intake_esgf import ESGFCatalog -OUTPUT_PATH = Path("data") - - -class DataRequest(ABC): - """ - Represents a request for a dataset - - A polymorphic association is used to capture the different types of datasets as each - dataset type may have different metadata fields and may need to be handled - differently to generate the sample data. - """ - - def __init__(self, remove_ensembles: bool, time_span: tuple[str, str]): - self.remove_ensembles = remove_ensembles - self.time_span = time_span - - @abstractmethod - def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: - """Downscale the dataset to a smaller size.""" - pass - - @abstractmethod - def create_out_filename(self, metadata: pd.Series, ds: xr.Dataset, ds_filename: str) -> pathlib.Path: - """Create the output filename for the dataset.""" - pass - - -class CMIP6Request(DataRequest): - """ - Represents a CMIP6 dataset request - - """ - - def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tuple[str, str] | None): - self.avail_facets = [ - "mip_era", - "activity_drs", - "institution_id", - "source_id", - "experiment_id", - "member_id", - "table_id", - "variable_id", - "grid_label", - "version", - "data_node", - ] - - self.facets = facets - - super().__init__(remove_ensembles, time_span) - - self.cmip6_path_items = [ - "mip_era", - "activity_drs", - "institution_id", - "source_id", - "experiment_id", - "member_id", - "table_id", - "variable_id", - "grid_label", - ] - - self.cmip6_filename_paths = [ - "variable_id", - "table_id", - "source_id", - "experiment_id", - "member_id", - "grid_label", - ] - - assert all(key in self.avail_facets for key in self.cmip6_path_items), "Error message" - assert all(key in self.avail_facets for key in self.cmip6_filename_paths), "Error message" - - def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: - """ - Downscale the dataset to a smaller size. - - Parameters - ---------- - dataset - The dataset to downscale - time_span - The time span to extract from a dataset - - Returns - ------- - xr.Dataset - The downscaled dataset - """ - has_latlon = "lat" in dataset.dims and "lon" in dataset.dims - has_ij = "i" in dataset.dims and "j" in dataset.dims - - if has_latlon: - assert len(dataset.lat.dims) == 1 and len(dataset.lon.dims) == 1 - result = dataset.interp(lat=dataset.lat[:10], lon=dataset.lon[:10]) - elif has_ij: - # 2d lat/lon grid (generally ocean variables) - # Choose a starting point around the middle of the grid to maximise chance that it has values - # TODO: Be smarter about this? - j_midpoint = len(dataset.j) // 2 - result = dataset.interp(i=dataset.i[:10], j=dataset.j[j_midpoint : j_midpoint + 10]) - else: - raise ValueError("Cannot decimate this grid: too many dimensions") - - if "time" in dataset.dims and time_span is not None: - result = result.sel(time=slice(*time_span)) - if result.time.size == 0: - result = None - - return result - - def create_out_filename(self, metadata: pd.Series, ds: xr.Dataset, ds_filename: str) -> pathlib.Path: - """ - Create the output filename for the dataset. - - Parameters - ---------- - ds - Loaded dataset - - Returns - ------- - The output filename - """ - output_path = ( - Path(os.path.join(*[metadata[item] for item in self.cmip6_path_items])) - / f"v{metadata['version']}" - ) - filename_prefix = "_".join([metadata[item] for item in self.cmip6_filename_paths]) - - if "time" in ds.dims: - time_range = ( - f"{ds.time.min().dt.strftime('%Y%m').item()}-{ds.time.max().dt.strftime('%Y%m').item()}" - ) - filename = f"{filename_prefix}_{time_range}.nc" - else: - filename = f"{filename_prefix}.nc" - - return output_path / filename +from ref_sample_data import CMIP6Request, DataRequest, Obs4MIPsRequest - -class Obs4MIPsRequest(DataRequest): - """ - Represents a Obs4MIPs dataset request - - """ - - def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tuple[str, str] | None): - self.avail_facets = [ - "activity_id", - "institution_id", - "source_id", - "frequency", - "variable_id", - "grid_label", - "version", - "data_node", - ] - - self.facets = facets - - super().__init__(remove_ensembles, time_span) - - self.obs4mips_path_items = [ - "activity_id", - "institution_id", - "source_id", - "variable_id", - "grid_label", - ] - - self.obs4mips_filename_paths = [ - "variable_id", - "source_id", - "grid_label", - ] - - assert all(key in self.avail_facets for key in self.obs4mips_path_items), "Error message" - assert all(key in self.avail_facets for key in self.obs4mips_filename_paths), "Error message" - - def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: - """ - Downscale the dataset to a smaller size. - - Parameters - ---------- - dataset - The dataset to downscale - time_span - The time span to extract from a dataset - - Returns - ------- - xr.Dataset - The downscaled dataset - """ - has_latlon = "lat" in dataset.dims and "lon" in dataset.dims - has_ij = "i" in dataset.dims and "j" in dataset.dims - - if has_latlon: - assert len(dataset.lat.dims) == 1 and len(dataset.lon.dims) == 1 - result = dataset.interp(lat=dataset.lat[:10], lon=dataset.lon[:10]) - elif has_ij: - # 2d lat/lon grid (generally ocean variables) - # Choose a starting point around the middle of the grid to maximise chance that it has values - # TODO: Be smarter about this? - j_midpoint = len(dataset.j) // 2 - result = dataset.interp(i=dataset.i[:10], j=dataset.j[j_midpoint : j_midpoint + 10]) - else: - raise ValueError("Cannot decimate this grid: too many dimensions") - - if "time" in dataset.dims and time_span is not None: - result = result.sel(time=slice(*time_span)) - if result.time.size == 0: - result = None - - return result - - def create_out_filename(self, metadata: pd.Series, ds: xr.Dataset, ds_filename: str) -> pathlib.Path: - """ - Create the output filename for the dataset. - - Parameters - ---------- - ds - Loaded dataset - - Returns - ------- - The output filename - """ - output_path = ( - Path(os.path.join(*[metadata[item] for item in self.obs4mips_path_items])) - / f"v{metadata['version']}" - ) - if ds_filename.name.split("_")[0] == ds.variable_id: - filename_prefix = "_".join([metadata[item] for item in self.obs4mips_filename_paths]) - else: - filename_prefix = ds_filename.name.split("_")[0] + "_" - filename_prefix += "_".join( - [metadata[item] for item in self.obs4mips_filename_paths if item != "variable_id"] - ) - - if "time" in ds.dims: - time_range = ( - f"{ds.time.min().dt.strftime('%Y%m').item()}-{ds.time.max().dt.strftime('%Y%m').item()}" - ) - filename = f"{filename_prefix}_{time_range}.nc" - else: - filename = f"{filename_prefix}.nc" - - return output_path / filename +OUTPUT_PATH = Path("data") +app = typer.Typer() -def fetch_datasets(request: DataRequest) -> pd.DataFrame: +def fetch_datasets(request: DataRequest, quiet: bool) -> pd.DataFrame: """ Fetch the datasets from ESGF. Parameters ---------- - search_facets - Facets to search for - remove_ensembles - Whether to remove ensembles from the dataset - (i.e. include only a single ensemble member) + request + The request object + quiet + Whether to suppress progress messages from intake-esgf Returns ------- - List of paths to the fetched datasets + Dataframe that contains metadata and paths to the fetched datasets """ cat = ESGFCatalog() @@ -288,7 +35,7 @@ def fetch_datasets(request: DataRequest) -> pd.DataFrame: if request.remove_ensembles: cat.remove_ensembles() - path_dict = cat.to_path_dict(prefer_streaming=False, minimal_keys=False) + path_dict = cat.to_path_dict(prefer_streaming=False, minimal_keys=False, quiet=quiet) merged_df = cat.df.merge(pd.Series(path_dict, name="files"), left_on="key", right_index=True) if request.time_span: merged_df["time_start"] = request.time_span[0] @@ -296,7 +43,7 @@ def fetch_datasets(request: DataRequest) -> pd.DataFrame: return merged_df -def deduplicate_datasets(request: DataRequest) -> pd.DataFrame: +def deduplicate_datasets(datasets: pd.DataFrame) -> pd.DataFrame: """ Deduplicate a dataset collection. @@ -313,7 +60,6 @@ def deduplicate_datasets(request: DataRequest) -> pd.DataFrame: pd.DataFrame The deduplicated dataset collection spanning the times requested """ - datasets = fetch_datasets(request) def _deduplicate_group(group: pd.DataFrame) -> pd.DataFrame: first = group.iloc[0].copy() @@ -325,106 +71,131 @@ def _deduplicate_group(group: pd.DataFrame) -> pd.DataFrame: return datasets.groupby("key").apply(_deduplicate_group, include_groups=False).reset_index() -def create_sample_dataset(request: DataRequest): +def process_sample_data_request( + request: DataRequest, decimate: bool, output_directory: Path, quiet: bool +) -> None: """ - Create the output filename for the dataset. + Fetch and create sample datasets Parameters ---------- - ds - Loaded dataset - - Returns - ------- - The output filename + request + The request to execute + + This may be different types of requests, such as CMIP6Request or Obs4MIPsRequest. + decimate + Whether to decimate the datasets + output_directory + The directory to write the output to + quiet + Whether to suppress progress messages """ - datasets = deduplicate_datasets(request) + datasets = fetch_datasets(request, quiet) + datasets = deduplicate_datasets(datasets) + for _, dataset in datasets.iterrows(): for ds_filename in dataset["files"]: ds_orig = xr.open_dataset(ds_filename) - ds_decimated = request.decimate_dataset(ds_orig, request.time_span) + + if decimate: + ds_decimated = request.decimate_dataset(ds_orig, request.time_span) + else: + ds_decimated = ds_orig if ds_decimated is None: continue - output_filename = OUTPUT_PATH / request.create_out_filename(dataset, ds_decimated, ds_filename) + output_filename = output_directory / request.generate_filename(dataset, ds_decimated, ds_filename) output_filename.parent.mkdir(parents=True, exist_ok=True) ds_decimated.to_netcdf(output_filename) # Regenerate the registry.txt file - pooch.make_registry(OUTPUT_PATH, "registry.txt") + pooch.make_registry(str(OUTPUT_PATH), "registry.txt") -if __name__ == "__main__": - datasets_to_fetch = [ - # Example metric data - CMIP6Request( - facets=dict( - source_id="ACCESS-ESM1-5", - frequency=["fx", "mon"], - variable_id=["areacella", "tas", "tos", "rsut", "rlut", "rsdt"], - experiment_id=["ssp126", "historical"], - ), - remove_ensembles=True, - time_span=("2000", "2025"), +DATASETS_TO_FETCH = [ + # Example metric data + CMIP6Request( + facets=dict( + source_id="ACCESS-ESM1-5", + frequency=["fx", "mon"], + variable_id=["areacella", "tas", "tos", "rsut", "rlut", "rsdt"], + experiment_id=["ssp126", "historical"], ), - # ESMValTool ECS data - CMIP6Request( - facets=dict( - source_id="ACCESS-ESM1-5", - frequency=["fx", "mon"], - variable_id=["areacella", "rlut", "rsdt", "rsut", "tas"], - experiment_id=["abrupt-4xCO2", "piControl"], - ), - remove_ensembles=True, - time_span=("0101", "0125"), + remove_ensembles=True, + time_span=("2000", "2025"), + ), + # ESMValTool ECS data + CMIP6Request( + facets=dict( + source_id="ACCESS-ESM1-5", + frequency=["fx", "mon"], + variable_id=["areacella", "rlut", "rsdt", "rsut", "tas"], + experiment_id=["abrupt-4xCO2", "piControl"], ), - # ESMValTool TCR data - CMIP6Request( - facets=dict( - source_id="ACCESS-ESM1-5", - frequency=["fx", "mon"], - variable_id=["areacella", "tas"], - experiment_id=["1pctCO2", "piControl"], - ), - remove_ensembles=True, - time_span=("0101", "0180"), + remove_ensembles=True, + time_span=("0101", "0125"), + ), + # ESMValTool TCR data + CMIP6Request( + facets=dict( + source_id="ACCESS-ESM1-5", + frequency=["fx", "mon"], + variable_id=["areacella", "tas"], + experiment_id=["1pctCO2", "piControl"], ), - # ILAMB data - CMIP6Request( - facets=dict( - source_id="ACCESS-ESM1-5", - frequency=["fx", "mon"], - variable_id=["areacella", "sftlf", "gpp", "pr"], - experiment_id=["historical"], - ), - remove_ensembles=True, - time_span=("2000", "2025"), + remove_ensembles=True, + time_span=("0101", "0180"), + ), + # ILAMB data + CMIP6Request( + facets=dict( + source_id="ACCESS-ESM1-5", + frequency=["fx", "mon"], + variable_id=["areacella", "sftlf", "gpp", "pr"], + experiment_id=["historical"], ), - # PMP PDO data - CMIP6Request( - facets=dict( - source_id="ACCESS-ESM1-5", - frequency=["fx", "mon"], - variable_id=["areacella", "ts"], - experiment_id=["historical", "hist-GHG"], - variant_label=["r1i1p1f1", "r2i1p1f1"], - ), - remove_ensembles=False, - time_span=("2000", "2025"), + remove_ensembles=True, + time_span=("2000", "2025"), + ), + # PMP PDO data + CMIP6Request( + facets=dict( + source_id="ACCESS-ESM1-5", + frequency=["fx", "mon"], + variable_id=["areacella", "ts"], + experiment_id=["historical", "hist-GHG"], + variant_label=["r1i1p1f1", "r2i1p1f1"], ), - # Obs4MIPs AIRS data - Obs4MIPsRequest( - facets=dict( - project="obs4MIPs", - institution_id="NASA-JPL", - frequency="mon", - source_id="AIRS-2-1", - variable_id="ta", - ), - remove_ensembles=False, - time_span=("2002", "2016"), + remove_ensembles=False, + time_span=("2000", "2025"), + ), + # Obs4MIPs AIRS data + Obs4MIPsRequest( + facets=dict( + project="obs4MIPs", + institution_id="NASA-JPL", + frequency="mon", + source_id="AIRS-2-1", + variable_id="ta", ), - ] + remove_ensembles=False, + time_span=("2002", "2016"), + ), +] + + +@app.command() +def create_sample_data( + decimate: bool = True, + output: Path = OUTPUT_PATH, + quiet: Annotated[bool, typer.Argument(envvar="QUIET")] = False, +) -> None: + """Fetch and create sample datasets""" + for dataset_requested in DATASETS_TO_FETCH: + process_sample_data_request( + dataset_requested, decimate=decimate, output_directory=pathlib.Path(output), quiet=quiet + ) - for dataset_requested in datasets_to_fetch: - create_sample_dataset(dataset_requested) + +if __name__ == "__main__": + app() diff --git a/src/ref_sample_data/__init__.py b/src/ref_sample_data/__init__.py new file mode 100644 index 00000000..c33270f0 --- /dev/null +++ b/src/ref_sample_data/__init__.py @@ -0,0 +1,14 @@ +""" +REF sample data +""" + +import importlib.metadata + +__version__ = importlib.metadata.version("ref_sample_data") + + +from .data_request.base import DataRequest +from .data_request.cmip6 import CMIP6Request +from .data_request.obs4mips import Obs4MIPsRequest + +__all__ = ["DataRequest", "CMIP6Request", "Obs4MIPsRequest"] diff --git a/src/ref_sample_data/data_request/__init__.py b/src/ref_sample_data/data_request/__init__.py new file mode 100644 index 00000000..30cc3a36 --- /dev/null +++ b/src/ref_sample_data/data_request/__init__.py @@ -0,0 +1,5 @@ +""" +Data requests + +Provides an abstraction over the different possible data queries that intake-esgf can perform +""" diff --git a/src/ref_sample_data/data_request/base.py b/src/ref_sample_data/data_request/base.py new file mode 100644 index 00000000..a2f7e751 --- /dev/null +++ b/src/ref_sample_data/data_request/base.py @@ -0,0 +1,27 @@ +import pathlib +from typing import Protocol + +import pandas as pd +import xarray as xr + + +class DataRequest(Protocol): + """ + Represents a request for a dataset + + A polymorphic association is used to capture the different types of datasets as each + dataset type may have different metadata fields and may need to be handled + differently to generate the sample data. + """ + + facets: dict[str, str | tuple[str, ...]] + remove_ensembles: bool + time_span: tuple[str, str] + + def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: + """Downscale the dataset to a smaller size.""" + ... + + def generate_filename(self, metadata: pd.Series, ds: xr.Dataset, ds_filename: str) -> pathlib.Path: + """Create the output filename for the dataset.""" + ... diff --git a/src/ref_sample_data/data_request/cmip6.py b/src/ref_sample_data/data_request/cmip6.py new file mode 100644 index 00000000..922dfa62 --- /dev/null +++ b/src/ref_sample_data/data_request/cmip6.py @@ -0,0 +1,125 @@ +import os.path +from pathlib import Path +from typing import Any + +import pandas as pd +import xarray as xr + +from ref_sample_data.data_request.base import DataRequest + + +class CMIP6Request(DataRequest): + """ + Represents a CMIP6 dataset request + + """ + + cmip6_path_items = ( + "mip_era", + "activity_drs", + "institution_id", + "source_id", + "experiment_id", + "member_id", + "table_id", + "variable_id", + "grid_label", + ) + + cmip6_filename_paths = ( + "variable_id", + "table_id", + "source_id", + "experiment_id", + "member_id", + "grid_label", + ) + + def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tuple[str, str] | None): + self.avail_facets = [ + "mip_era", + "activity_drs", + "institution_id", + "source_id", + "experiment_id", + "member_id", + "table_id", + "variable_id", + "grid_label", + "version", + "data_node", + ] + + self.facets = facets + self.remove_ensembles = remove_ensembles + self.time_span = time_span + + assert all(key in self.avail_facets for key in self.cmip6_path_items), "Error message" + assert all(key in self.avail_facets for key in self.cmip6_filename_paths), "Error message" + + def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: + """ + Downscale the dataset to a smaller size. + + Parameters + ---------- + dataset + The dataset to downscale + time_span + The time span to extract from a dataset + + Returns + ------- + xr.Dataset + The downscaled dataset + """ + has_latlon = "lat" in dataset.dims and "lon" in dataset.dims + has_ij = "i" in dataset.dims and "j" in dataset.dims + + if has_latlon: + assert len(dataset.lat.dims) == 1 and len(dataset.lon.dims) == 1 + result = dataset.interp(lat=dataset.lat[:10], lon=dataset.lon[:10]) + elif has_ij: + # 2d lat/lon grid (generally ocean variables) + # Choose a starting point around the middle of the grid to maximise chance that it has values + # TODO: Be smarter about this? + j_midpoint = len(dataset.j) // 2 + result = dataset.interp(i=dataset.i[:10], j=dataset.j[j_midpoint : j_midpoint + 10]) + else: + raise ValueError("Cannot decimate this grid: too many dimensions") + + if "time" in dataset.dims and time_span is not None: + result = result.sel(time=slice(*time_span)) + if result.time.size == 0: + result = None + + return result + + def generate_filename(self, metadata: pd.Series, ds: xr.Dataset, ds_filename: str) -> Path: + """ + Create the output filename for the dataset. + + Parameters + ---------- + ds + Loaded dataset + + Returns + ------- + The output filename + """ + output_path = ( + Path(os.path.join(*[metadata[item] for item in self.cmip6_path_items])) + / f"v{metadata['version']}" + ) + filename_prefix = "_".join([metadata[item] for item in self.cmip6_filename_paths]) + + if "time" in ds.dims: + time_range = ( + f"{ds.time.min().dt.strftime('%Y%m').item()}-{ds.time.max().dt.strftime('%Y%m').item()}" + ) + filename = f"{filename_prefix}_{time_range}.nc" + else: + filename = f"{filename_prefix}.nc" + + return output_path / filename diff --git a/src/ref_sample_data/data_request/obs4mips.py b/src/ref_sample_data/data_request/obs4mips.py new file mode 100644 index 00000000..1d9c6fe9 --- /dev/null +++ b/src/ref_sample_data/data_request/obs4mips.py @@ -0,0 +1,136 @@ +import os.path +from pathlib import Path +from typing import Any + +import pandas as pd +import xarray as xr + +from ref_sample_data.data_request.base import DataRequest + + +class Obs4MIPsRequest(DataRequest): + """ + Represents a Obs4MIPs dataset request + """ + + obs4mips_path_items = ( + "activity_id", + "institution_id", + "source_id", + "variable_id", + "grid_label", + ) + + obs4mips_filename_paths = ( + "variable_id", + "source_id", + "grid_label", + ) + + def __init__(self, facets: dict[str, Any], remove_ensembles: bool, time_span: tuple[str, str] | None): + self.avail_facets = [ + "activity_id", + "institution_id", + "source_id", + "frequency", + "variable_id", + "grid_label", + "version", + "data_node", + ] + + self.facets = facets + self.remove_ensembles = remove_ensembles + self.time_span = time_span + + super().__init__(remove_ensembles, time_span) + + self.obs4mips_path_items = [ + "activity_id", + "institution_id", + "source_id", + "variable_id", + "grid_label", + ] + + self.obs4mips_filename_paths = [ + "variable_id", + "source_id", + "grid_label", + ] + + assert all(key in self.avail_facets for key in self.obs4mips_path_items), "Error message" + assert all(key in self.avail_facets for key in self.obs4mips_filename_paths), "Error message" + + def decimate_dataset(self, dataset: xr.Dataset, time_span: tuple[str, str] | None) -> xr.Dataset | None: + """ + Downscale the dataset to a smaller size. + + Parameters + ---------- + dataset + The dataset to downscale + time_span + The time span to extract from a dataset + + Returns + ------- + xr.Dataset + The downscaled dataset + """ + has_latlon = "lat" in dataset.dims and "lon" in dataset.dims + has_ij = "i" in dataset.dims and "j" in dataset.dims + + if has_latlon: + assert len(dataset.lat.dims) == 1 and len(dataset.lon.dims) == 1 + result = dataset.interp(lat=dataset.lat[:10], lon=dataset.lon[:10]) + elif has_ij: + # 2d lat/lon grid (generally ocean variables) + # Choose a starting point around the middle of the grid to maximise chance that it has values + # TODO: Be smarter about this? + j_midpoint = len(dataset.j) // 2 + result = dataset.interp(i=dataset.i[:10], j=dataset.j[j_midpoint : j_midpoint + 10]) + else: + raise ValueError("Cannot decimate this grid: too many dimensions") + + if "time" in dataset.dims and time_span is not None: + result = result.sel(time=slice(*time_span)) + if result.time.size == 0: + result = None + + return result + + def generate_filename(self, metadata: pd.Series, ds: xr.Dataset, ds_filename: str) -> Path: + """ + Create the output filename for the dataset. + + Parameters + ---------- + ds + Loaded dataset + + Returns + ------- + The output filename + """ + output_path = ( + Path(os.path.join(*[metadata[item] for item in self.obs4mips_path_items])) + / f"v{metadata['version']}" + ) + if ds_filename.name.split("_")[0] == ds.variable_id: + filename_prefix = "_".join([metadata[item] for item in self.obs4mips_filename_paths]) + else: + filename_prefix = ds_filename.name.split("_")[0] + "_" + filename_prefix += "_".join( + [metadata[item] for item in self.obs4mips_filename_paths if item != "variable_id"] + ) + + if "time" in ds.dims: + time_range = ( + f"{ds.time.min().dt.strftime('%Y%m').item()}-{ds.time.max().dt.strftime('%Y%m').item()}" + ) + filename = f"{filename_prefix}_{time_range}.nc" + else: + filename = f"{filename_prefix}.nc" + + return output_path / filename diff --git a/uv.lock b/uv.lock index 4f771d9a..99e2665b 100644 --- a/uv.lock +++ b/uv.lock @@ -241,47 +241,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/41/e1d85ca3cab0b674e277c8c4f678cf66a91cd2cecf93df94353a606fe0db/cloudpickle-3.1.0-py3-none-any.whl", hash = "sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e", size = 22021 }, ] -[[package]] -name = "cmip-ref-sample-data" -version = "0.3.2" -source = { virtual = "." } -dependencies = [ - { name = "intake-esgf" }, - { name = "matplotlib" }, - { name = "pooch" }, - { name = "scipy" }, - { name = "xarray" }, -] - -[package.dev-dependencies] -dev = [ - { name = "bump-my-version" }, - { name = "liccheck" }, - { name = "pip" }, - { name = "pre-commit" }, - { name = "ruff" }, - { name = "towncrier" }, -] - -[package.metadata] -requires-dist = [ - { name = "intake-esgf" }, - { name = "matplotlib", specifier = ">=3.10.0" }, - { name = "pooch" }, - { name = "scipy", specifier = ">=1.15.0" }, - { name = "xarray", specifier = ">=2024.10.0" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "bump-my-version", specifier = ">=0.29.0" }, - { name = "liccheck", specifier = ">=0.9.2" }, - { name = "pip", specifier = ">=24.3.1" }, - { name = "pre-commit", specifier = ">=3.3.1" }, - { name = "ruff", specifier = ">=0.6.9" }, - { name = "towncrier", specifier = ">=24.8.0" }, -] - [[package]] name = "colorama" version = "0.4.6" @@ -1478,6 +1437,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ad/3f/11dd4cd4f39e05128bfd20138faea57bec56f9ffba6185d276e3107ba5b2/questionary-2.1.0-py3-none-any.whl", hash = "sha256:44174d237b68bc828e4878c763a9ad6790ee61990e0ae72927694ead57bab8ec", size = 36747 }, ] +[[package]] +name = "ref-sample-data" +version = "0.3.2" +source = { editable = "." } +dependencies = [ + { name = "intake-esgf" }, + { name = "matplotlib" }, + { name = "pooch" }, + { name = "scipy" }, + { name = "typer" }, + { name = "xarray" }, +] + +[package.dev-dependencies] +dev = [ + { name = "bump-my-version" }, + { name = "liccheck" }, + { name = "pip" }, + { name = "pre-commit" }, + { name = "ruff" }, + { name = "towncrier" }, +] + +[package.metadata] +requires-dist = [ + { name = "intake-esgf" }, + { name = "matplotlib", specifier = ">=3.10.0" }, + { name = "pooch" }, + { name = "scipy", specifier = ">=1.15.0" }, + { name = "typer", specifier = ">=0.15.1" }, + { name = "xarray", specifier = ">=2024.10.0" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "bump-my-version", specifier = ">=0.29.0" }, + { name = "liccheck", specifier = ">=0.9.2" }, + { name = "pip", specifier = ">=24.3.1" }, + { name = "pre-commit", specifier = ">=3.3.1" }, + { name = "ruff", specifier = ">=0.6.9" }, + { name = "towncrier", specifier = ">=24.8.0" }, +] + [[package]] name = "requests" version = "2.32.3" @@ -1605,6 +1607,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/23/8146aad7d88f4fcb3a6218f41a60f6c2d4e3a72de72da1825dc7c8f7877c/semantic_version-2.10.0-py2.py3-none-any.whl", hash = "sha256:de78a3b8e0feda74cabc54aab2da702113e33ac9d9eb9d2389bcf1f58b7d9177", size = 15552 }, ] +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755 }, +] + [[package]] name = "six" version = "1.16.0" @@ -1704,6 +1715,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 }, ] +[[package]] +name = "typer" +version = "0.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cb/ce/dca7b219718afd37a0068f4f2530a727c2b74a8b6e8e0c0080a4c0de4fcd/typer-0.15.1.tar.gz", hash = "sha256:a0588c0a7fa68a1978a069818657778f86abe6ff5ea6abf472f940a08bfe4f0a", size = 99789 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/cc/0a838ba5ca64dc832aa43f727bd586309846b0ffb2ce52422543e6075e8a/typer-0.15.1-py3-none-any.whl", hash = "sha256:7994fb7b8155b64d3402518560648446072864beefd44aa2dc36972a5972e847", size = 44908 }, +] + [[package]] name = "typing-extensions" version = "4.12.2"