From eec5c2e5cbb6f7312829010c9a3b5737a17ef379 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 17 Mar 2026 14:12:06 +0100 Subject: [PATCH 01/21] Patch simple_UV_dataset metadata --- src/parcels/_datasets/structured/generated.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parcels/_datasets/structured/generated.py b/src/parcels/_datasets/structured/generated.py index c0477828b..4ee9eaf3a 100644 --- a/src/parcels/_datasets/structured/generated.py +++ b/src/parcels/_datasets/structured/generated.py @@ -34,6 +34,7 @@ def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"): cf_role="grid_topology", topology_dimension=2, node_dimensions=("XG", "YG"), + node_coordinates=("lon", "lat"), face_dimensions=( FaceNodePadding("XC", "XG", Padding.LOW), FaceNodePadding("YC", "YG", Padding.LOW), From e02e9e6ec766594c3cb1d9d9edce57af3357a35e Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 18 May 2026 14:05:54 +0200 Subject: [PATCH 02/21] Add SGRID Xarray index From ee414c3d13c1a32ebfd5b26be4247c51888e34f1 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 18 May 2026 14:49:14 +0200 Subject: [PATCH 03/21] Move sgrid files From 0f24d47774997bf934b95cfdfee3c5fc3b2e2d63 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 18 May 2026 14:57:25 +0200 Subject: [PATCH 04/21] Update sgrid API --- src/parcels/_core/fieldset.py | 2 +- src/parcels/_core/utils/sgrid.py | 36 ++++++++++++++++---------------- tests/utils/test_sgrid.py | 14 ++++++------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index a44472e00..1da2f25ed 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -280,7 +280,7 @@ def from_sgrid_conventions( ds = ds.rename({time_dim: "time"}) # Parse SGRID metadata and get xgcm kwargs - _, xgcm_kwargs = sgrid.parse_sgrid(ds) + _, xgcm_kwargs = sgrid.xgcm_parse_sgrid(ds) # Add time axis to xgcm_kwargs if present if "time" in ds.dims: diff --git a/src/parcels/_core/utils/sgrid.py b/src/parcels/_core/utils/sgrid.py index ed8edce60..bac4f40c8 100644 --- a/src/parcels/_core/utils/sgrid.py +++ b/src/parcels/_core/utils/sgrid.py @@ -22,7 +22,7 @@ from parcels._python import repr_from_dunder_dict -RE_FACE_NODE_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)" +_RE_FACE_NODE_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)" Dim = str @@ -47,7 +47,7 @@ class Padding(enum.Enum): } -class AttrsSerializable(Protocol): +class _AttrsSerializable(Protocol): def to_attrs(self) -> dict[str, str | int]: ... @classmethod @@ -78,7 +78,7 @@ def from_attrs(cls, d: dict[str, Hashable]) -> Self: ... } -class SGrid2DMetadata(AttrsSerializable): +class SGrid2DMetadata(_AttrsSerializable): def __init__( self, cf_role: Literal["grid_topology"], @@ -164,8 +164,8 @@ def from_attrs(cls, attrs): # type: ignore[override] topology_dimension=attrs["topology_dimension"], node_dimensions=cast(tuple[Dim, Dim], load_mappings(attrs["node_dimensions"])), face_dimensions=cast(tuple[FaceNodePadding, FaceNodePadding], load_mappings(attrs["face_dimensions"])), - node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")), - vertical_dimensions=maybe_load_mappings(attrs.get("vertical_dimensions")), + node_coordinates=_maybe_load_mappings(attrs.get("node_coordinates")), + vertical_dimensions=_maybe_load_mappings(attrs.get("vertical_dimensions")), ) except Exception as e: raise SGridParsingException(f"Failed to parse Grid2DMetadata from {attrs=!r}") from e @@ -200,7 +200,7 @@ def get_value_by_id(self, id: str) -> Dim | Padding: return _ID_FETCHERS_GRID2DMETADATA[id](self) -class SGrid3DMetadata(AttrsSerializable): +class SGrid3DMetadata(_AttrsSerializable): def __init__( self, cf_role: Literal["grid_topology"], @@ -282,7 +282,7 @@ def from_attrs(cls, attrs): # type: ignore[override] volume_dimensions=cast( tuple[FaceNodePadding, FaceNodePadding, FaceNodePadding], load_mappings(attrs["volume_dimensions"]) ), - node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")), + node_coordinates=_maybe_load_mappings(attrs.get("node_coordinates")), ) except Exception as e: raise SGridParsingException(f"Failed to parse Grid3DMetadata from {attrs=!r}") from e @@ -336,7 +336,7 @@ def __str__(self) -> str: @classmethod def load(cls, s: str) -> Self: - match = re.match(RE_FACE_NODE_PADDING, s) + match = re.match(_RE_FACE_NODE_PADDING, s) if not match: raise ValueError(f"String {s!r} does not match expected format for FaceNodePadding") face = match.group(1) @@ -359,12 +359,12 @@ def dump_mappings(parts: Iterable[FaceNodePadding | Dim]) -> str: @overload -def maybe_dump_mappings(parts: None) -> None: ... +def _maybe_dump_mappings(parts: None) -> None: ... @overload -def maybe_dump_mappings(parts: Iterable[FaceNodePadding | Dim]) -> str: ... +def _maybe_dump_mappings(parts: Iterable[FaceNodePadding | Dim]) -> str: ... -def maybe_dump_mappings(parts): +def _maybe_dump_mappings(parts): if parts is None: return None return dump_mappings(parts) @@ -383,7 +383,7 @@ def load_mappings(s: str) -> tuple[FaceNodePadding | Dim, ...]: ret = [] while s: # find next part - match = re.match(RE_FACE_NODE_PADDING, s) + match = re.match(_RE_FACE_NODE_PADDING, s) if match and match.start() == 0: # match found at start, take that as next part part = match.group(0) @@ -414,12 +414,12 @@ def load_mappings(s: str) -> tuple[FaceNodePadding | Dim, ...]: @overload -def maybe_load_mappings(s: None) -> None: ... +def _maybe_load_mappings(s: None) -> None: ... @overload -def maybe_load_mappings(s: Hashable) -> tuple[FaceNodePadding | Dim, ...]: ... +def _maybe_load_mappings(s: str) -> tuple[FaceNodePadding | Dim, ...]: ... -def maybe_load_mappings(s): +def _maybe_load_mappings(s): if s is None: return None return load_mappings(s) @@ -453,7 +453,7 @@ def get_grid_topology(ds: xr.Dataset) -> xr.DataArray | None: return None -def parse_sgrid(ds: xr.Dataset): +def xgcm_parse_sgrid(ds: xr.Dataset): # Function similar to that provided in `xgcm.metadata_parsers. # Might at some point be upstreamed to xgcm directly try: @@ -499,7 +499,7 @@ def rename(ds: xr.Dataset, name_dict: dict[str, str]) -> xr.Dataset: return ds -def get_unique_names(grid: SGrid2DMetadata | SGrid3DMetadata) -> set[str]: +def _get_unique_names(grid: SGrid2DMetadata | SGrid3DMetadata) -> set[str]: dims = set() dims.update(set(grid.node_dimensions)) @@ -724,7 +724,7 @@ def _metadata_rename(grid, names_dict): names_dict = names_dict.copy() assert len(names_dict) == len(set(names_dict.values())), "names_dict contains duplicate target dimension names" - existing_names = get_unique_names(grid) + existing_names = _get_unique_names(grid) for name in names_dict.keys(): if name not in existing_names: raise ValueError(f"Name {name!r} not found in names defined in SGrid metadata {existing_names!r}") diff --git a/tests/utils/test_sgrid.py b/tests/utils/test_sgrid.py index 3b8292929..1f7fa9fc3 100644 --- a/tests/utils/test_sgrid.py +++ b/tests/utils/test_sgrid.py @@ -89,7 +89,7 @@ def dummy_sgrid_2d_ds(grid: sgrid.SGrid2DMetadata) -> xr.Dataset: ds = dummy_comodo_3d_ds() # Can't rename dimensions that already exist in the dataset - assume(sgrid.get_unique_names(grid) & set(ds.dims) == set()) + assume(sgrid._get_unique_names(grid) & set(ds.dims) == set()) renamings = {} if grid.vertical_dimensions is None: @@ -114,7 +114,7 @@ def dummy_sgrid_3d_ds(grid: sgrid.SGrid3DMetadata) -> xr.Dataset: ds = dummy_comodo_3d_ds() # Can't rename dimensions that already exist in the dataset - assume(sgrid.get_unique_names(grid) & set(ds.dims) == set()) + assume(sgrid._get_unique_names(grid) & set(ds.dims) == set()) renamings = {} for old, new in zip(["XG", "YG", "ZG"], grid.node_dimensions, strict=True): @@ -242,7 +242,7 @@ def test_grid3Dmetadata_standard_names(grid: sgrid.SGrid3DMetadata): @given(pst.sgrid.grid_metadata) -def test_parse_grid_attrs(grid: sgrid.AttrsSerializable): +def test_parse_grid_attrs(grid: sgrid._AttrsSerializable): attrs = grid.to_attrs() parsed = sgrid.parse_grid_attrs(attrs) assert parsed == grid @@ -254,7 +254,7 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.SGrid2DMetadata): """Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided""" ds = dummy_sgrid_2d_ds(grid_metadata) - _, xgcm_kwargs = sgrid.parse_sgrid(ds) + _, xgcm_kwargs = sgrid.xgcm_parse_sgrid(ds) grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs) for obj, axis in zip(grid_metadata.face_dimensions, ["X", "Y"], strict=True): @@ -276,7 +276,7 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.SGrid3DMetadata): """Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided""" ds = dummy_sgrid_3d_ds(grid_metadata) - ds, xgcm_kwargs = sgrid.parse_sgrid(ds) + ds, xgcm_kwargs = sgrid.xgcm_parse_sgrid(ds) grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs) for obj, axis in zip(grid_metadata.volume_dimensions, ["X", "Y", "Z"], strict=True): @@ -294,12 +294,12 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.SGrid3DMetadata): + [create_example_grid3dmetadata(with_node_coordinates=i) for i in [False, True]], ) def test_rename(grid): - dims = sgrid.get_unique_names(grid) + dims = sgrid._get_unique_names(grid) dims_dict = {dim: f"new_{dim}" for dim in dims} dims_dict_inv = {v: k for k, v in dims_dict.items()} grid_new = grid.rename(dims_dict) - assert dims & set(sgrid.get_unique_names(grid_new)) == set() + assert dims & set(sgrid._get_unique_names(grid_new)) == set() assert grid == grid_new.rename(dims_dict_inv) From 17e4ff82d507c448498771a6b0e03ae6ded3d23a Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 18 May 2026 15:02:03 +0200 Subject: [PATCH 05/21] Move sgrid file to own subpackage --- src/parcels/_core/fieldset.py | 2 +- src/parcels/_datasets/structured/generated.py | 17 ++++------ src/parcels/_datasets/structured/generic.py | 34 +++++++------------ .../_datasets/structured/strategies.py | 5 ++- src/parcels/_sgrid/__init__.py | 27 +++++++++++++++ src/parcels/_sgrid/accessor.py | 0 .../{_core/utils/sgrid.py => _sgrid/core.py} | 0 src/parcels/_sgrid/index.py | 0 src/parcels/_strategies/sgrid.py | 2 +- src/parcels/convert.py | 2 +- tests/datasets/test_strategies.py | 11 +++--- tests/test_convert.py | 2 +- tests/utils/test_sgrid.py | 19 ++++++----- 13 files changed, 67 insertions(+), 54 deletions(-) create mode 100644 src/parcels/_sgrid/__init__.py create mode 100644 src/parcels/_sgrid/accessor.py rename src/parcels/{_core/utils/sgrid.py => _sgrid/core.py} (100%) create mode 100644 src/parcels/_sgrid/index.py diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 1da2f25ed..96079df1f 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -10,8 +10,8 @@ import xarray as xr import xgcm +import parcels._sgrid as sgrid from parcels._core.field import Field, VectorField -from parcels._core.utils import sgrid from parcels._core.utils.string import _assert_str_and_python_varname from parcels._core.utils.time import get_datetime_type_calendar from parcels._core.utils.time import is_compatible as datetime_is_compatible diff --git a/src/parcels/_datasets/structured/generated.py b/src/parcels/_datasets/structured/generated.py index 4ee9eaf3a..665bcaf0f 100644 --- a/src/parcels/_datasets/structured/generated.py +++ b/src/parcels/_datasets/structured/generated.py @@ -3,12 +3,7 @@ import numpy as np import xarray as xr -from parcels._core.utils.sgrid import ( - FaceNodePadding, - Padding, - SGrid2DMetadata, - _attach_sgrid_metadata, -) +import parcels._sgrid as sgrid from parcels._core.utils.time import timedelta_to_float @@ -29,17 +24,17 @@ def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"): "lon": (["XG"], np.linspace(-max_lon, max_lon, dims[3]), {"axis": "X", "c_grid_axis_shift": -0.5}), }, ).pipe( - _attach_sgrid_metadata, - SGrid2DMetadata( + sgrid._attach_sgrid_metadata, + sgrid.SGrid2DMetadata( cf_role="grid_topology", topology_dimension=2, node_dimensions=("XG", "YG"), node_coordinates=("lon", "lat"), face_dimensions=( - FaceNodePadding("XC", "XG", Padding.LOW), - FaceNodePadding("YC", "YG", Padding.LOW), + sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.LOW), + sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.LOW), ), - vertical_dimensions=(FaceNodePadding("ZC", "depth", Padding.BOTH),), + vertical_dimensions=(sgrid.FaceNodePadding("ZC", "depth", sgrid.Padding.BOTH),), ), ) diff --git a/src/parcels/_datasets/structured/generic.py b/src/parcels/_datasets/structured/generic.py index c26d8e9ce..8b4958b0e 100644 --- a/src/parcels/_datasets/structured/generic.py +++ b/src/parcels/_datasets/structured/generic.py @@ -1,15 +1,7 @@ import numpy as np import xarray as xr -from parcels._core.utils.sgrid import ( - FaceNodePadding, - Padding, - SGrid2DMetadata, - _attach_sgrid_metadata, -) -from parcels._core.utils.sgrid import ( - rename as sgrid_rename, -) +import parcels._sgrid as sgrid from . import T, X, Y, Z @@ -248,42 +240,42 @@ def _unrolled_cone_curvilinear_grid(): "ds_2d_padded_high": ( datasets["ds_2d_left"] .pipe( - _attach_sgrid_metadata, - SGrid2DMetadata( + sgrid._attach_sgrid_metadata, + sgrid.SGrid2DMetadata( cf_role="grid_topology", topology_dimension=2, node_dimensions=("XG", "YG"), face_dimensions=( - FaceNodePadding("XC", "XG", Padding.HIGH), - FaceNodePadding("YC", "YG", Padding.HIGH), + sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.HIGH), ), node_coordinates=("lon", "lat"), - vertical_dimensions=(FaceNodePadding("ZC", "ZG", Padding.HIGH),), + vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.HIGH),), ), ) .pipe( - sgrid_rename, + sgrid.rename, _COMODO_TO_2D_SGRID, ) ), "ds_2d_padded_low": ( datasets["ds_2d_right"] .pipe( - _attach_sgrid_metadata, - SGrid2DMetadata( + sgrid._attach_sgrid_metadata, + sgrid.SGrid2DMetadata( cf_role="grid_topology", topology_dimension=2, node_dimensions=("XG", "YG"), face_dimensions=( - FaceNodePadding("XC", "XG", Padding.LOW), - FaceNodePadding("YC", "YG", Padding.LOW), + sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.LOW), + sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.LOW), ), node_coordinates=("lon", "lat"), - vertical_dimensions=(FaceNodePadding("ZC", "ZG", Padding.LOW),), + vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.LOW),), ), ) .pipe( - sgrid_rename, + sgrid.rename, _COMODO_TO_2D_SGRID, ) ), diff --git a/src/parcels/_datasets/structured/strategies.py b/src/parcels/_datasets/structured/strategies.py index c66dd0916..a001db76b 100644 --- a/src/parcels/_datasets/structured/strategies.py +++ b/src/parcels/_datasets/structured/strategies.py @@ -3,9 +3,8 @@ from hypothesis import strategies as st from hypothesis.extra.numpy import arrays as np_arrays +import parcels._sgrid as sgrid import parcels._strategies as pst -from parcels._core.utils import sgrid -from parcels._core.utils.sgrid import _attach_sgrid_metadata def _face_size(node_size: int, padding: sgrid.Padding) -> int: @@ -90,4 +89,4 @@ def sgrid_dataset(draw, grid: sgrid.SGrid2DMetadata | None = None) -> xr.Dataset } ds = xr.Dataset(data_vars=data_vars, coords=coords) - return _attach_sgrid_metadata(ds, grid) + return sgrid._attach_sgrid_metadata(ds, grid) diff --git a/src/parcels/_sgrid/__init__.py b/src/parcels/_sgrid/__init__.py new file mode 100644 index 000000000..e3a704898 --- /dev/null +++ b/src/parcels/_sgrid/__init__.py @@ -0,0 +1,27 @@ +from parcels._sgrid.core import ( + FaceNodePadding, + Padding, + SGrid2DMetadata, + SGrid3DMetadata, + _attach_sgrid_metadata, + dump_mappings, + get_grid_topology, + load_mappings, + parse_grid_attrs, + rename, + xgcm_parse_sgrid, +) + +__all__ = [ + "FaceNodePadding", + "Padding", + "SGrid2DMetadata", + "SGrid3DMetadata", + "_attach_sgrid_metadata", + "dump_mappings", + "get_grid_topology", + "load_mappings", + "parse_grid_attrs", + "rename", + "xgcm_parse_sgrid", +] diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/parcels/_core/utils/sgrid.py b/src/parcels/_sgrid/core.py similarity index 100% rename from src/parcels/_core/utils/sgrid.py rename to src/parcels/_sgrid/core.py diff --git a/src/parcels/_sgrid/index.py b/src/parcels/_sgrid/index.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/parcels/_strategies/sgrid.py b/src/parcels/_strategies/sgrid.py index f413ab18b..010b23ed8 100644 --- a/src/parcels/_strategies/sgrid.py +++ b/src/parcels/_strategies/sgrid.py @@ -11,7 +11,7 @@ import xarray.testing.strategies as xr_st from hypothesis import strategies as st -from parcels._core.utils import sgrid +import parcels._sgrid as sgrid padding = st.sampled_from(sgrid.Padding) dimension_name = xr_st.names().filter( diff --git a/src/parcels/convert.py b/src/parcels/convert.py index 600262bd6..15ccb8c36 100644 --- a/src/parcels/convert.py +++ b/src/parcels/convert.py @@ -20,7 +20,7 @@ import numpy as np import xarray as xr -from parcels._core.utils import sgrid +import parcels._sgrid as sgrid from parcels._logger import logger if typing.TYPE_CHECKING: diff --git a/tests/datasets/test_strategies.py b/tests/datasets/test_strategies.py index 38083cb54..2079ee86c 100644 --- a/tests/datasets/test_strategies.py +++ b/tests/datasets/test_strategies.py @@ -6,8 +6,7 @@ from hypothesis import given, settings from hypothesis.errors import NonInteractiveExampleWarning -from parcels._core.utils import sgrid -from parcels._core.utils.sgrid import get_grid_topology, parse_grid_attrs +import parcels._sgrid as sgrid from parcels._datasets.structured.strategies import _face_size, sgrid_dataset @@ -50,13 +49,13 @@ def test_sgrid_dataset_returns_dataset(ds): @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_has_grid_topology(ds): - assert get_grid_topology(ds) is not None + assert sgrid.get_grid_topology(ds) is not None @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_node_coordinates_present(ds): - meta = parse_grid_attrs(get_grid_topology(ds).attrs) + meta = sgrid.parse_grid_attrs(sgrid.get_grid_topology(ds).attrs) assert meta.node_coordinates is not None for coord_name in meta.node_coordinates: assert coord_name in ds.coords @@ -65,7 +64,7 @@ def test_sgrid_dataset_node_coordinates_present(ds): @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_coordinate_shapes(ds): - meta = parse_grid_attrs(get_grid_topology(ds).attrs) + meta = sgrid.parse_grid_attrs(sgrid.get_grid_topology(ds).attrs) coord_name1, coord_name2 = meta.node_coordinates node_dim1, node_dim2 = meta.node_dimensions coord1 = ds.coords[coord_name1] @@ -86,7 +85,7 @@ def test_sgrid_dataset_has_at_least_one_field(ds): @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_field_dims_are_valid(ds): - meta = parse_grid_attrs(get_grid_topology(ds).attrs) + meta = sgrid.parse_grid_attrs(sgrid.get_grid_topology(ds).attrs) valid_dims = set(meta.node_dimensions) valid_dims.add(meta.face_dimensions[0].face) valid_dims.add(meta.face_dimensions[1].face) diff --git a/tests/test_convert.py b/tests/test_convert.py index 88bac2c24..137ad7978 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -3,10 +3,10 @@ import xarray as xr import parcels +import parcels._sgrid as sgrid import parcels.convert as convert import parcels.tutorial from parcels import FieldSet -from parcels._core.utils import sgrid from parcels._datasets.remote import open_remote_dataset from parcels._datasets.structured.circulation_models import datasets as datasets_circulation_models from parcels.interpolators._xinterpolators import _get_offsets_dictionary diff --git a/tests/utils/test_sgrid.py b/tests/utils/test_sgrid.py index 1f7fa9fc3..de2294194 100644 --- a/tests/utils/test_sgrid.py +++ b/tests/utils/test_sgrid.py @@ -6,8 +6,9 @@ import xgcm from hypothesis import assume, example, given +import parcels._sgrid as sgrid import parcels._strategies as pst -from parcels._core.utils import sgrid +from parcels._sgrid.core import SGRID_PADDING_TO_XGCM_POSITION, _get_unique_names def create_example_grid2dmetadata(with_vertical_dimensions: bool, with_node_coordinates: bool): @@ -89,7 +90,7 @@ def dummy_sgrid_2d_ds(grid: sgrid.SGrid2DMetadata) -> xr.Dataset: ds = dummy_comodo_3d_ds() # Can't rename dimensions that already exist in the dataset - assume(sgrid._get_unique_names(grid) & set(ds.dims) == set()) + assume(_get_unique_names(grid) & set(ds.dims) == set()) renamings = {} if grid.vertical_dimensions is None: @@ -114,7 +115,7 @@ def dummy_sgrid_3d_ds(grid: sgrid.SGrid3DMetadata) -> xr.Dataset: ds = dummy_comodo_3d_ds() # Can't rename dimensions that already exist in the dataset - assume(sgrid._get_unique_names(grid) & set(ds.dims) == set()) + assume(_get_unique_names(grid) & set(ds.dims) == set()) renamings = {} for old, new in zip(["XG", "YG", "ZG"], grid.node_dimensions, strict=True): @@ -242,7 +243,7 @@ def test_grid3Dmetadata_standard_names(grid: sgrid.SGrid3DMetadata): @given(pst.sgrid.grid_metadata) -def test_parse_grid_attrs(grid: sgrid._AttrsSerializable): +def test_parse_grid_attrs(grid): attrs = grid.to_attrs() parsed = sgrid.parse_grid_attrs(attrs) assert parsed == grid @@ -260,7 +261,7 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.SGrid2DMetadata): for obj, axis in zip(grid_metadata.face_dimensions, ["X", "Y"], strict=True): coords = grid.axes[axis].coords assert coords["center"] == obj.face - assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node + assert coords[SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node if grid_metadata.vertical_dimensions is None: assert "Z" not in grid.axes @@ -268,7 +269,7 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.SGrid2DMetadata): obj = grid_metadata.vertical_dimensions[0] coords = grid.axes["Z"].coords assert coords["center"] == obj.face - assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node + assert coords[SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node @given(pst.sgrid.grid3Dmetadata()) @@ -282,7 +283,7 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.SGrid3DMetadata): for obj, axis in zip(grid_metadata.volume_dimensions, ["X", "Y", "Z"], strict=True): coords = grid.axes[axis].coords assert coords["center"] == obj.face - assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node + assert coords[SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node @pytest.mark.parametrize( @@ -294,12 +295,12 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.SGrid3DMetadata): + [create_example_grid3dmetadata(with_node_coordinates=i) for i in [False, True]], ) def test_rename(grid): - dims = sgrid._get_unique_names(grid) + dims = _get_unique_names(grid) dims_dict = {dim: f"new_{dim}" for dim in dims} dims_dict_inv = {v: k for k, v in dims_dict.items()} grid_new = grid.rename(dims_dict) - assert dims & set(sgrid._get_unique_names(grid_new)) == set() + assert dims & set(_get_unique_names(grid_new)) == set() assert grid == grid_new.rename(dims_dict_inv) From 2f1104f7b6acb2cc7be8fa574931815a021e2b7d Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 18 May 2026 16:47:00 +0200 Subject: [PATCH 06/21] Move test_sgrid.py --- tests/{utils => sgrid}/test_sgrid.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{utils => sgrid}/test_sgrid.py (100%) diff --git a/tests/utils/test_sgrid.py b/tests/sgrid/test_sgrid.py similarity index 100% rename from tests/utils/test_sgrid.py rename to tests/sgrid/test_sgrid.py From 116e162b90a2dae33bd7d5fe15f757d7a3abe70c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 18 May 2026 15:33:02 +0200 Subject: [PATCH 07/21] Fix typo --- src/parcels/_strategies/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parcels/_strategies/__init__.py b/src/parcels/_strategies/__init__.py index 6cc66ee5a..15536e269 100644 --- a/src/parcels/_strategies/__init__.py +++ b/src/parcels/_strategies/__init__.py @@ -4,7 +4,7 @@ import hypothesis # noqa: F401 except ImportError as err: err.add_note( - "To use strategies you must have hypothesis installed. Install it from PyPI, Conda, or using your preffered package manager." + "To use strategies you must have hypothesis installed. Install it from PyPI, Conda, or using your preferred package manager." ) raise err From a24caa62c6af213a45a856679f729e844084df89 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 19 May 2026 10:51:14 +0200 Subject: [PATCH 08/21] SGRID xarray accessor tests --- tests/sgrid/test_accessor.py | 40 ++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 tests/sgrid/test_accessor.py diff --git a/tests/sgrid/test_accessor.py b/tests/sgrid/test_accessor.py new file mode 100644 index 000000000..40344c40e --- /dev/null +++ b/tests/sgrid/test_accessor.py @@ -0,0 +1,40 @@ +import hypothesis.strategies as st +import xarray as xr +from hypothesis import given + +import parcels._strategies as pst +from parcels._datasets.structured.strategies import sgrid_dataset +from parcels._sgrid import SGrid2DMetadata + + +@st.composite +def grid_and_dataset(draw) -> tuple[SGrid2DMetadata, xr.Dataset]: + # used only for test_metadata - for all other tests we can simply do `ds.sgrid.metadata` to get the metadata + metadata_2d = draw( + pst.sgrid.grid_metadata.filter( + # parcels can only generate 2D Sgrid datasets, that also have coordinates + lambda meta: isinstance(meta, SGrid2DMetadata) and meta.node_coordinates is not None + ) + ) + ds = draw(sgrid_dataset(metadata_2d)) + return metadata_2d, ds + + +@given(grid_and_dataset()) +def test_metadata(metadata_ds): + metadata = metadata_ds[0] + ds = metadata_ds[1] + parsed_metadata = ds.sgrid.metadata + assert parsed_metadata == metadata + + +@given(selection_axis=st.sampled_from(["X", "Y", "Z"]), ds=sgrid_dataset()) +def test_isel(selection_axis, ds): + # TODO: Add skip if Z but no Z dimension in ds + + # select nodes in axis direction + # assert consistent + + # select edges + # assert consistent + ... From 286e67561536050bfa567066b622bfaedae99c42 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 19 May 2026 11:18:28 +0200 Subject: [PATCH 09/21] Add SGRID xarray accessor - metadata, rename() Adds an xarray accessor for SGRID focussing on `metadata` and `rename()`. Removes old functions which handled this (i.e., get_grid_topology, parse_grid_attrs, and parcels.sgrid.rename(ds,...)), and updates all references --- src/parcels/_core/fieldset.py | 6 +--- src/parcels/_datasets/structured/generic.py | 6 ++-- src/parcels/_sgrid/__init__.py | 10 ++---- src/parcels/_sgrid/accessor.py | 38 +++++++++++++++++++++ src/parcels/_sgrid/core.py | 31 +---------------- src/parcels/_sgrid/index.py | 0 src/parcels/convert.py | 2 +- tests/datasets/test_strategies.py | 8 ++--- tests/sgrid/test_sgrid.py | 12 +++---- tests/test_convert.py | 5 ++- 10 files changed, 58 insertions(+), 60 deletions(-) delete mode 100644 src/parcels/_sgrid/index.py diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 96079df1f..f83df364f 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -475,11 +475,7 @@ def _select_uxinterpolator(da: ux.UxDataArray): # TODO: Refactor later into something like `parcels._metadata.discover(dataset)` helper that can be used to discover important metadata like this. I think this whole metadata handling should be refactored into its own module. def _get_mesh_type_from_sgrid_dataset(ds_sgrid: xr.Dataset) -> Mesh: """Small helper to inspect SGRID metadata and dataset metadata to determine mesh type.""" - grid_da = sgrid.get_grid_topology(ds_sgrid) - if grid_da is None: - raise ValueError("Dataset does not contain SGRID grid topology metadata (cf_role='grid_topology').") - - sgrid_metadata = sgrid.parse_grid_attrs(grid_da.attrs) + sgrid_metadata = ds_sgrid.sgrid.metadata fpoint_x, fpoint_y = sgrid_metadata.node_coordinates diff --git a/src/parcels/_datasets/structured/generic.py b/src/parcels/_datasets/structured/generic.py index 8b4958b0e..98811e293 100644 --- a/src/parcels/_datasets/structured/generic.py +++ b/src/parcels/_datasets/structured/generic.py @@ -253,8 +253,7 @@ def _unrolled_cone_curvilinear_grid(): vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.HIGH),), ), ) - .pipe( - sgrid.rename, + .sgrid.rename( _COMODO_TO_2D_SGRID, ) ), @@ -274,8 +273,7 @@ def _unrolled_cone_curvilinear_grid(): vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.LOW),), ), ) - .pipe( - sgrid.rename, + .sgrid.rename( _COMODO_TO_2D_SGRID, ) ), diff --git a/src/parcels/_sgrid/__init__.py b/src/parcels/_sgrid/__init__.py index e3a704898..8044fc1c6 100644 --- a/src/parcels/_sgrid/__init__.py +++ b/src/parcels/_sgrid/__init__.py @@ -1,14 +1,12 @@ -from parcels._sgrid.core import ( +from .accessor import SgridAccessor +from .core import ( FaceNodePadding, Padding, SGrid2DMetadata, SGrid3DMetadata, _attach_sgrid_metadata, dump_mappings, - get_grid_topology, load_mappings, - parse_grid_attrs, - rename, xgcm_parse_sgrid, ) @@ -17,11 +15,9 @@ "Padding", "SGrid2DMetadata", "SGrid3DMetadata", + "SgridAccessor", "_attach_sgrid_metadata", "dump_mappings", - "get_grid_topology", "load_mappings", - "parse_grid_attrs", - "rename", "xgcm_parse_sgrid", ] diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py index e69de29bb..f3d783325 100644 --- a/src/parcels/_sgrid/accessor.py +++ b/src/parcels/_sgrid/accessor.py @@ -0,0 +1,38 @@ +import xarray as xr + +from .core import SGrid2DMetadata, SGrid3DMetadata, parse_grid_attrs + + +@xr.register_dataset_accessor("sgrid") +class SgridAccessor: + def __init__(self, xarray_obj): + self._ds = xarray_obj + + @property + def metadata(self) -> SGrid2DMetadata | SGrid3DMetadata: + grid_da = self._get_grid_topology() + return parse_grid_attrs(grid_da.attrs) + + def rename(self, name_dict: dict[str, str]) -> xr.Dataset: + """Similar to Xarray's rename functionality - but also updates the SGRID metadata attributes.""" + ds = self._ds.copy() + ds = ds.rename(name_dict) + + grid_da_name = self._get_grid_topology().name + ds[grid_da_name].attrs = self.metadata.rename(name_dict).to_attrs() + return ds + + def _get_grid_topology(self) -> xr.DataArray: + grid_da = None + for var_name in self._ds.variables: + if self._ds[var_name].attrs.get("cf_role") == "grid_topology": + grid_da = self._ds[var_name] + + if grid_da is None: + raise ValueError( + "No variable found in dataset with 'cf_role' attribute set to 'grid_topology'. This doesn't look to be an SGrid dataset - please make your dataset conforms to SGrid conventions https://sgrid.github.io/sgrid/" + ) + return grid_da + + +def assert_metadata_ds_consistency(ds: xr.Dataset, metadata: SGrid2DMetadata | SGrid3DMetadata): ... diff --git a/src/parcels/_sgrid/core.py b/src/parcels/_sgrid/core.py index bac4f40c8..cc69f1a08 100644 --- a/src/parcels/_sgrid/core.py +++ b/src/parcels/_sgrid/core.py @@ -445,24 +445,10 @@ def parse_grid_attrs(attrs: dict[str, Hashable]) -> SGrid2DMetadata | SGrid3DMet return grid -def get_grid_topology(ds: xr.Dataset) -> xr.DataArray | None: - """Extracts grid topology DataArray from an xarray Dataset.""" - for var_name in ds.variables: - if ds[var_name].attrs.get("cf_role") == "grid_topology": - return ds[var_name] - return None - - def xgcm_parse_sgrid(ds: xr.Dataset): # Function similar to that provided in `xgcm.metadata_parsers. # Might at some point be upstreamed to xgcm directly - try: - grid_topology = get_grid_topology(ds) - assert grid_topology is not None, "No grid_topology variable found in dataset" - grid = parse_grid_attrs(grid_topology.attrs) - - except Exception as e: - raise SGridParsingException(f"Error parsing {grid_topology=!r}") from e + grid = ds.sgrid.metadata if isinstance(grid, SGrid2DMetadata): dimensions = grid.face_dimensions + (grid.vertical_dimensions or ()) @@ -484,21 +470,6 @@ def xgcm_parse_sgrid(ds: xr.Dataset): return (ds, {"coords": xgcm_coords}) -def rename(ds: xr.Dataset, name_dict: dict[str, str]) -> xr.Dataset: - grid_da = get_grid_topology(ds) - if grid_da is None: - raise ValueError( - "No variable found in dataset with 'cf_role' attribute set to 'grid_topology'. This doesn't look to be an SGrid dataset - please make your dataset conforms to SGrid conventions." - ) - - ds = ds.rename(name_dict) - - # Update the metadata - grid = parse_grid_attrs(grid_da.attrs) - ds[grid_da.name].attrs = grid.rename(name_dict).to_attrs() - return ds - - def _get_unique_names(grid: SGrid2DMetadata | SGrid3DMetadata) -> set[str]: dims = set() dims.update(set(grid.node_dimensions)) diff --git a/src/parcels/_sgrid/index.py b/src/parcels/_sgrid/index.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/parcels/convert.py b/src/parcels/convert.py index 15ccb8c36..9f044f049 100644 --- a/src/parcels/convert.py +++ b/src/parcels/convert.py @@ -376,7 +376,7 @@ def nemo_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr.Da ds["gphif"].attrs["units"] = "degrees" # Update to use lon and lat for internal naming - ds = sgrid.rename(ds, {"gphif": "lat", "glamf": "lon"}) # TODO: Logging message about rename + ds = ds.sgrid.rename({"gphif": "lat", "glamf": "lon"}) # TODO: Logging message about rename return ds diff --git a/tests/datasets/test_strategies.py b/tests/datasets/test_strategies.py index 2079ee86c..6a5924f71 100644 --- a/tests/datasets/test_strategies.py +++ b/tests/datasets/test_strategies.py @@ -49,13 +49,13 @@ def test_sgrid_dataset_returns_dataset(ds): @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_has_grid_topology(ds): - assert sgrid.get_grid_topology(ds) is not None + ds.sgrid._get_grid_topology() # shouldn't error @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_node_coordinates_present(ds): - meta = sgrid.parse_grid_attrs(sgrid.get_grid_topology(ds).attrs) + meta = ds.sgrid.metadata assert meta.node_coordinates is not None for coord_name in meta.node_coordinates: assert coord_name in ds.coords @@ -64,7 +64,7 @@ def test_sgrid_dataset_node_coordinates_present(ds): @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_coordinate_shapes(ds): - meta = sgrid.parse_grid_attrs(sgrid.get_grid_topology(ds).attrs) + meta = ds.sgrid.metadata coord_name1, coord_name2 = meta.node_coordinates node_dim1, node_dim2 = meta.node_dimensions coord1 = ds.coords[coord_name1] @@ -85,7 +85,7 @@ def test_sgrid_dataset_has_at_least_one_field(ds): @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_field_dims_are_valid(ds): - meta = sgrid.parse_grid_attrs(sgrid.get_grid_topology(ds).attrs) + meta = ds.sgrid.metadata valid_dims = set(meta.node_dimensions) valid_dims.add(meta.face_dimensions[0].face) valid_dims.add(meta.face_dimensions[1].face) diff --git a/tests/sgrid/test_sgrid.py b/tests/sgrid/test_sgrid.py index de2294194..644e8e596 100644 --- a/tests/sgrid/test_sgrid.py +++ b/tests/sgrid/test_sgrid.py @@ -8,7 +8,7 @@ import parcels._sgrid as sgrid import parcels._strategies as pst -from parcels._sgrid.core import SGRID_PADDING_TO_XGCM_POSITION, _get_unique_names +from parcels._sgrid.core import SGRID_PADDING_TO_XGCM_POSITION, _get_unique_names, parse_grid_attrs def create_example_grid2dmetadata(with_vertical_dimensions: bool, with_node_coordinates: bool): @@ -245,7 +245,7 @@ def test_grid3Dmetadata_standard_names(grid: sgrid.SGrid3DMetadata): @given(pst.sgrid.grid_metadata) def test_parse_grid_attrs(grid): attrs = grid.to_attrs() - parsed = sgrid.parse_grid_attrs(attrs) + parsed = parse_grid_attrs(attrs) assert parsed == grid @@ -357,14 +357,14 @@ def test_rename_errors(): ) def test_rename_dataset(ds): # Check renaming works for coordinates - ds_new = sgrid.rename(ds, {"lon": "lon_updated"}) - grid_new = sgrid.parse_grid_attrs(ds_new["grid"].attrs) + ds_new = ds.sgrid.rename({"lon": "lon_updated"}) + grid_new = ds_new.sgrid.metadata assert "lon_updated" in ds_new.coords assert "lon_updated" == grid_new.node_coordinates[0] # Check renaming works for dim - ds_new = sgrid.rename(ds, {"XC": "XC_updated"}) - grid_new = sgrid.parse_grid_attrs(ds_new["grid"].attrs) + ds_new = ds.sgrid.rename({"XC": "XC_updated"}) + grid_new = ds_new.sgrid.metadata assert "XC_updated" in ds_new.dims assert "XC" not in ds_new.dims assert "XC_updated" == grid_new.face_dimensions[0].face diff --git a/tests/test_convert.py b/tests/test_convert.py index 137ad7978..6a05bc960 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -3,7 +3,6 @@ import xarray as xr import parcels -import parcels._sgrid as sgrid import parcels.convert as convert import parcels.tutorial from parcels import FieldSet @@ -35,7 +34,7 @@ def test_nemo_to_sgrid_2d(U, V, coords): # noqa: N803 "vertical_dimensions": "depth_center:depth (padding:high)", } - meta = sgrid.parse_grid_attrs(ds["grid"].attrs) + meta = ds.sgrid.metadata # Assuming that node_dimension1 and node_dimension2 correspond to X and Y respectively # check that U and V are properly defined on the staggered grid @@ -75,7 +74,7 @@ def test_nemo_to_sgrid_with_depth(U, V, depth, coords): # noqa: N803 "vertical_dimensions": "depth_center:depth (padding:high)", } - meta = sgrid.parse_grid_attrs(ds["grid"].attrs) + meta = ds.sgrid.metadata # Assuming that node_dimension1 and node_dimension2 correspond to X and Y respectively # check that U and V are properly defined on the staggered grid From c272eed8b8f7f3aaf85507401fa8cb13e44efed7 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 19 May 2026 12:21:37 +0200 Subject: [PATCH 10/21] Add function assert_metadata_ds_consistency Alongside helpers get_n_nodes and get_n_faces --- .../_datasets/structured/strategies.py | 15 ++---- src/parcels/_sgrid/__init__.py | 4 ++ src/parcels/_sgrid/accessor.py | 54 ++++++++++++++++++- src/parcels/_sgrid/core.py | 22 ++++++++ tests/datasets/test_strategies.py | 15 +----- tests/sgrid/test_accessor.py | 43 +++++++++++++-- tests/sgrid/test_sgrid.py | 8 +++ 7 files changed, 130 insertions(+), 31 deletions(-) diff --git a/src/parcels/_datasets/structured/strategies.py b/src/parcels/_datasets/structured/strategies.py index a001db76b..c317e3985 100644 --- a/src/parcels/_datasets/structured/strategies.py +++ b/src/parcels/_datasets/structured/strategies.py @@ -7,15 +7,6 @@ import parcels._strategies as pst -def _face_size(node_size: int, padding: sgrid.Padding) -> int: - if padding == sgrid.Padding.NONE: - return node_size - 1 - elif padding in (sgrid.Padding.LOW, sgrid.Padding.HIGH): - return node_size - else: # Padding.BOTH - return node_size + 1 - - @st.composite def sgrid_dataset(draw, grid: sgrid.SGrid2DMetadata | None = None) -> xr.Dataset: """Strategy to create Xarray Sgrid datasets for testing""" @@ -32,14 +23,14 @@ def sgrid_dataset(draw, grid: sgrid.SGrid2DMetadata | None = None) -> xr.Dataset node_dim1, node_dim2 = grid.node_dimensions face_dim1 = grid.face_dimensions[0].face face_dim2 = grid.face_dimensions[1].face - N_face = _face_size(N, grid.face_dimensions[0].padding) - M_face = _face_size(M, grid.face_dimensions[1].padding) + N_face = sgrid.get_n_faces(N, grid.face_dimensions[0].padding) + M_face = sgrid.get_n_faces(M, grid.face_dimensions[1].padding) if has_vertical := grid.vertical_dimensions is not None: P = draw(st.integers(min_value=5, max_value=20)) vert_node_dim = grid.vertical_dimensions[0].node vert_face_dim = grid.vertical_dimensions[0].face - P_face = _face_size(P, grid.vertical_dimensions[0].padding) + P_face = sgrid.get_n_faces(P, grid.vertical_dimensions[0].padding) has_curvilinear_grid = draw(st.booleans()) coord_name1, coord_name2 = grid.node_coordinates diff --git a/src/parcels/_sgrid/__init__.py b/src/parcels/_sgrid/__init__.py index 8044fc1c6..f06e03830 100644 --- a/src/parcels/_sgrid/__init__.py +++ b/src/parcels/_sgrid/__init__.py @@ -6,6 +6,8 @@ SGrid3DMetadata, _attach_sgrid_metadata, dump_mappings, + get_n_faces, + get_n_nodes, load_mappings, xgcm_parse_sgrid, ) @@ -18,6 +20,8 @@ "SgridAccessor", "_attach_sgrid_metadata", "dump_mappings", + "get_n_faces", + "get_n_nodes", "load_mappings", "xgcm_parse_sgrid", ] diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py index f3d783325..c5e5001fa 100644 --- a/src/parcels/_sgrid/accessor.py +++ b/src/parcels/_sgrid/accessor.py @@ -1,6 +1,10 @@ +import itertools +from collections.abc import Mapping +from typing import Any + import xarray as xr -from .core import SGrid2DMetadata, SGrid3DMetadata, parse_grid_attrs +from .core import SGrid2DMetadata, SGrid3DMetadata, get_n_faces, parse_grid_attrs @xr.register_dataset_accessor("sgrid") @@ -34,5 +38,51 @@ def _get_grid_topology(self) -> xr.DataArray: ) return grid_da + def isel(self, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs): + if indexers is None: + indexers = {} + + for k, indexer in itertools.chain(indexers.items(), indexers_kwargs.items()): + if not isinstance(indexer, slice): + raise NotImplementedError( + f"sgrid.isel() only works on `slice` objects for the timebeing. Got indexer {indexer!r} for {k!r}" + ) + + _meta = self.metadata + + ... + + +def assert_metadata_ds_consistency(ds: xr.Dataset, metadata: SGrid2DMetadata): + vertical_dimensions = metadata.vertical_dimensions or tuple() + + for obj in itertools.chain(metadata.face_dimensions, vertical_dimensions): + face, node, padding = obj.face, obj.node, obj.padding + + try: + n_nodes = ds.dims[node] + except KeyError: + # node dimension is not in this dataset + continue + + try: + n_faces = ds.dims[face] + except KeyError: + # face dimension is not in this dataset + continue + + expected_n_faces = get_n_faces(n_nodes, padding) + + if expected_n_faces != n_faces: + raise SGridDatasetInconsistency( + f"Node dimension {node!r} has size {n_nodes}, and face dimension {face!r} has size of {n_faces}. " + f"Due to dataset padding of {padding!r}, expected face dimension {face} to actually be size {expected_n_faces}." + ) + + # TODO: Also check on coordinates + + +class SGridDatasetInconsistency(Exception): + """Attached metadata is not compatible with Xarray dataset""" -def assert_metadata_ds_consistency(ds: xr.Dataset, metadata: SGrid2DMetadata | SGrid3DMetadata): ... + ... diff --git a/src/parcels/_sgrid/core.py b/src/parcels/_sgrid/core.py index cc69f1a08..11a853463 100644 --- a/src/parcels/_sgrid/core.py +++ b/src/parcels/_sgrid/core.py @@ -47,6 +47,28 @@ class Padding(enum.Enum): } +def get_n_faces(n_nodes: int, padding: Padding) -> int: + """Get number of faces along a dimension""" + if padding in [Padding.LOW, Padding.HIGH]: + return n_nodes + if padding == Padding.NONE: + return n_nodes - 1 + if padding == Padding.BOTH: + return n_nodes + 1 + raise ValueError(f"Invalid {padding=!r}") + + +def get_n_nodes(n_faces: int, padding: Padding) -> int: + """Get number of nodes along a dimension""" + if padding in [Padding.LOW, Padding.HIGH]: + return n_faces + if padding == Padding.NONE: + return n_faces + 1 + if padding == Padding.BOTH: + return n_faces - 1 + raise ValueError(f"Invalid {padding=!r}") + + class _AttrsSerializable(Protocol): def to_attrs(self) -> dict[str, str | int]: ... diff --git a/tests/datasets/test_strategies.py b/tests/datasets/test_strategies.py index 6a5924f71..32faa1065 100644 --- a/tests/datasets/test_strategies.py +++ b/tests/datasets/test_strategies.py @@ -7,20 +7,7 @@ from hypothesis.errors import NonInteractiveExampleWarning import parcels._sgrid as sgrid -from parcels._datasets.structured.strategies import _face_size, sgrid_dataset - - -@pytest.mark.parametrize( - "n_nodes, padding, n_edges", - [ - (10, sgrid.Padding.NONE, 9), - (10, sgrid.Padding.LOW, 10), - (10, sgrid.Padding.HIGH, 10), - (10, sgrid.Padding.BOTH, 11), - ], -) -def test_face_size(n_nodes, padding, n_edges): - assert _face_size(n_nodes, padding) == n_edges +from parcels._datasets.structured.strategies import sgrid_dataset def test_sgrid_dataset_raises_when_no_node_coordinates(): diff --git a/tests/sgrid/test_accessor.py b/tests/sgrid/test_accessor.py index 40344c40e..34f41c68f 100644 --- a/tests/sgrid/test_accessor.py +++ b/tests/sgrid/test_accessor.py @@ -1,10 +1,12 @@ import hypothesis.strategies as st +import pytest import xarray as xr -from hypothesis import given +from hypothesis import assume, given import parcels._strategies as pst from parcels._datasets.structured.strategies import sgrid_dataset from parcels._sgrid import SGrid2DMetadata +from parcels._sgrid.accessor import SGridDatasetInconsistency, assert_metadata_ds_consistency @st.composite @@ -28,8 +30,43 @@ def test_metadata(metadata_ds): assert parsed_metadata == metadata -@given(selection_axis=st.sampled_from(["X", "Y", "Z"]), ds=sgrid_dataset()) -def test_isel(selection_axis, ds): +@given(sgrid_dataset()) +def test_assert_metadata_ds_consistency(ds): + metadata: SGrid2DMetadata = ds.sgrid.metadata + assert_metadata_ds_consistency(ds, metadata) + + +@given(sgrid_dataset()) +def test_assert_metadata_ds_consistency_dropped_dim(ds): + metadata: SGrid2DMetadata = ds.sgrid.metadata + + # dropping one of the SGRID dimensions is fine + first_face_dim = metadata.face_dimensions[0].face + + assume(first_face_dim in ds.dims) + + ds = ds.isel({first_face_dim: 0}) + assert_metadata_ds_consistency(ds, metadata) + + +@given(ds=sgrid_dataset()) +def test_assert_metadata_ds_consistency_failures(ds): + metadata: SGrid2DMetadata = ds.sgrid.metadata + first_face_dim = metadata.face_dimensions[0].face + + assume(first_face_dim in ds.dims) + + ds = ds.isel({metadata.face_dimensions[0].face: slice(None, -1)}) + + with pytest.raises( + SGridDatasetInconsistency, + match="Node dimension .* has size .*, and face dimension .* has size of .* .* expected face dimension .* to actually be size .*", + ): + assert_metadata_ds_consistency(ds, metadata) + + +@given(selection_axis=st.sampled_from(["X", "Y", "Z"]), ds=sgrid_dataset(), slice_=st.slices(4)) +def test_isel(selection_axis, ds, slice_): # TODO: Add skip if Z but no Z dimension in ds # select nodes in axis direction diff --git a/tests/sgrid/test_sgrid.py b/tests/sgrid/test_sgrid.py index 644e8e596..173613336 100644 --- a/tests/sgrid/test_sgrid.py +++ b/tests/sgrid/test_sgrid.py @@ -1,5 +1,6 @@ import itertools +import hypothesis.strategies as st import numpy as np import pytest import xarray as xr @@ -581,3 +582,10 @@ def test_face_node_padding_to_diagram(face_node_padding: sgrid.FaceNodePadding, actual = face_node_padding.to_diagram() lines = actual.split("\n") assert lines == expected_lines + + +@given(n=st.integers(min_value=1), padding=pst.sgrid.padding) +def test_dim_sizes_roundtrip(n, padding): + n_faces = sgrid.get_n_nodes(n, padding) + n_nodes = sgrid.get_n_faces(n_faces, padding) + assert n_nodes == n From 85d65c86b40f4ca0693b230955255729df6bac09 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 19 May 2026 14:29:35 +0200 Subject: [PATCH 11/21] Improve tests --- tests/sgrid/test_accessor.py | 44 ++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/tests/sgrid/test_accessor.py b/tests/sgrid/test_accessor.py index 34f41c68f..6ded6cab4 100644 --- a/tests/sgrid/test_accessor.py +++ b/tests/sgrid/test_accessor.py @@ -36,27 +36,47 @@ def test_assert_metadata_ds_consistency(ds): assert_metadata_ds_consistency(ds, metadata) -@given(sgrid_dataset()) -def test_assert_metadata_ds_consistency_dropped_dim(ds): +@given(ds=sgrid_dataset(), dim=st.sampled_from(['face_dimension1', "face_dimension2", "vertical_dimension"])) +def test_assert_metadata_ds_consistency_dropped_dim(ds, dim): + # dropping one of the SGRID dimensions still results in a consistent dataset metadata: SGrid2DMetadata = ds.sgrid.metadata - # dropping one of the SGRID dimensions is fine - first_face_dim = metadata.face_dimensions[0].face + if dim == "face_dimension1": + fnp = metadata.face_dimensions[0] + elif dim=="face_dimension2": + fnp = metadata.face_dimensions[1] + elif dim=="vertical_dimension": + assume(metadata.vertical_dimensions is not None) + assert metadata.vertical_dimensions is not None + fnp = metadata.vertical_dimensions[0] + else: + raise ValueError("Unexpected value for dim") - assume(first_face_dim in ds.dims) + assume(fnp.face in ds.dims) - ds = ds.isel({first_face_dim: 0}) + ds = ds.isel({fnp.face: 0}) assert_metadata_ds_consistency(ds, metadata) -@given(ds=sgrid_dataset()) -def test_assert_metadata_ds_consistency_failures(ds): +@given(ds=sgrid_dataset(), dim=st.sampled_from(['face_dimension1', "face_dimension2", "vertical_dimension"])) +def test_assert_metadata_ds_consistency_failures(ds, dim): metadata: SGrid2DMetadata = ds.sgrid.metadata - first_face_dim = metadata.face_dimensions[0].face - - assume(first_face_dim in ds.dims) - ds = ds.isel({metadata.face_dimensions[0].face: slice(None, -1)}) + if dim == "face_dimension1": + fnp = metadata.face_dimensions[0] + elif dim=="face_dimension2": + fnp = metadata.face_dimensions[1] + elif dim=="vertical_dimension": + assume(metadata.vertical_dimensions is not None) + assert metadata.vertical_dimensions is not None + fnp = metadata.vertical_dimensions[0] + else: + raise ValueError("Unexpected value for dim") + + assume(fnp.node in ds.dims) + assume(fnp.face in ds.dims) + + ds = ds.isel({fnp.face: slice(None, -1)}) with pytest.raises( SGridDatasetInconsistency, From e441cc0beee404a780615c3ba98e92e5e5b70057 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 19 May 2026 14:51:20 +0200 Subject: [PATCH 12/21] Add isel method to Xarray accessor --- src/parcels/_python.py | 13 ++++- src/parcels/_sgrid/accessor.py | 96 +++++++++++++++++++++++++++++----- tests/sgrid/test_accessor.py | 48 ++++++++++++----- 3 files changed, 130 insertions(+), 27 deletions(-) diff --git a/src/parcels/_python.py b/src/parcels/_python.py index b7f8b2845..24d1a1b4e 100644 --- a/src/parcels/_python.py +++ b/src/parcels/_python.py @@ -1,6 +1,10 @@ # Generic Python helpers import inspect -from collections.abc import Callable +from collections.abc import Callable, Mapping +from typing import TypeVar + +K = TypeVar("K") +V = TypeVar("V") def isinstance_noimport(obj, class_or_tuple): @@ -39,3 +43,10 @@ def assert_same_function_signature(f: Callable, *, ref: Callable, context: str) raise ValueError( f"Parameter '{param2.name}' has incorrect name. Expected '{param1.name}', got '{param2.name}'" ) + + +def invert_non_unique_mapping(d: Mapping[K, V]) -> Mapping[V, list[K]]: + inv_map = {} + for k, v in d.items(): + inv_map[v] = inv_map.get(v, []) + [k] + return inv_map diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py index c5e5001fa..90485cad2 100644 --- a/src/parcels/_sgrid/accessor.py +++ b/src/parcels/_sgrid/accessor.py @@ -1,21 +1,26 @@ import itertools -from collections.abc import Mapping -from typing import Any +from collections.abc import Mapping, Sequence +from typing import Any, Literal import xarray as xr +from parcels._python import invert_non_unique_mapping + from .core import SGrid2DMetadata, SGrid3DMetadata, get_n_faces, parse_grid_attrs @xr.register_dataset_accessor("sgrid") class SgridAccessor: def __init__(self, xarray_obj): - self._ds = xarray_obj + self._ds: xr.Dataset = xarray_obj @property - def metadata(self) -> SGrid2DMetadata | SGrid3DMetadata: + def metadata(self) -> SGrid2DMetadata: grid_da = self._get_grid_topology() - return parse_grid_attrs(grid_da.attrs) + grid = parse_grid_attrs(grid_da.attrs) + if isinstance(grid, SGrid3DMetadata): + raise NotImplementedError("Support for 3D SGRID metadata not supported.") + return grid def rename(self, name_dict: dict[str, str]) -> xr.Dataset: """Similar to Xarray's rename functionality - but also updates the SGRID metadata attributes.""" @@ -38,19 +43,31 @@ def _get_grid_topology(self) -> xr.DataArray: ) return grid_da - def isel(self, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs): - if indexers is None: - indexers = {} + def isel(self, indexers: Mapping[str, Any] | None = None, **indexers_kwargs): + """TODO: Docstring""" + if indexers_kwargs != {}: + if indexers is not None: + raise ValueError("Cannot provide both positional and keyword argument to .isel .") + indexers = indexers_kwargs + + assert indexers is not None - for k, indexer in itertools.chain(indexers.items(), indexers_kwargs.items()): + for k, indexer in indexers.items(): if not isinstance(indexer, slice): raise NotImplementedError( f"sgrid.isel() only works on `slice` objects for the timebeing. Got indexer {indexer!r} for {k!r}" ) - _meta = self.metadata + metadata = self.metadata + + _assert_not_indexing_along_same_axis(indexers, metadata) + _assert_all_isel_along_axis(indexers.keys(), metadata) - ... + indexers = _complete_isel_indexing(indexers, metadata, self._ds.dims.keys()) + + ds = self._ds.isel(indexers=indexers) + assert_metadata_ds_consistency(ds, metadata) + return ds def assert_metadata_ds_consistency(ds: xr.Dataset, metadata: SGrid2DMetadata): @@ -85,4 +102,59 @@ def assert_metadata_ds_consistency(ds: xr.Dataset, metadata: SGrid2DMetadata): class SGridDatasetInconsistency(Exception): """Attached metadata is not compatible with Xarray dataset""" - ... + pass + + +def _get_dim_to_axis_mapping(grid: SGrid2DMetadata) -> dict[str, Literal["X", "Y", "Z"]]: + fnp_x = grid.face_dimensions[0] + fnp_y = grid.face_dimensions[1] + fnp_z = grid.vertical_dimensions[0] if grid.vertical_dimensions is not None else None + + d = { + fnp_x.node: "X", + fnp_x.face: "X", + fnp_y.node: "Y", + fnp_y.face: "Y", + } + if fnp_z is not None: + d.update({fnp_z.node: "Z", fnp_z.face: "Z"}) + return d + + +def _assert_not_indexing_along_same_axis(indexers: Mapping[Any, Any], metadata: SGrid2DMetadata) -> None: + dim_to_axis = _get_dim_to_axis_mapping(metadata) + indexer_dim_to_axis = {dim: dim_to_axis.get(dim) for dim in indexers} + + indexer_axis_to_dim = invert_non_unique_mapping(indexer_dim_to_axis) + for axis, dims in indexer_axis_to_dim.items(): + if axis is None: + continue + + if len(dims) > 1: + msg = f"Dims {dims} are on the same axis {axis!r} according to SGRID metadata - cannot simultaneously index along multiple dimensions in the same axis." + raise ValueError(msg) + + +def _assert_all_isel_along_axis(index_dims: Sequence[str], metadata: SGrid2DMetadata): + dim_to_axis = _get_dim_to_axis_mapping(metadata) + for dim in index_dims: + try: + dim_to_axis[dim] + except KeyError as e: + raise ValueError( + f"Cannot use SGRID accessor to .isel non-spatial (/SGRID related) dimension {dim!r}." + ) from e + + +def _complete_isel_indexing( + indexers: Mapping[Any, Any], grid: SGrid2DMetadata, dims_in_dataset: Sequence[str] +) -> Mapping[Any, Any]: + """Copies indexers to the other dataset dimensions defined on the same axis.""" + ret = {} + dim_to_axis = _get_dim_to_axis_mapping(grid) + indexers_by_axis = {dim_to_axis[dim]: indexer for dim, indexer in indexers.items()} + + for dim, axis in dim_to_axis.items(): + if dim in dims_in_dataset and axis in indexers_by_axis: + ret[dim] = indexers_by_axis[axis] + return ret diff --git a/tests/sgrid/test_accessor.py b/tests/sgrid/test_accessor.py index 6ded6cab4..dbabf4f66 100644 --- a/tests/sgrid/test_accessor.py +++ b/tests/sgrid/test_accessor.py @@ -4,6 +4,7 @@ from hypothesis import assume, given import parcels._strategies as pst +from parcels._datasets.structured.generic import datasets_sgrid from parcels._datasets.structured.strategies import sgrid_dataset from parcels._sgrid import SGrid2DMetadata from parcels._sgrid.accessor import SGridDatasetInconsistency, assert_metadata_ds_consistency @@ -36,16 +37,16 @@ def test_assert_metadata_ds_consistency(ds): assert_metadata_ds_consistency(ds, metadata) -@given(ds=sgrid_dataset(), dim=st.sampled_from(['face_dimension1', "face_dimension2", "vertical_dimension"])) +@given(ds=sgrid_dataset(), dim=st.sampled_from(["face_dimension1", "face_dimension2", "vertical_dimension"])) def test_assert_metadata_ds_consistency_dropped_dim(ds, dim): # dropping one of the SGRID dimensions still results in a consistent dataset metadata: SGrid2DMetadata = ds.sgrid.metadata if dim == "face_dimension1": fnp = metadata.face_dimensions[0] - elif dim=="face_dimension2": + elif dim == "face_dimension2": fnp = metadata.face_dimensions[1] - elif dim=="vertical_dimension": + elif dim == "vertical_dimension": assume(metadata.vertical_dimensions is not None) assert metadata.vertical_dimensions is not None fnp = metadata.vertical_dimensions[0] @@ -58,15 +59,15 @@ def test_assert_metadata_ds_consistency_dropped_dim(ds, dim): assert_metadata_ds_consistency(ds, metadata) -@given(ds=sgrid_dataset(), dim=st.sampled_from(['face_dimension1', "face_dimension2", "vertical_dimension"])) +@given(ds=sgrid_dataset(), dim=st.sampled_from(["face_dimension1", "face_dimension2", "vertical_dimension"])) def test_assert_metadata_ds_consistency_failures(ds, dim): metadata: SGrid2DMetadata = ds.sgrid.metadata if dim == "face_dimension1": fnp = metadata.face_dimensions[0] - elif dim=="face_dimension2": + elif dim == "face_dimension2": fnp = metadata.face_dimensions[1] - elif dim=="vertical_dimension": + elif dim == "vertical_dimension": assume(metadata.vertical_dimensions is not None) assert metadata.vertical_dimensions is not None fnp = metadata.vertical_dimensions[0] @@ -85,13 +86,32 @@ def test_assert_metadata_ds_consistency_failures(ds, dim): assert_metadata_ds_consistency(ds, metadata) -@given(selection_axis=st.sampled_from(["X", "Y", "Z"]), ds=sgrid_dataset(), slice_=st.slices(4)) -def test_isel(selection_axis, ds, slice_): - # TODO: Add skip if Z but no Z dimension in ds +@pytest.mark.parametrize("ds", [pytest.param(ds, id=id_) for id_, ds in datasets_sgrid.items()]) +@pytest.mark.parametrize("slice_", [slice(None, None, 3), slice(2, -3)]) +@pytest.mark.parametrize( + "node_dim, face_dim", [("node_dimension1", "face_dimension1"), ("node_dimension2", "face_dimension2")] +) +def test_isel(ds, slice_, node_dim, face_dim): + # TODO: Extend to padding BOTH and NONE by updating datasets_sgrid - # select nodes in axis direction - # assert consistent + assert ds.dims[node_dim] == ds.dims[face_dim] - # select edges - # assert consistent - ... + ds_trimmed = ds.sgrid.isel({node_dim: slice_}) + + assert ds_trimmed.dims[node_dim] == ds_trimmed.dims[face_dim] + + # Assert that other dims haven't been affected + for dim, size_before in ds.dims.items(): + if dim in (node_dim, face_dim): + continue + size_after = ds_trimmed.dims[dim] + assert size_before == size_after + + +@pytest.mark.parametrize("ds", [datasets_sgrid["ds_2d_padded_high"]]) +def test_isel_invalid(ds): + with pytest.raises(ValueError, match="Cannot use SGRID accessor to .isel non-spatial \(/SGRID related\) dimension.*"): + ds.sgrid.isel(time=slice(None)) + + with pytest.raises(ValueError, match="Dims .* are on the same axis .* according to SGRID metadata.*"): + ds.sgrid.isel(node_dimension1=slice(None), face_dimension1=slice(None)) From 889c9adbb548301a72bbff48a099370a2be7dd70 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 19 May 2026 16:34:59 +0200 Subject: [PATCH 13/21] Add test for single isel Relax requirement for slice objects in isel --- src/parcels/_sgrid/accessor.py | 6 ------ tests/sgrid/test_accessor.py | 22 +++++++++++++++++++--- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py index 90485cad2..0ddde1fcf 100644 --- a/src/parcels/_sgrid/accessor.py +++ b/src/parcels/_sgrid/accessor.py @@ -52,12 +52,6 @@ def isel(self, indexers: Mapping[str, Any] | None = None, **indexers_kwargs): assert indexers is not None - for k, indexer in indexers.items(): - if not isinstance(indexer, slice): - raise NotImplementedError( - f"sgrid.isel() only works on `slice` objects for the timebeing. Got indexer {indexer!r} for {k!r}" - ) - metadata = self.metadata _assert_not_indexing_along_same_axis(indexers, metadata) diff --git a/tests/sgrid/test_accessor.py b/tests/sgrid/test_accessor.py index dbabf4f66..29a5ae3ef 100644 --- a/tests/sgrid/test_accessor.py +++ b/tests/sgrid/test_accessor.py @@ -87,16 +87,17 @@ def test_assert_metadata_ds_consistency_failures(ds, dim): @pytest.mark.parametrize("ds", [pytest.param(ds, id=id_) for id_, ds in datasets_sgrid.items()]) -@pytest.mark.parametrize("slice_", [slice(None, None, 3), slice(2, -3)]) +@pytest.mark.parametrize("indexer", [slice(None, None, 3), slice(2, -3), [0]]) @pytest.mark.parametrize( "node_dim, face_dim", [("node_dimension1", "face_dimension1"), ("node_dimension2", "face_dimension2")] ) -def test_isel(ds, slice_, node_dim, face_dim): +def test_isel(ds, indexer, node_dim, face_dim): # TODO: Extend to padding BOTH and NONE by updating datasets_sgrid + # TODO: Expand testing on types of indexers assert ds.dims[node_dim] == ds.dims[face_dim] - ds_trimmed = ds.sgrid.isel({node_dim: slice_}) + ds_trimmed = ds.sgrid.isel({node_dim: indexer}) assert ds_trimmed.dims[node_dim] == ds_trimmed.dims[face_dim] @@ -107,6 +108,21 @@ def test_isel(ds, slice_, node_dim, face_dim): size_after = ds_trimmed.dims[dim] assert size_before == size_after +@pytest.mark.parametrize("ds", [pytest.param(ds, id=id_) for id_, ds in datasets_sgrid.items()]) +def test_isel_drop_dim(ds): + ds = ds.copy() + assert ds.dims["node_dimension1"] == ds.dims["face_dimension1"] + + ds_trimmed = ds.sgrid.isel({"node_dimension1": 0}) + + assert "node_dimension1" not in ds_trimmed.dims + assert "face_dimension1" not in ds_trimmed.dims + + # Assert that other dims haven't been affected + for dim, size_after in ds_trimmed.dims.items(): + size_before = ds.dims[dim] + assert size_before == size_after + @pytest.mark.parametrize("ds", [datasets_sgrid["ds_2d_padded_high"]]) def test_isel_invalid(ds): From 3fba06b1c1d5fae5c0ed1b76f1ed198ff8fd8edc Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 19 May 2026 16:44:49 +0200 Subject: [PATCH 14/21] Fix typechecking --- src/parcels/_python.py | 2 +- src/parcels/_sgrid/accessor.py | 16 +++---- src/parcels/_sgrid/core.py | 80 +++++++++++++++++----------------- tests/sgrid/test_accessor.py | 5 ++- 4 files changed, 53 insertions(+), 50 deletions(-) diff --git a/src/parcels/_python.py b/src/parcels/_python.py index 24d1a1b4e..81db6ade4 100644 --- a/src/parcels/_python.py +++ b/src/parcels/_python.py @@ -46,7 +46,7 @@ def assert_same_function_signature(f: Callable, *, ref: Callable, context: str) def invert_non_unique_mapping(d: Mapping[K, V]) -> Mapping[V, list[K]]: - inv_map = {} + inv_map: dict[V, list[K]] = {} for k, v in d.items(): inv_map[v] = inv_map.get(v, []) + [k] return inv_map diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py index 0ddde1fcf..19d77b8ef 100644 --- a/src/parcels/_sgrid/accessor.py +++ b/src/parcels/_sgrid/accessor.py @@ -1,12 +1,12 @@ import itertools -from collections.abc import Mapping, Sequence -from typing import Any, Literal +from collections.abc import Hashable, Mapping, Sequence +from typing import Any, Literal, cast import xarray as xr from parcels._python import invert_non_unique_mapping -from .core import SGrid2DMetadata, SGrid3DMetadata, get_n_faces, parse_grid_attrs +from .core import FaceNodePadding, SGrid2DMetadata, SGrid3DMetadata, get_n_faces, parse_grid_attrs @xr.register_dataset_accessor("sgrid") @@ -55,9 +55,9 @@ def isel(self, indexers: Mapping[str, Any] | None = None, **indexers_kwargs): metadata = self.metadata _assert_not_indexing_along_same_axis(indexers, metadata) - _assert_all_isel_along_axis(indexers.keys(), metadata) + _assert_all_isel_along_axis(list(indexers.keys()), metadata) - indexers = _complete_isel_indexing(indexers, metadata, self._ds.dims.keys()) + indexers = _complete_isel_indexing(indexers, metadata, list(self._ds.dims.keys())) ds = self._ds.isel(indexers=indexers) assert_metadata_ds_consistency(ds, metadata) @@ -65,7 +65,7 @@ def isel(self, indexers: Mapping[str, Any] | None = None, **indexers_kwargs): def assert_metadata_ds_consistency(ds: xr.Dataset, metadata: SGrid2DMetadata): - vertical_dimensions = metadata.vertical_dimensions or tuple() + vertical_dimensions: tuple[FaceNodePadding, ...] = metadata.vertical_dimensions or tuple() for obj in itertools.chain(metadata.face_dimensions, vertical_dimensions): face, node, padding = obj.face, obj.node, obj.padding @@ -112,7 +112,7 @@ def _get_dim_to_axis_mapping(grid: SGrid2DMetadata) -> dict[str, Literal["X", "Y } if fnp_z is not None: d.update({fnp_z.node: "Z", fnp_z.face: "Z"}) - return d + return cast(dict[str, Literal["X", "Y", "Z"]], d) def _assert_not_indexing_along_same_axis(indexers: Mapping[Any, Any], metadata: SGrid2DMetadata) -> None: @@ -141,7 +141,7 @@ def _assert_all_isel_along_axis(index_dims: Sequence[str], metadata: SGrid2DMeta def _complete_isel_indexing( - indexers: Mapping[Any, Any], grid: SGrid2DMetadata, dims_in_dataset: Sequence[str] + indexers: Mapping[Any, Any], grid: SGrid2DMetadata, dims_in_dataset: Sequence[Hashable] ) -> Mapping[Any, Any]: """Copies indexers to the other dataset dimensions defined on the same axis.""" ret = {} diff --git a/src/parcels/_sgrid/core.py b/src/parcels/_sgrid/core.py index 11a853463..05bec18ad 100644 --- a/src/parcels/_sgrid/core.py +++ b/src/parcels/_sgrid/core.py @@ -76,30 +76,6 @@ def to_attrs(self) -> dict[str, str | int]: ... def from_attrs(cls, d: dict[str, Hashable]) -> Self: ... -# Note that - for some optional attributes in the SGRID spec - these IDs are not available -# hence this isn't full coverage -_ID_FETCHERS_GRID2DMETADATA: dict[str, Callable[[SGrid2DMetadata], Dim | Padding]] = { - "node_dimension1": lambda meta: meta.node_dimensions[0], - "node_dimension2": lambda meta: meta.node_dimensions[1], - "face_dimension1": lambda meta: meta.face_dimensions[0].face, - "face_dimension2": lambda meta: meta.face_dimensions[1].face, - "type1": lambda meta: meta.face_dimensions[0].padding, - "type2": lambda meta: meta.face_dimensions[1].padding, -} - -_ID_FETCHERS_GRID3DMETADATA: dict[str, Callable[[SGrid3DMetadata], Dim | Padding]] = { - "node_dimension1": lambda meta: meta.node_dimensions[0], - "node_dimension2": lambda meta: meta.node_dimensions[1], - "node_dimension3": lambda meta: meta.node_dimensions[2], - "face_dimension1": lambda meta: meta.volume_dimensions[0].face, - "face_dimension2": lambda meta: meta.volume_dimensions[1].face, - "face_dimension3": lambda meta: meta.volume_dimensions[2].face, - "type1": lambda meta: meta.volume_dimensions[0].padding, - "type2": lambda meta: meta.volume_dimensions[1].padding, - "type3": lambda meta: meta.volume_dimensions[2].padding, -} - - class SGrid2DMetadata(_AttrsSerializable): def __init__( self, @@ -179,7 +155,7 @@ def __eq__(self, other: Any) -> bool: return self.to_attrs() == other.to_attrs() @classmethod - def from_attrs(cls, attrs): # type: ignore[override] + def from_attrs(cls, attrs): try: return cls( cf_role=attrs["cf_role"], @@ -193,12 +169,12 @@ def from_attrs(cls, attrs): # type: ignore[override] raise SGridParsingException(f"Failed to parse Grid2DMetadata from {attrs=!r}") from e def to_attrs(self) -> dict[str, str | int]: - d = dict( - cf_role=self.cf_role, - topology_dimension=self.topology_dimension, - node_dimensions=dump_mappings(self.node_dimensions), - face_dimensions=dump_mappings(self.face_dimensions), - ) + d: dict[str, str | int] = { + "cf_role": self.cf_role, + "topology_dimension": self.topology_dimension, + "node_dimensions": dump_mappings(self.node_dimensions), + "face_dimensions": dump_mappings(self.face_dimensions), + } if self.node_coordinates is not None: d["node_coordinates"] = dump_mappings(self.node_coordinates) if self.vertical_dimensions is not None: @@ -295,7 +271,7 @@ def __eq__(self, other: Any) -> bool: return self.to_attrs() == other.to_attrs() @classmethod - def from_attrs(cls, attrs): # type: ignore[override] + def from_attrs(cls, attrs): try: return cls( cf_role=attrs["cf_role"], @@ -310,12 +286,12 @@ def from_attrs(cls, attrs): # type: ignore[override] raise SGridParsingException(f"Failed to parse Grid3DMetadata from {attrs=!r}") from e def to_attrs(self) -> dict[str, str | int]: - d = dict( - cf_role=self.cf_role, - topology_dimension=self.topology_dimension, - node_dimensions=dump_mappings(self.node_dimensions), - volume_dimensions=dump_mappings(self.volume_dimensions), - ) + d: dict[str, str | int] = { + "cf_role": self.cf_role, + "topology_dimension": self.topology_dimension, + "node_dimensions": dump_mappings(self.node_dimensions), + "volume_dimensions": dump_mappings(self.volume_dimensions), + } if self.node_coordinates is not None: d["node_coordinates"] = dump_mappings(self.node_coordinates) return d @@ -337,6 +313,30 @@ def get_value_by_id(self, id: str) -> Dim | Padding: return _ID_FETCHERS_GRID3DMETADATA[id](self) +# Note that - for some optional attributes in the SGRID spec - these IDs are not available +# hence this isn't full coverage +_ID_FETCHERS_GRID2DMETADATA: dict[str, Callable[[SGrid2DMetadata], Dim | Padding]] = { + "node_dimension1": lambda meta: meta.node_dimensions[0], + "node_dimension2": lambda meta: meta.node_dimensions[1], + "face_dimension1": lambda meta: meta.face_dimensions[0].face, + "face_dimension2": lambda meta: meta.face_dimensions[1].face, + "type1": lambda meta: meta.face_dimensions[0].padding, + "type2": lambda meta: meta.face_dimensions[1].padding, +} + +_ID_FETCHERS_GRID3DMETADATA: dict[str, Callable[[SGrid3DMetadata], Dim | Padding]] = { + "node_dimension1": lambda meta: meta.node_dimensions[0], + "node_dimension2": lambda meta: meta.node_dimensions[1], + "node_dimension3": lambda meta: meta.node_dimensions[2], + "face_dimension1": lambda meta: meta.volume_dimensions[0].face, + "face_dimension2": lambda meta: meta.volume_dimensions[1].face, + "face_dimension3": lambda meta: meta.volume_dimensions[2].face, + "type1": lambda meta: meta.volume_dimensions[0].padding, + "type2": lambda meta: meta.volume_dimensions[1].padding, + "type3": lambda meta: meta.volume_dimensions[2].padding, +} + + @dataclass class FaceNodePadding: """A data class representing a face-node-padding triplet for SGrid metadata. @@ -412,8 +412,8 @@ def load_mappings(s: str) -> tuple[FaceNodePadding | Dim, ...]: s_new = s[match.end() :].lstrip() else: # no FaceNodePadding match at start, assume just a Dim until next space - part, *s_new = s.split(" ", 1) - s_new = "".join(s_new) + part, *rest = s.split(" ", 1) + s_new = "".join(rest) assert s != s_new, f"SGrid parsing did not advance, stuck at {s!r}" diff --git a/tests/sgrid/test_accessor.py b/tests/sgrid/test_accessor.py index 29a5ae3ef..7b49add31 100644 --- a/tests/sgrid/test_accessor.py +++ b/tests/sgrid/test_accessor.py @@ -108,6 +108,7 @@ def test_isel(ds, indexer, node_dim, face_dim): size_after = ds_trimmed.dims[dim] assert size_before == size_after + @pytest.mark.parametrize("ds", [pytest.param(ds, id=id_) for id_, ds in datasets_sgrid.items()]) def test_isel_drop_dim(ds): ds = ds.copy() @@ -126,7 +127,9 @@ def test_isel_drop_dim(ds): @pytest.mark.parametrize("ds", [datasets_sgrid["ds_2d_padded_high"]]) def test_isel_invalid(ds): - with pytest.raises(ValueError, match="Cannot use SGRID accessor to .isel non-spatial \(/SGRID related\) dimension.*"): + with pytest.raises( + ValueError, match=r"Cannot use SGRID accessor to .isel non-spatial \(/SGRID related\) dimension.*" + ): ds.sgrid.isel(time=slice(None)) with pytest.raises(ValueError, match="Dims .* are on the same axis .* according to SGRID metadata.*"): From c702f83d1679fba426783fd1e79a8d49be28982c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 20 May 2026 10:25:44 +0200 Subject: [PATCH 15/21] Move rename test --- tests/sgrid/test_accessor.py | 60 ++++++++++++++++++++++++++++++++---- tests/sgrid/test_sgrid.py | 47 ---------------------------- 2 files changed, 54 insertions(+), 53 deletions(-) diff --git a/tests/sgrid/test_accessor.py b/tests/sgrid/test_accessor.py index 7b49add31..3ad6c1c9a 100644 --- a/tests/sgrid/test_accessor.py +++ b/tests/sgrid/test_accessor.py @@ -1,22 +1,23 @@ import hypothesis.strategies as st +import numpy as np import pytest import xarray as xr from hypothesis import assume, given +import parcels._sgrid as sgrid import parcels._strategies as pst from parcels._datasets.structured.generic import datasets_sgrid from parcels._datasets.structured.strategies import sgrid_dataset -from parcels._sgrid import SGrid2DMetadata from parcels._sgrid.accessor import SGridDatasetInconsistency, assert_metadata_ds_consistency @st.composite -def grid_and_dataset(draw) -> tuple[SGrid2DMetadata, xr.Dataset]: +def grid_and_dataset(draw) -> tuple[sgrid.SGrid2DMetadata, xr.Dataset]: # used only for test_metadata - for all other tests we can simply do `ds.sgrid.metadata` to get the metadata metadata_2d = draw( pst.sgrid.grid_metadata.filter( # parcels can only generate 2D Sgrid datasets, that also have coordinates - lambda meta: isinstance(meta, SGrid2DMetadata) and meta.node_coordinates is not None + lambda meta: isinstance(meta, sgrid.SGrid2DMetadata) and meta.node_coordinates is not None ) ) ds = draw(sgrid_dataset(metadata_2d)) @@ -31,16 +32,63 @@ def test_metadata(metadata_ds): assert parsed_metadata == metadata +@pytest.mark.parametrize( + "ds", + [ + xr.Dataset( + { + "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(10, 10, 10, 10)), + "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(10, 10, 10, 10)), + "grid": ( + [], + np.array(0), + sgrid.SGrid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("XG", "YG"), + face_dimensions=( + sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.HIGH), + ), + vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.HIGH),), + node_coordinates=("lon", "lat"), + ).to_attrs(), + ), + }, + coords={ + "lon": (["XG"], 2 * np.pi / 10 * np.arange(0, 10)), + "lat": (["YG"], 2 * np.pi / (10) * np.arange(0, 10)), + "depth": (["ZG"], np.arange(10)), + "time": (["time"], xr.date_range("2000", "2001", 10), {"axis": "T"}), + }, + ), + ], +) +def test_rename_dataset(ds): + # Check renaming works for coordinates + ds_new = ds.sgrid.rename({"lon": "lon_updated"}) + grid_new = ds_new.sgrid.metadata + assert "lon_updated" in ds_new.coords + assert "lon_updated" == grid_new.node_coordinates[0] + + # Check renaming works for dim + ds_new = ds.sgrid.rename({"XC": "XC_updated"}) + grid_new = ds_new.sgrid.metadata + assert "XC_updated" in ds_new.dims + assert "XC" not in ds_new.dims + assert "XC_updated" == grid_new.face_dimensions[0].face + + @given(sgrid_dataset()) def test_assert_metadata_ds_consistency(ds): - metadata: SGrid2DMetadata = ds.sgrid.metadata + metadata: sgrid.SGrid2DMetadata = ds.sgrid.metadata assert_metadata_ds_consistency(ds, metadata) @given(ds=sgrid_dataset(), dim=st.sampled_from(["face_dimension1", "face_dimension2", "vertical_dimension"])) def test_assert_metadata_ds_consistency_dropped_dim(ds, dim): # dropping one of the SGRID dimensions still results in a consistent dataset - metadata: SGrid2DMetadata = ds.sgrid.metadata + metadata: sgrid.SGrid2DMetadata = ds.sgrid.metadata if dim == "face_dimension1": fnp = metadata.face_dimensions[0] @@ -61,7 +109,7 @@ def test_assert_metadata_ds_consistency_dropped_dim(ds, dim): @given(ds=sgrid_dataset(), dim=st.sampled_from(["face_dimension1", "face_dimension2", "vertical_dimension"])) def test_assert_metadata_ds_consistency_failures(ds, dim): - metadata: SGrid2DMetadata = ds.sgrid.metadata + metadata: sgrid.SGrid2DMetadata = ds.sgrid.metadata if dim == "face_dimension1": fnp = metadata.face_dimensions[0] diff --git a/tests/sgrid/test_sgrid.py b/tests/sgrid/test_sgrid.py index 173613336..af9fb22eb 100644 --- a/tests/sgrid/test_sgrid.py +++ b/tests/sgrid/test_sgrid.py @@ -324,53 +324,6 @@ def test_rename_errors(): grid.rename(names_dict) -@pytest.mark.parametrize( - "ds", - [ - xr.Dataset( - { - "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(10, 10, 10, 10)), - "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(10, 10, 10, 10)), - "grid": ( - [], - np.array(0), - sgrid.SGrid2DMetadata( - cf_role="grid_topology", - topology_dimension=2, - node_dimensions=("XG", "YG"), - face_dimensions=( - sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.HIGH), - sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.HIGH), - ), - vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.HIGH),), - node_coordinates=("lon", "lat"), - ).to_attrs(), - ), - }, - coords={ - "lon": (["XG"], 2 * np.pi / 10 * np.arange(0, 10)), - "lat": (["YG"], 2 * np.pi / (10) * np.arange(0, 10)), - "depth": (["ZG"], np.arange(10)), - "time": (["time"], xr.date_range("2000", "2001", 10), {"axis": "T"}), - }, - ), - ], -) -def test_rename_dataset(ds): - # Check renaming works for coordinates - ds_new = ds.sgrid.rename({"lon": "lon_updated"}) - grid_new = ds_new.sgrid.metadata - assert "lon_updated" in ds_new.coords - assert "lon_updated" == grid_new.node_coordinates[0] - - # Check renaming works for dim - ds_new = ds.sgrid.rename({"XC": "XC_updated"}) - grid_new = ds_new.sgrid.metadata - assert "XC_updated" in ds_new.dims - assert "XC" not in ds_new.dims - assert "XC_updated" == grid_new.face_dimensions[0].face - - @pytest.mark.parametrize( ("metadata, expected"), [ From f4c9e91b63657dc1c859dcf6f60072463acddc5c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 20 May 2026 10:40:16 +0200 Subject: [PATCH 16/21] xfail test_parse_sgrid_3d --- tests/sgrid/test_sgrid.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/sgrid/test_sgrid.py b/tests/sgrid/test_sgrid.py index af9fb22eb..54dd56ea3 100644 --- a/tests/sgrid/test_sgrid.py +++ b/tests/sgrid/test_sgrid.py @@ -274,6 +274,7 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.SGrid2DMetadata): @given(pst.sgrid.grid3Dmetadata()) +@pytest.mark.xfail(reason="Parcels doesn't have native support for SGRID 3D grids. This metadata checking is superfluous until we have such support.") def test_parse_sgrid_3d(grid_metadata: sgrid.SGrid3DMetadata): """Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided""" ds = dummy_sgrid_3d_ds(grid_metadata) From a3a3fb34774d30e3c823556c5d45bd0698cc2544 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 20 May 2026 11:45:11 +0200 Subject: [PATCH 17/21] Extend isel to work with padding NONE and BOTH datasets Refactor to avoid calculating the index verbatim. Define NONE and BOTH isel indexing only for contiguous regions. Add more hypothesis testing properties: - P1: consistency invariant (assert_metadata_ds_consistency after any valid isel) - P2: data correctness (values match direct xarray isel) - P3: specification symmetry (isel by node dim == isel by derived face dim) Co-authored-by: Claude --- src/parcels/_datasets/structured/generic.py | 48 ++++++++ src/parcels/_sgrid/accessor.py | 109 ++++++++++++++++-- tests/sgrid/test_accessor.py | 117 ++++++++++++++++++-- tests/sgrid/test_sgrid.py | 4 +- 4 files changed, 255 insertions(+), 23 deletions(-) diff --git a/src/parcels/_datasets/structured/generic.py b/src/parcels/_datasets/structured/generic.py index 98811e293..ff729c8a1 100644 --- a/src/parcels/_datasets/structured/generic.py +++ b/src/parcels/_datasets/structured/generic.py @@ -277,4 +277,52 @@ def _unrolled_cone_curvilinear_grid(): _COMODO_TO_2D_SGRID, ) ), + "ds_2d_padded_none": xr.Dataset( + { + "data_g": (["node_dimension1", "node_dimension2"], np.random.rand(10, 10)), + "data_c": (["face_dimension1", "face_dimension2"], np.random.rand(9, 9)), + "grid": ( + [], + np.array(0), + sgrid.SGrid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("node_dimension1", "node_dimension2"), + face_dimensions=( + sgrid.FaceNodePadding("face_dimension1", "node_dimension1", sgrid.Padding.NONE), + sgrid.FaceNodePadding("face_dimension2", "node_dimension2", sgrid.Padding.NONE), + ), + node_coordinates=("lon", "lat"), + ).to_attrs(), + ), + }, + coords={ + "lon": (["node_dimension1"], np.linspace(0, 1, 10)), + "lat": (["node_dimension2"], np.linspace(0, 1, 10)), + }, + ), + "ds_2d_padded_both": xr.Dataset( + { + "data_g": (["node_dimension1", "node_dimension2"], np.random.rand(10, 10)), + "data_c": (["face_dimension1", "face_dimension2"], np.random.rand(11, 11)), + "grid": ( + [], + np.array(0), + sgrid.SGrid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("node_dimension1", "node_dimension2"), + face_dimensions=( + sgrid.FaceNodePadding("face_dimension1", "node_dimension1", sgrid.Padding.BOTH), + sgrid.FaceNodePadding("face_dimension2", "node_dimension2", sgrid.Padding.BOTH), + ), + node_coordinates=("lon", "lat"), + ).to_attrs(), + ), + }, + coords={ + "lon": (["node_dimension1"], np.linspace(0, 1, 10)), + "lat": (["node_dimension2"], np.linspace(0, 1, 10)), + }, + ), } diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py index 19d77b8ef..4fdbd8b68 100644 --- a/src/parcels/_sgrid/accessor.py +++ b/src/parcels/_sgrid/accessor.py @@ -6,7 +6,7 @@ from parcels._python import invert_non_unique_mapping -from .core import FaceNodePadding, SGrid2DMetadata, SGrid3DMetadata, get_n_faces, parse_grid_attrs +from .core import FaceNodePadding, Padding, SGrid2DMetadata, SGrid3DMetadata, get_n_faces, parse_grid_attrs @xr.register_dataset_accessor("sgrid") @@ -50,14 +50,16 @@ def isel(self, indexers: Mapping[str, Any] | None = None, **indexers_kwargs): raise ValueError("Cannot provide both positional and keyword argument to .isel .") indexers = indexers_kwargs - assert indexers is not None + if indexers is None: + raise ValueError("Must provide indexers either as a positional argument or as keyword arguments.") metadata = self.metadata _assert_not_indexing_along_same_axis(indexers, metadata) _assert_all_isel_along_axis(list(indexers.keys()), metadata) - indexers = _complete_isel_indexing(indexers, metadata, list(self._ds.dims.keys())) + dim_sizes = dict(self._ds.sizes) + indexers = _complete_isel_indexing(indexers, metadata, list(self._ds.dims.keys()), dim_sizes) ds = self._ds.isel(indexers=indexers) assert_metadata_ds_consistency(ds, metadata) @@ -115,6 +117,72 @@ def _get_dim_to_axis_mapping(grid: SGrid2DMetadata) -> dict[str, Literal["X", "Y return cast(dict[str, Literal["X", "Y", "Z"]], d) +def _get_axis_info(grid: SGrid2DMetadata) -> dict[str, tuple[FaceNodePadding, bool]]: + """For each spatial dim, return (FaceNodePadding it belongs to, True if node dim).""" + result: dict[str, tuple[FaceNodePadding, bool]] = {} + all_fnps = list(grid.face_dimensions) + list(grid.vertical_dimensions or []) + for fnp in all_fnps: + result[fnp.node] = (fnp, True) + result[fnp.face] = (fnp, False) + return result + + +def _derive_paired_indexer( + indexer: Any, + indexer_is_node: bool, + padding: Padding, + dim_size: int | None = None, +) -> tuple[Any, Any]: + """Given a user's indexer for one side of a face-node pair, return ``(normalized_user_indexer, paired_indexer)``. + + For HIGH/LOW padding, face and node dims are the same size so both the normalised user + indexer and the paired indexer are identical to the original ``user_indexer``. + For NONE/BOTH padding, the slice is first normalised to non-negative absolute indices (via + ``slice.indices``) and then the stop of the paired indexer is adjusted by ±1. + + ``n_user_dim`` is required for NONE/BOTH padding so that negative starts and ``None`` stops + can be resolved to unambiguous absolute positions. + + Scalar (integer) and list indexers raise for NONE/BOTH because there is no unambiguous + paired position. + + Returns + ------- + tuple[Any, Any] + ``(normalized_user_indexer, paired_indexer)`` — the first element is the user's indexer + after normalisation (unchanged for HIGH/LOW), the second is the derived indexer for the + other side of the face-node pair. + """ + if padding in (Padding.HIGH, Padding.LOW): + return indexer, indexer + + # NONE and BOTH: only slices with step in {None, 1} are supported + if not isinstance(indexer, slice): + raise ValueError( + f"Scalar and list indexers are not supported for NONE/BOTH padding. " + f"Got indexer {indexer!r}. Use a slice instead." + ) + if indexer.step not in (None, 1): + raise ValueError(f"Slices with step != 1 are not supported for NONE/BOTH padding. Got step={indexer.step!r}.") + if dim_size is None: + raise ValueError("dim_size must be provided for NONE/BOTH padding to correctly handle slices.") + + # Normalise to non-negative absolute indices so the arithmetic below is unambiguous. + abs_start, abs_stop, _ = indexer.indices(dim_size) + normalized_user_indexer = slice(abs_start, abs_stop) + + start, stop = abs_start, abs_stop + + # Adjust stop: positive stops reference from the start of the array, so ±1 is needed. + if stop is not None and stop > 0: + if padding == Padding.NONE: + stop = stop - 1 if indexer_is_node else stop + 1 + else: # BOTH + stop = stop + 1 if indexer_is_node else stop - 1 + + return normalized_user_indexer, slice(start, stop) + + def _assert_not_indexing_along_same_axis(indexers: Mapping[Any, Any], metadata: SGrid2DMetadata) -> None: dim_to_axis = _get_dim_to_axis_mapping(metadata) indexer_dim_to_axis = {dim: dim_to_axis.get(dim) for dim in indexers} @@ -141,14 +209,31 @@ def _assert_all_isel_along_axis(index_dims: Sequence[str], metadata: SGrid2DMeta def _complete_isel_indexing( - indexers: Mapping[Any, Any], grid: SGrid2DMetadata, dims_in_dataset: Sequence[Hashable] + indexers: Mapping[Any, Any], + grid: SGrid2DMetadata, + dims_in_dataset: Sequence[Hashable], + dim_sizes: Mapping[str, int], ) -> Mapping[Any, Any]: - """Copies indexers to the other dataset dimensions defined on the same axis.""" - ret = {} - dim_to_axis = _get_dim_to_axis_mapping(grid) - indexers_by_axis = {dim_to_axis[dim]: indexer for dim, indexer in indexers.items()} - - for dim, axis in dim_to_axis.items(): - if dim in dims_in_dataset and axis in indexers_by_axis: - ret[dim] = indexers_by_axis[axis] + """For each user-supplied (dim, indexer), expand to both the face and node dim on that axis, + deriving the paired indexer according to the padding type. + """ + axis_info = _get_axis_info(grid) + ret: dict[Any, Any] = {} + + for user_dim, user_indexer in indexers.items(): + fnp, user_is_node = axis_info[user_dim] + + n_user_dim = dim_sizes.get(user_dim) + normalized_user, paired_indexer = _derive_paired_indexer( + user_indexer, user_is_node, fnp.padding, dim_size=n_user_dim + ) + + node_indexer = normalized_user if user_is_node else paired_indexer + face_indexer = paired_indexer if user_is_node else normalized_user + + if fnp.node in dims_in_dataset: + ret[fnp.node] = node_indexer + if fnp.face in dims_in_dataset: + ret[fnp.face] = face_indexer + return ret diff --git a/tests/sgrid/test_accessor.py b/tests/sgrid/test_accessor.py index 3ad6c1c9a..8a9f9b0ce 100644 --- a/tests/sgrid/test_accessor.py +++ b/tests/sgrid/test_accessor.py @@ -134,20 +134,21 @@ def test_assert_metadata_ds_consistency_failures(ds, dim): assert_metadata_ds_consistency(ds, metadata) -@pytest.mark.parametrize("ds", [pytest.param(ds, id=id_) for id_, ds in datasets_sgrid.items()]) -@pytest.mark.parametrize("indexer", [slice(None, None, 3), slice(2, -3), [0]]) +_SYMMETRIC_DATASETS = {k: v for k, v in datasets_sgrid.items() if k in ("ds_2d_padded_high", "ds_2d_padded_low")} + + +@pytest.mark.parametrize("ds", [pytest.param(ds, id=id_) for id_, ds in _SYMMETRIC_DATASETS.items()]) +@pytest.mark.parametrize("indexer", [slice(None, None, 3), [0]]) @pytest.mark.parametrize( "node_dim, face_dim", [("node_dimension1", "face_dimension1"), ("node_dimension2", "face_dimension2")] ) def test_isel(ds, indexer, node_dim, face_dim): - # TODO: Extend to padding BOTH and NONE by updating datasets_sgrid - # TODO: Expand testing on types of indexers - - assert ds.dims[node_dim] == ds.dims[face_dim] + # Covers HIGH/LOW-specific indexer types (lists, step slices) not exercised by property tests. + metadata = ds.sgrid.metadata ds_trimmed = ds.sgrid.isel({node_dim: indexer}) - assert ds_trimmed.dims[node_dim] == ds_trimmed.dims[face_dim] + assert_metadata_ds_consistency(ds_trimmed, metadata) # Assert that other dims haven't been affected for dim, size_before in ds.dims.items(): @@ -157,17 +158,15 @@ def test_isel(ds, indexer, node_dim, face_dim): assert size_before == size_after -@pytest.mark.parametrize("ds", [pytest.param(ds, id=id_) for id_, ds in datasets_sgrid.items()]) +@pytest.mark.parametrize("ds", [pytest.param(ds, id=id_) for id_, ds in _SYMMETRIC_DATASETS.items()]) def test_isel_drop_dim(ds): ds = ds.copy() - assert ds.dims["node_dimension1"] == ds.dims["face_dimension1"] ds_trimmed = ds.sgrid.isel({"node_dimension1": 0}) assert "node_dimension1" not in ds_trimmed.dims assert "face_dimension1" not in ds_trimmed.dims - # Assert that other dims haven't been affected for dim, size_after in ds_trimmed.dims.items(): size_before = ds.dims[dim] assert size_before == size_after @@ -182,3 +181,101 @@ def test_isel_invalid(ds): with pytest.raises(ValueError, match="Dims .* are on the same axis .* according to SGRID metadata.*"): ds.sgrid.isel(node_dimension1=slice(None), face_dimension1=slice(None)) + + +@st.composite +def valid_isel_slice(draw) -> slice: + """Slice with no step; covers None, positive, and negative start/stop.""" + start = draw(st.one_of(st.none(), st.integers(-25, 25))) + stop = draw(st.one_of(st.none(), st.integers(-25, 25))) + return slice(start, stop) + + +@given( + ds=sgrid_dataset(), + s=valid_isel_slice(), +) +def test_isel_p1_consistency_invariant(ds, s): + """After any valid isel, the SGRID face-node relationship is preserved.""" + metadata = ds.sgrid.metadata + fnp = metadata.face_dimensions[0] + node_dim = fnp.node + assume(node_dim in ds.dims) + n_nodes = ds.dims[node_dim] + assume(len(range(*s.indices(n_nodes))) > 0) # exclude empty selections + + result = ds.sgrid.isel({node_dim: s}) + + assert_metadata_ds_consistency(result, metadata) + + +@given( + ds=sgrid_dataset(), + s=valid_isel_slice(), +) +def test_isel_p2_data_correctness(ds, s): + """Values in the result match those obtained by applying the indexers directly via xarray.""" + metadata = ds.sgrid.metadata + fnp = metadata.face_dimensions[0] + node_dim, face_dim = fnp.node, fnp.face + assume(node_dim in ds.dims) + n_nodes = ds.dims[node_dim] + assume(len(range(*s.indices(n_nodes))) > 0) + + result = ds.sgrid.isel({node_dim: s}) + + from parcels._sgrid.accessor import _derive_paired_indexer + + _, face_indexer = _derive_paired_indexer(s, indexer_is_node=True, padding=fnp.padding, dim_size=n_nodes) + + for var in ds.data_vars: + if var == "grid": + continue + var_dims = ds[var].dims + if node_dim in var_dims and face_dim not in var_dims: + xr.testing.assert_equal(result[var], ds[var].isel({node_dim: s})) + elif face_dim in var_dims and node_dim not in var_dims: + if face_dim in ds.dims: + xr.testing.assert_equal(result[var], ds[var].isel({face_dim: face_indexer})) + + +@given( + ds=sgrid_dataset(), + s=valid_isel_slice(), +) +def test_isel_p3_specification_symmetry(ds, s): + """isel(node_dim=s) produces the same dataset as isel(face_dim=derived_face_slice).""" + metadata = ds.sgrid.metadata + fnp = metadata.face_dimensions[0] + node_dim, face_dim = fnp.node, fnp.face + assume(node_dim in ds.dims and face_dim in ds.dims) + n_nodes = ds.dims[node_dim] + assume(len(range(*s.indices(n_nodes))) > 0) + + from parcels._sgrid.accessor import _derive_paired_indexer + + _, face_indexer = _derive_paired_indexer(s, indexer_is_node=True, padding=fnp.padding, dim_size=n_nodes) + + n_faces = ds.dims.get(face_dim) + if n_faces is not None: + assume(len(range(*face_indexer.indices(n_faces))) > 0) + + result_via_node = ds.sgrid.isel({node_dim: s}) + result_via_face = ds.sgrid.isel({face_dim: face_indexer}) + + xr.testing.assert_identical(result_via_node, result_via_face) + + +@pytest.mark.parametrize("ds_name", ["ds_2d_padded_none", "ds_2d_padded_both"]) +@pytest.mark.parametrize( + "indexer, match", + [ + (5, "Scalar and list indexers are not supported for NONE/BOTH padding"), + ([0, 1, 2], "Scalar and list indexers are not supported for NONE/BOTH padding"), + (slice(0, 8, 2), "Slices with step != 1 are not supported for NONE/BOTH padding"), + ], +) +def test_isel_asymmetric_padding_invalid_indexers(ds_name, indexer, match): + ds = datasets_sgrid[ds_name] + with pytest.raises(ValueError, match=match): + ds.sgrid.isel({"node_dimension1": indexer}) diff --git a/tests/sgrid/test_sgrid.py b/tests/sgrid/test_sgrid.py index 54dd56ea3..5582821d0 100644 --- a/tests/sgrid/test_sgrid.py +++ b/tests/sgrid/test_sgrid.py @@ -274,7 +274,9 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.SGrid2DMetadata): @given(pst.sgrid.grid3Dmetadata()) -@pytest.mark.xfail(reason="Parcels doesn't have native support for SGRID 3D grids. This metadata checking is superfluous until we have such support.") +@pytest.mark.xfail( + reason="Parcels doesn't have native support for SGRID 3D grids. This metadata checking is superfluous until we have such support." +) def test_parse_sgrid_3d(grid_metadata: sgrid.SGrid3DMetadata): """Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided""" ds = dummy_sgrid_3d_ds(grid_metadata) From 824348695e5b38b082e6ec71f55e8d401a9a3fe1 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 20 May 2026 13:09:00 +0200 Subject: [PATCH 18/21] Use dataset as argument --- src/parcels/_sgrid/accessor.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py index 4fdbd8b68..a0c049736 100644 --- a/src/parcels/_sgrid/accessor.py +++ b/src/parcels/_sgrid/accessor.py @@ -1,5 +1,5 @@ import itertools -from collections.abc import Hashable, Mapping, Sequence +from collections.abc import Mapping, Sequence from typing import Any, Literal, cast import xarray as xr @@ -58,8 +58,7 @@ def isel(self, indexers: Mapping[str, Any] | None = None, **indexers_kwargs): _assert_not_indexing_along_same_axis(indexers, metadata) _assert_all_isel_along_axis(list(indexers.keys()), metadata) - dim_sizes = dict(self._ds.sizes) - indexers = _complete_isel_indexing(indexers, metadata, list(self._ds.dims.keys()), dim_sizes) + indexers = _complete_isel_indexing(self._ds, indexers, metadata) ds = self._ds.isel(indexers=indexers) assert_metadata_ds_consistency(ds, metadata) @@ -209,10 +208,9 @@ def _assert_all_isel_along_axis(index_dims: Sequence[str], metadata: SGrid2DMeta def _complete_isel_indexing( + ds: xr.Dataset, indexers: Mapping[Any, Any], grid: SGrid2DMetadata, - dims_in_dataset: Sequence[Hashable], - dim_sizes: Mapping[str, int], ) -> Mapping[Any, Any]: """For each user-supplied (dim, indexer), expand to both the face and node dim on that axis, deriving the paired indexer according to the padding type. @@ -223,7 +221,7 @@ def _complete_isel_indexing( for user_dim, user_indexer in indexers.items(): fnp, user_is_node = axis_info[user_dim] - n_user_dim = dim_sizes.get(user_dim) + n_user_dim = ds.sizes.get(user_dim) normalized_user, paired_indexer = _derive_paired_indexer( user_indexer, user_is_node, fnp.padding, dim_size=n_user_dim ) @@ -231,9 +229,9 @@ def _complete_isel_indexing( node_indexer = normalized_user if user_is_node else paired_indexer face_indexer = paired_indexer if user_is_node else normalized_user - if fnp.node in dims_in_dataset: + if fnp.node in ds.dims: ret[fnp.node] = node_indexer - if fnp.face in dims_in_dataset: + if fnp.face in ds.dims: ret[fnp.face] = face_indexer return ret From e3d8ce392c30f2b9cab6ce0ce82d319a9ba4a421 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 20 May 2026 13:24:23 +0200 Subject: [PATCH 19/21] Self review --- src/parcels/_sgrid/accessor.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py index a0c049736..6d61f3f76 100644 --- a/src/parcels/_sgrid/accessor.py +++ b/src/parcels/_sgrid/accessor.py @@ -6,7 +6,7 @@ from parcels._python import invert_non_unique_mapping -from .core import FaceNodePadding, Padding, SGrid2DMetadata, SGrid3DMetadata, get_n_faces, parse_grid_attrs +from .core import FaceNodePadding, Padding, SGrid2DMetadata, SGrid3DMetadata, get_n_faces, get_n_nodes, parse_grid_attrs @xr.register_dataset_accessor("sgrid") @@ -73,18 +73,14 @@ def assert_metadata_ds_consistency(ds: xr.Dataset, metadata: SGrid2DMetadata): try: n_nodes = ds.dims[node] - except KeyError: - # node dimension is not in this dataset + except KeyError: # node dimension is not in this dataset continue - try: n_faces = ds.dims[face] - except KeyError: - # face dimension is not in this dataset + except KeyError: # face dimension is not in this dataset continue expected_n_faces = get_n_faces(n_nodes, padding) - if expected_n_faces != n_faces: raise SGridDatasetInconsistency( f"Node dimension {node!r} has size {n_nodes}, and face dimension {face!r} has size of {n_faces}. " @@ -174,10 +170,7 @@ def _derive_paired_indexer( # Adjust stop: positive stops reference from the start of the array, so ±1 is needed. if stop is not None and stop > 0: - if padding == Padding.NONE: - stop = stop - 1 if indexer_is_node else stop + 1 - else: # BOTH - stop = stop + 1 if indexer_is_node else stop - 1 + stop = get_n_faces(stop, padding=padding) if indexer_is_node else get_n_nodes(stop, padding=padding) return normalized_user_indexer, slice(start, stop) From 8e8910a4726c42decbdca00554991cf055d9775e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 May 2026 12:07:40 +0000 Subject: [PATCH 20/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci From dd66b21258277e7732b5f81dc2d27225e1b33a60 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 21 May 2026 11:24:17 +0200 Subject: [PATCH 21/21] Update docstrings --- src/parcels/_sgrid/accessor.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py index 6d61f3f76..857a92444 100644 --- a/src/parcels/_sgrid/accessor.py +++ b/src/parcels/_sgrid/accessor.py @@ -23,7 +23,7 @@ def metadata(self) -> SGrid2DMetadata: return grid def rename(self, name_dict: dict[str, str]) -> xr.Dataset: - """Similar to Xarray's rename functionality - but also updates the SGRID metadata attributes.""" + """Similar to Xarray's rename functionality - but also updates attached SGRID metadata.""" ds = self._ds.copy() ds = ds.rename(name_dict) @@ -44,7 +44,33 @@ def _get_grid_topology(self) -> xr.DataArray: return grid_da def isel(self, indexers: Mapping[str, Any] | None = None, **indexers_kwargs): - """TODO: Docstring""" + """Index the dataset along SGRID spatial dimensions, keeping face and node dimensions consistent. + + For a provided index, this function derives the paired index from SGRID metadata, applies the indexes, and asserts the restulting dataset still complies with the SGRID metadata. + + Behaviour: + + - Only spatial (SGRID-registered) dimensions may be indexed. + - Simultaneously indexing along two dimensions that belong to the same axis is not allowed. + - For NONE/BOTH padding, only contiguous regions are supported (i.e., ``slice`` + indexers with ``step`` of ``None`` or ``1``). Indexing discontiguous regions are not well defined. + + + Parameters + ---------- + indexers : Mapping[str, Any], optional + A mapping of dimension name to indexer, e.g. ``{"node_dimension1": slice(0, 10)}``. + Mutually exclusive with ``**indexers_kwargs``. + **indexers_kwargs + Dimension-to-indexer pairs as keyword arguments. Mutually exclusive with + ``indexers``. + + Returns + ------- + xr.Dataset + A new dataset indexed along the requested dimensions with all paired face/node + dimensions adjusted accordingly. + """ if indexers_kwargs != {}: if indexers is not None: raise ValueError("Cannot provide both positional and keyword argument to .isel .")