diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index a44472e00..f83df364f 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -10,8 +10,8 @@ import xarray as xr import xgcm +import parcels._sgrid as sgrid from parcels._core.field import Field, VectorField -from parcels._core.utils import sgrid from parcels._core.utils.string import _assert_str_and_python_varname from parcels._core.utils.time import get_datetime_type_calendar from parcels._core.utils.time import is_compatible as datetime_is_compatible @@ -280,7 +280,7 @@ def from_sgrid_conventions( ds = ds.rename({time_dim: "time"}) # Parse SGRID metadata and get xgcm kwargs - _, xgcm_kwargs = sgrid.parse_sgrid(ds) + _, xgcm_kwargs = sgrid.xgcm_parse_sgrid(ds) # Add time axis to xgcm_kwargs if present if "time" in ds.dims: @@ -475,11 +475,7 @@ def _select_uxinterpolator(da: ux.UxDataArray): # TODO: Refactor later into something like `parcels._metadata.discover(dataset)` helper that can be used to discover important metadata like this. I think this whole metadata handling should be refactored into its own module. def _get_mesh_type_from_sgrid_dataset(ds_sgrid: xr.Dataset) -> Mesh: """Small helper to inspect SGRID metadata and dataset metadata to determine mesh type.""" - grid_da = sgrid.get_grid_topology(ds_sgrid) - if grid_da is None: - raise ValueError("Dataset does not contain SGRID grid topology metadata (cf_role='grid_topology').") - - sgrid_metadata = sgrid.parse_grid_attrs(grid_da.attrs) + sgrid_metadata = ds_sgrid.sgrid.metadata fpoint_x, fpoint_y = sgrid_metadata.node_coordinates diff --git a/src/parcels/_datasets/structured/generated.py b/src/parcels/_datasets/structured/generated.py index c0477828b..665bcaf0f 100644 --- a/src/parcels/_datasets/structured/generated.py +++ b/src/parcels/_datasets/structured/generated.py @@ -3,12 +3,7 @@ import numpy as np import xarray as xr -from parcels._core.utils.sgrid import ( - FaceNodePadding, - Padding, - SGrid2DMetadata, - _attach_sgrid_metadata, -) +import parcels._sgrid as sgrid from parcels._core.utils.time import timedelta_to_float @@ -29,16 +24,17 @@ def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"): "lon": (["XG"], np.linspace(-max_lon, max_lon, dims[3]), {"axis": "X", "c_grid_axis_shift": -0.5}), }, ).pipe( - _attach_sgrid_metadata, - SGrid2DMetadata( + sgrid._attach_sgrid_metadata, + sgrid.SGrid2DMetadata( cf_role="grid_topology", topology_dimension=2, node_dimensions=("XG", "YG"), + node_coordinates=("lon", "lat"), face_dimensions=( - FaceNodePadding("XC", "XG", Padding.LOW), - FaceNodePadding("YC", "YG", Padding.LOW), + sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.LOW), + sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.LOW), ), - vertical_dimensions=(FaceNodePadding("ZC", "depth", Padding.BOTH),), + vertical_dimensions=(sgrid.FaceNodePadding("ZC", "depth", sgrid.Padding.BOTH),), ), ) diff --git a/src/parcels/_datasets/structured/generic.py b/src/parcels/_datasets/structured/generic.py index c26d8e9ce..ff729c8a1 100644 --- a/src/parcels/_datasets/structured/generic.py +++ b/src/parcels/_datasets/structured/generic.py @@ -1,15 +1,7 @@ import numpy as np import xarray as xr -from parcels._core.utils.sgrid import ( - FaceNodePadding, - Padding, - SGrid2DMetadata, - _attach_sgrid_metadata, -) -from parcels._core.utils.sgrid import ( - rename as sgrid_rename, -) +import parcels._sgrid as sgrid from . import T, X, Y, Z @@ -248,43 +240,89 @@ def _unrolled_cone_curvilinear_grid(): "ds_2d_padded_high": ( datasets["ds_2d_left"] .pipe( - _attach_sgrid_metadata, - SGrid2DMetadata( + sgrid._attach_sgrid_metadata, + sgrid.SGrid2DMetadata( cf_role="grid_topology", topology_dimension=2, node_dimensions=("XG", "YG"), face_dimensions=( - FaceNodePadding("XC", "XG", Padding.HIGH), - FaceNodePadding("YC", "YG", Padding.HIGH), + sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.HIGH), ), node_coordinates=("lon", "lat"), - vertical_dimensions=(FaceNodePadding("ZC", "ZG", Padding.HIGH),), + vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.HIGH),), ), ) - .pipe( - sgrid_rename, + .sgrid.rename( _COMODO_TO_2D_SGRID, ) ), "ds_2d_padded_low": ( datasets["ds_2d_right"] .pipe( - _attach_sgrid_metadata, - SGrid2DMetadata( + sgrid._attach_sgrid_metadata, + sgrid.SGrid2DMetadata( cf_role="grid_topology", topology_dimension=2, node_dimensions=("XG", "YG"), face_dimensions=( - FaceNodePadding("XC", "XG", Padding.LOW), - FaceNodePadding("YC", "YG", Padding.LOW), + sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.LOW), + sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.LOW), ), node_coordinates=("lon", "lat"), - vertical_dimensions=(FaceNodePadding("ZC", "ZG", Padding.LOW),), + vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.LOW),), ), ) - .pipe( - sgrid_rename, + .sgrid.rename( _COMODO_TO_2D_SGRID, ) ), + "ds_2d_padded_none": xr.Dataset( + { + "data_g": (["node_dimension1", "node_dimension2"], np.random.rand(10, 10)), + "data_c": (["face_dimension1", "face_dimension2"], np.random.rand(9, 9)), + "grid": ( + [], + np.array(0), + sgrid.SGrid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("node_dimension1", "node_dimension2"), + face_dimensions=( + sgrid.FaceNodePadding("face_dimension1", "node_dimension1", sgrid.Padding.NONE), + sgrid.FaceNodePadding("face_dimension2", "node_dimension2", sgrid.Padding.NONE), + ), + node_coordinates=("lon", "lat"), + ).to_attrs(), + ), + }, + coords={ + "lon": (["node_dimension1"], np.linspace(0, 1, 10)), + "lat": (["node_dimension2"], np.linspace(0, 1, 10)), + }, + ), + "ds_2d_padded_both": xr.Dataset( + { + "data_g": (["node_dimension1", "node_dimension2"], np.random.rand(10, 10)), + "data_c": (["face_dimension1", "face_dimension2"], np.random.rand(11, 11)), + "grid": ( + [], + np.array(0), + sgrid.SGrid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("node_dimension1", "node_dimension2"), + face_dimensions=( + sgrid.FaceNodePadding("face_dimension1", "node_dimension1", sgrid.Padding.BOTH), + sgrid.FaceNodePadding("face_dimension2", "node_dimension2", sgrid.Padding.BOTH), + ), + node_coordinates=("lon", "lat"), + ).to_attrs(), + ), + }, + coords={ + "lon": (["node_dimension1"], np.linspace(0, 1, 10)), + "lat": (["node_dimension2"], np.linspace(0, 1, 10)), + }, + ), } diff --git a/src/parcels/_datasets/structured/strategies.py b/src/parcels/_datasets/structured/strategies.py index c66dd0916..c317e3985 100644 --- a/src/parcels/_datasets/structured/strategies.py +++ b/src/parcels/_datasets/structured/strategies.py @@ -3,18 +3,8 @@ from hypothesis import strategies as st from hypothesis.extra.numpy import arrays as np_arrays +import parcels._sgrid as sgrid import parcels._strategies as pst -from parcels._core.utils import sgrid -from parcels._core.utils.sgrid import _attach_sgrid_metadata - - -def _face_size(node_size: int, padding: sgrid.Padding) -> int: - if padding == sgrid.Padding.NONE: - return node_size - 1 - elif padding in (sgrid.Padding.LOW, sgrid.Padding.HIGH): - return node_size - else: # Padding.BOTH - return node_size + 1 @st.composite @@ -33,14 +23,14 @@ def sgrid_dataset(draw, grid: sgrid.SGrid2DMetadata | None = None) -> xr.Dataset node_dim1, node_dim2 = grid.node_dimensions face_dim1 = grid.face_dimensions[0].face face_dim2 = grid.face_dimensions[1].face - N_face = _face_size(N, grid.face_dimensions[0].padding) - M_face = _face_size(M, grid.face_dimensions[1].padding) + N_face = sgrid.get_n_faces(N, grid.face_dimensions[0].padding) + M_face = sgrid.get_n_faces(M, grid.face_dimensions[1].padding) if has_vertical := grid.vertical_dimensions is not None: P = draw(st.integers(min_value=5, max_value=20)) vert_node_dim = grid.vertical_dimensions[0].node vert_face_dim = grid.vertical_dimensions[0].face - P_face = _face_size(P, grid.vertical_dimensions[0].padding) + P_face = sgrid.get_n_faces(P, grid.vertical_dimensions[0].padding) has_curvilinear_grid = draw(st.booleans()) coord_name1, coord_name2 = grid.node_coordinates @@ -90,4 +80,4 @@ def sgrid_dataset(draw, grid: sgrid.SGrid2DMetadata | None = None) -> xr.Dataset } ds = xr.Dataset(data_vars=data_vars, coords=coords) - return _attach_sgrid_metadata(ds, grid) + return sgrid._attach_sgrid_metadata(ds, grid) diff --git a/src/parcels/_python.py b/src/parcels/_python.py index b7f8b2845..81db6ade4 100644 --- a/src/parcels/_python.py +++ b/src/parcels/_python.py @@ -1,6 +1,10 @@ # Generic Python helpers import inspect -from collections.abc import Callable +from collections.abc import Callable, Mapping +from typing import TypeVar + +K = TypeVar("K") +V = TypeVar("V") def isinstance_noimport(obj, class_or_tuple): @@ -39,3 +43,10 @@ def assert_same_function_signature(f: Callable, *, ref: Callable, context: str) raise ValueError( f"Parameter '{param2.name}' has incorrect name. Expected '{param1.name}', got '{param2.name}'" ) + + +def invert_non_unique_mapping(d: Mapping[K, V]) -> Mapping[V, list[K]]: + inv_map: dict[V, list[K]] = {} + for k, v in d.items(): + inv_map[v] = inv_map.get(v, []) + [k] + return inv_map diff --git a/src/parcels/_sgrid/__init__.py b/src/parcels/_sgrid/__init__.py new file mode 100644 index 000000000..f06e03830 --- /dev/null +++ b/src/parcels/_sgrid/__init__.py @@ -0,0 +1,27 @@ +from .accessor import SgridAccessor +from .core import ( + FaceNodePadding, + Padding, + SGrid2DMetadata, + SGrid3DMetadata, + _attach_sgrid_metadata, + dump_mappings, + get_n_faces, + get_n_nodes, + load_mappings, + xgcm_parse_sgrid, +) + +__all__ = [ + "FaceNodePadding", + "Padding", + "SGrid2DMetadata", + "SGrid3DMetadata", + "SgridAccessor", + "_attach_sgrid_metadata", + "dump_mappings", + "get_n_faces", + "get_n_nodes", + "load_mappings", + "xgcm_parse_sgrid", +] diff --git a/src/parcels/_sgrid/accessor.py b/src/parcels/_sgrid/accessor.py new file mode 100644 index 000000000..857a92444 --- /dev/null +++ b/src/parcels/_sgrid/accessor.py @@ -0,0 +1,256 @@ +import itertools +from collections.abc import Mapping, Sequence +from typing import Any, Literal, cast + +import xarray as xr + +from parcels._python import invert_non_unique_mapping + +from .core import FaceNodePadding, Padding, SGrid2DMetadata, SGrid3DMetadata, get_n_faces, get_n_nodes, parse_grid_attrs + + +@xr.register_dataset_accessor("sgrid") +class SgridAccessor: + def __init__(self, xarray_obj): + self._ds: xr.Dataset = xarray_obj + + @property + def metadata(self) -> SGrid2DMetadata: + grid_da = self._get_grid_topology() + grid = parse_grid_attrs(grid_da.attrs) + if isinstance(grid, SGrid3DMetadata): + raise NotImplementedError("Support for 3D SGRID metadata not supported.") + return grid + + def rename(self, name_dict: dict[str, str]) -> xr.Dataset: + """Similar to Xarray's rename functionality - but also updates attached SGRID metadata.""" + ds = self._ds.copy() + ds = ds.rename(name_dict) + + grid_da_name = self._get_grid_topology().name + ds[grid_da_name].attrs = self.metadata.rename(name_dict).to_attrs() + return ds + + def _get_grid_topology(self) -> xr.DataArray: + grid_da = None + for var_name in self._ds.variables: + if self._ds[var_name].attrs.get("cf_role") == "grid_topology": + grid_da = self._ds[var_name] + + if grid_da is None: + raise ValueError( + "No variable found in dataset with 'cf_role' attribute set to 'grid_topology'. This doesn't look to be an SGrid dataset - please make your dataset conforms to SGrid conventions https://sgrid.github.io/sgrid/" + ) + return grid_da + + def isel(self, indexers: Mapping[str, Any] | None = None, **indexers_kwargs): + """Index the dataset along SGRID spatial dimensions, keeping face and node dimensions consistent. + + For a provided index, this function derives the paired index from SGRID metadata, applies the indexes, and asserts the restulting dataset still complies with the SGRID metadata. + + Behaviour: + + - Only spatial (SGRID-registered) dimensions may be indexed. + - Simultaneously indexing along two dimensions that belong to the same axis is not allowed. + - For NONE/BOTH padding, only contiguous regions are supported (i.e., ``slice`` + indexers with ``step`` of ``None`` or ``1``). Indexing discontiguous regions are not well defined. + + + Parameters + ---------- + indexers : Mapping[str, Any], optional + A mapping of dimension name to indexer, e.g. ``{"node_dimension1": slice(0, 10)}``. + Mutually exclusive with ``**indexers_kwargs``. + **indexers_kwargs + Dimension-to-indexer pairs as keyword arguments. Mutually exclusive with + ``indexers``. + + Returns + ------- + xr.Dataset + A new dataset indexed along the requested dimensions with all paired face/node + dimensions adjusted accordingly. + """ + if indexers_kwargs != {}: + if indexers is not None: + raise ValueError("Cannot provide both positional and keyword argument to .isel .") + indexers = indexers_kwargs + + if indexers is None: + raise ValueError("Must provide indexers either as a positional argument or as keyword arguments.") + + metadata = self.metadata + + _assert_not_indexing_along_same_axis(indexers, metadata) + _assert_all_isel_along_axis(list(indexers.keys()), metadata) + + indexers = _complete_isel_indexing(self._ds, indexers, metadata) + + ds = self._ds.isel(indexers=indexers) + assert_metadata_ds_consistency(ds, metadata) + return ds + + +def assert_metadata_ds_consistency(ds: xr.Dataset, metadata: SGrid2DMetadata): + vertical_dimensions: tuple[FaceNodePadding, ...] = metadata.vertical_dimensions or tuple() + + for obj in itertools.chain(metadata.face_dimensions, vertical_dimensions): + face, node, padding = obj.face, obj.node, obj.padding + + try: + n_nodes = ds.dims[node] + except KeyError: # node dimension is not in this dataset + continue + try: + n_faces = ds.dims[face] + except KeyError: # face dimension is not in this dataset + continue + + expected_n_faces = get_n_faces(n_nodes, padding) + if expected_n_faces != n_faces: + raise SGridDatasetInconsistency( + f"Node dimension {node!r} has size {n_nodes}, and face dimension {face!r} has size of {n_faces}. " + f"Due to dataset padding of {padding!r}, expected face dimension {face} to actually be size {expected_n_faces}." + ) + + # TODO: Also check on coordinates + + +class SGridDatasetInconsistency(Exception): + """Attached metadata is not compatible with Xarray dataset""" + + pass + + +def _get_dim_to_axis_mapping(grid: SGrid2DMetadata) -> dict[str, Literal["X", "Y", "Z"]]: + fnp_x = grid.face_dimensions[0] + fnp_y = grid.face_dimensions[1] + fnp_z = grid.vertical_dimensions[0] if grid.vertical_dimensions is not None else None + + d = { + fnp_x.node: "X", + fnp_x.face: "X", + fnp_y.node: "Y", + fnp_y.face: "Y", + } + if fnp_z is not None: + d.update({fnp_z.node: "Z", fnp_z.face: "Z"}) + return cast(dict[str, Literal["X", "Y", "Z"]], d) + + +def _get_axis_info(grid: SGrid2DMetadata) -> dict[str, tuple[FaceNodePadding, bool]]: + """For each spatial dim, return (FaceNodePadding it belongs to, True if node dim).""" + result: dict[str, tuple[FaceNodePadding, bool]] = {} + all_fnps = list(grid.face_dimensions) + list(grid.vertical_dimensions or []) + for fnp in all_fnps: + result[fnp.node] = (fnp, True) + result[fnp.face] = (fnp, False) + return result + + +def _derive_paired_indexer( + indexer: Any, + indexer_is_node: bool, + padding: Padding, + dim_size: int | None = None, +) -> tuple[Any, Any]: + """Given a user's indexer for one side of a face-node pair, return ``(normalized_user_indexer, paired_indexer)``. + + For HIGH/LOW padding, face and node dims are the same size so both the normalised user + indexer and the paired indexer are identical to the original ``user_indexer``. + For NONE/BOTH padding, the slice is first normalised to non-negative absolute indices (via + ``slice.indices``) and then the stop of the paired indexer is adjusted by ±1. + + ``n_user_dim`` is required for NONE/BOTH padding so that negative starts and ``None`` stops + can be resolved to unambiguous absolute positions. + + Scalar (integer) and list indexers raise for NONE/BOTH because there is no unambiguous + paired position. + + Returns + ------- + tuple[Any, Any] + ``(normalized_user_indexer, paired_indexer)`` — the first element is the user's indexer + after normalisation (unchanged for HIGH/LOW), the second is the derived indexer for the + other side of the face-node pair. + """ + if padding in (Padding.HIGH, Padding.LOW): + return indexer, indexer + + # NONE and BOTH: only slices with step in {None, 1} are supported + if not isinstance(indexer, slice): + raise ValueError( + f"Scalar and list indexers are not supported for NONE/BOTH padding. " + f"Got indexer {indexer!r}. Use a slice instead." + ) + if indexer.step not in (None, 1): + raise ValueError(f"Slices with step != 1 are not supported for NONE/BOTH padding. Got step={indexer.step!r}.") + if dim_size is None: + raise ValueError("dim_size must be provided for NONE/BOTH padding to correctly handle slices.") + + # Normalise to non-negative absolute indices so the arithmetic below is unambiguous. + abs_start, abs_stop, _ = indexer.indices(dim_size) + normalized_user_indexer = slice(abs_start, abs_stop) + + start, stop = abs_start, abs_stop + + # Adjust stop: positive stops reference from the start of the array, so ±1 is needed. + if stop is not None and stop > 0: + stop = get_n_faces(stop, padding=padding) if indexer_is_node else get_n_nodes(stop, padding=padding) + + return normalized_user_indexer, slice(start, stop) + + +def _assert_not_indexing_along_same_axis(indexers: Mapping[Any, Any], metadata: SGrid2DMetadata) -> None: + dim_to_axis = _get_dim_to_axis_mapping(metadata) + indexer_dim_to_axis = {dim: dim_to_axis.get(dim) for dim in indexers} + + indexer_axis_to_dim = invert_non_unique_mapping(indexer_dim_to_axis) + for axis, dims in indexer_axis_to_dim.items(): + if axis is None: + continue + + if len(dims) > 1: + msg = f"Dims {dims} are on the same axis {axis!r} according to SGRID metadata - cannot simultaneously index along multiple dimensions in the same axis." + raise ValueError(msg) + + +def _assert_all_isel_along_axis(index_dims: Sequence[str], metadata: SGrid2DMetadata): + dim_to_axis = _get_dim_to_axis_mapping(metadata) + for dim in index_dims: + try: + dim_to_axis[dim] + except KeyError as e: + raise ValueError( + f"Cannot use SGRID accessor to .isel non-spatial (/SGRID related) dimension {dim!r}." + ) from e + + +def _complete_isel_indexing( + ds: xr.Dataset, + indexers: Mapping[Any, Any], + grid: SGrid2DMetadata, +) -> Mapping[Any, Any]: + """For each user-supplied (dim, indexer), expand to both the face and node dim on that axis, + deriving the paired indexer according to the padding type. + """ + axis_info = _get_axis_info(grid) + ret: dict[Any, Any] = {} + + for user_dim, user_indexer in indexers.items(): + fnp, user_is_node = axis_info[user_dim] + + n_user_dim = ds.sizes.get(user_dim) + normalized_user, paired_indexer = _derive_paired_indexer( + user_indexer, user_is_node, fnp.padding, dim_size=n_user_dim + ) + + node_indexer = normalized_user if user_is_node else paired_indexer + face_indexer = paired_indexer if user_is_node else normalized_user + + if fnp.node in ds.dims: + ret[fnp.node] = node_indexer + if fnp.face in ds.dims: + ret[fnp.face] = face_indexer + + return ret diff --git a/src/parcels/_core/utils/sgrid.py b/src/parcels/_sgrid/core.py similarity index 89% rename from src/parcels/_core/utils/sgrid.py rename to src/parcels/_sgrid/core.py index ed8edce60..05bec18ad 100644 --- a/src/parcels/_core/utils/sgrid.py +++ b/src/parcels/_sgrid/core.py @@ -22,7 +22,7 @@ from parcels._python import repr_from_dunder_dict -RE_FACE_NODE_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)" +_RE_FACE_NODE_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)" Dim = str @@ -47,38 +47,36 @@ class Padding(enum.Enum): } -class AttrsSerializable(Protocol): - def to_attrs(self) -> dict[str, str | int]: ... +def get_n_faces(n_nodes: int, padding: Padding) -> int: + """Get number of faces along a dimension""" + if padding in [Padding.LOW, Padding.HIGH]: + return n_nodes + if padding == Padding.NONE: + return n_nodes - 1 + if padding == Padding.BOTH: + return n_nodes + 1 + raise ValueError(f"Invalid {padding=!r}") - @classmethod - def from_attrs(cls, d: dict[str, Hashable]) -> Self: ... +def get_n_nodes(n_faces: int, padding: Padding) -> int: + """Get number of nodes along a dimension""" + if padding in [Padding.LOW, Padding.HIGH]: + return n_faces + if padding == Padding.NONE: + return n_faces + 1 + if padding == Padding.BOTH: + return n_faces - 1 + raise ValueError(f"Invalid {padding=!r}") -# Note that - for some optional attributes in the SGRID spec - these IDs are not available -# hence this isn't full coverage -_ID_FETCHERS_GRID2DMETADATA: dict[str, Callable[[SGrid2DMetadata], Dim | Padding]] = { - "node_dimension1": lambda meta: meta.node_dimensions[0], - "node_dimension2": lambda meta: meta.node_dimensions[1], - "face_dimension1": lambda meta: meta.face_dimensions[0].face, - "face_dimension2": lambda meta: meta.face_dimensions[1].face, - "type1": lambda meta: meta.face_dimensions[0].padding, - "type2": lambda meta: meta.face_dimensions[1].padding, -} -_ID_FETCHERS_GRID3DMETADATA: dict[str, Callable[[SGrid3DMetadata], Dim | Padding]] = { - "node_dimension1": lambda meta: meta.node_dimensions[0], - "node_dimension2": lambda meta: meta.node_dimensions[1], - "node_dimension3": lambda meta: meta.node_dimensions[2], - "face_dimension1": lambda meta: meta.volume_dimensions[0].face, - "face_dimension2": lambda meta: meta.volume_dimensions[1].face, - "face_dimension3": lambda meta: meta.volume_dimensions[2].face, - "type1": lambda meta: meta.volume_dimensions[0].padding, - "type2": lambda meta: meta.volume_dimensions[1].padding, - "type3": lambda meta: meta.volume_dimensions[2].padding, -} +class _AttrsSerializable(Protocol): + def to_attrs(self) -> dict[str, str | int]: ... + @classmethod + def from_attrs(cls, d: dict[str, Hashable]) -> Self: ... -class SGrid2DMetadata(AttrsSerializable): + +class SGrid2DMetadata(_AttrsSerializable): def __init__( self, cf_role: Literal["grid_topology"], @@ -157,26 +155,26 @@ def __eq__(self, other: Any) -> bool: return self.to_attrs() == other.to_attrs() @classmethod - def from_attrs(cls, attrs): # type: ignore[override] + def from_attrs(cls, attrs): try: return cls( cf_role=attrs["cf_role"], topology_dimension=attrs["topology_dimension"], node_dimensions=cast(tuple[Dim, Dim], load_mappings(attrs["node_dimensions"])), face_dimensions=cast(tuple[FaceNodePadding, FaceNodePadding], load_mappings(attrs["face_dimensions"])), - node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")), - vertical_dimensions=maybe_load_mappings(attrs.get("vertical_dimensions")), + node_coordinates=_maybe_load_mappings(attrs.get("node_coordinates")), + vertical_dimensions=_maybe_load_mappings(attrs.get("vertical_dimensions")), ) except Exception as e: raise SGridParsingException(f"Failed to parse Grid2DMetadata from {attrs=!r}") from e def to_attrs(self) -> dict[str, str | int]: - d = dict( - cf_role=self.cf_role, - topology_dimension=self.topology_dimension, - node_dimensions=dump_mappings(self.node_dimensions), - face_dimensions=dump_mappings(self.face_dimensions), - ) + d: dict[str, str | int] = { + "cf_role": self.cf_role, + "topology_dimension": self.topology_dimension, + "node_dimensions": dump_mappings(self.node_dimensions), + "face_dimensions": dump_mappings(self.face_dimensions), + } if self.node_coordinates is not None: d["node_coordinates"] = dump_mappings(self.node_coordinates) if self.vertical_dimensions is not None: @@ -200,7 +198,7 @@ def get_value_by_id(self, id: str) -> Dim | Padding: return _ID_FETCHERS_GRID2DMETADATA[id](self) -class SGrid3DMetadata(AttrsSerializable): +class SGrid3DMetadata(_AttrsSerializable): def __init__( self, cf_role: Literal["grid_topology"], @@ -273,7 +271,7 @@ def __eq__(self, other: Any) -> bool: return self.to_attrs() == other.to_attrs() @classmethod - def from_attrs(cls, attrs): # type: ignore[override] + def from_attrs(cls, attrs): try: return cls( cf_role=attrs["cf_role"], @@ -282,18 +280,18 @@ def from_attrs(cls, attrs): # type: ignore[override] volume_dimensions=cast( tuple[FaceNodePadding, FaceNodePadding, FaceNodePadding], load_mappings(attrs["volume_dimensions"]) ), - node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")), + node_coordinates=_maybe_load_mappings(attrs.get("node_coordinates")), ) except Exception as e: raise SGridParsingException(f"Failed to parse Grid3DMetadata from {attrs=!r}") from e def to_attrs(self) -> dict[str, str | int]: - d = dict( - cf_role=self.cf_role, - topology_dimension=self.topology_dimension, - node_dimensions=dump_mappings(self.node_dimensions), - volume_dimensions=dump_mappings(self.volume_dimensions), - ) + d: dict[str, str | int] = { + "cf_role": self.cf_role, + "topology_dimension": self.topology_dimension, + "node_dimensions": dump_mappings(self.node_dimensions), + "volume_dimensions": dump_mappings(self.volume_dimensions), + } if self.node_coordinates is not None: d["node_coordinates"] = dump_mappings(self.node_coordinates) return d @@ -315,6 +313,30 @@ def get_value_by_id(self, id: str) -> Dim | Padding: return _ID_FETCHERS_GRID3DMETADATA[id](self) +# Note that - for some optional attributes in the SGRID spec - these IDs are not available +# hence this isn't full coverage +_ID_FETCHERS_GRID2DMETADATA: dict[str, Callable[[SGrid2DMetadata], Dim | Padding]] = { + "node_dimension1": lambda meta: meta.node_dimensions[0], + "node_dimension2": lambda meta: meta.node_dimensions[1], + "face_dimension1": lambda meta: meta.face_dimensions[0].face, + "face_dimension2": lambda meta: meta.face_dimensions[1].face, + "type1": lambda meta: meta.face_dimensions[0].padding, + "type2": lambda meta: meta.face_dimensions[1].padding, +} + +_ID_FETCHERS_GRID3DMETADATA: dict[str, Callable[[SGrid3DMetadata], Dim | Padding]] = { + "node_dimension1": lambda meta: meta.node_dimensions[0], + "node_dimension2": lambda meta: meta.node_dimensions[1], + "node_dimension3": lambda meta: meta.node_dimensions[2], + "face_dimension1": lambda meta: meta.volume_dimensions[0].face, + "face_dimension2": lambda meta: meta.volume_dimensions[1].face, + "face_dimension3": lambda meta: meta.volume_dimensions[2].face, + "type1": lambda meta: meta.volume_dimensions[0].padding, + "type2": lambda meta: meta.volume_dimensions[1].padding, + "type3": lambda meta: meta.volume_dimensions[2].padding, +} + + @dataclass class FaceNodePadding: """A data class representing a face-node-padding triplet for SGrid metadata. @@ -336,7 +358,7 @@ def __str__(self) -> str: @classmethod def load(cls, s: str) -> Self: - match = re.match(RE_FACE_NODE_PADDING, s) + match = re.match(_RE_FACE_NODE_PADDING, s) if not match: raise ValueError(f"String {s!r} does not match expected format for FaceNodePadding") face = match.group(1) @@ -359,12 +381,12 @@ def dump_mappings(parts: Iterable[FaceNodePadding | Dim]) -> str: @overload -def maybe_dump_mappings(parts: None) -> None: ... +def _maybe_dump_mappings(parts: None) -> None: ... @overload -def maybe_dump_mappings(parts: Iterable[FaceNodePadding | Dim]) -> str: ... +def _maybe_dump_mappings(parts: Iterable[FaceNodePadding | Dim]) -> str: ... -def maybe_dump_mappings(parts): +def _maybe_dump_mappings(parts): if parts is None: return None return dump_mappings(parts) @@ -383,15 +405,15 @@ def load_mappings(s: str) -> tuple[FaceNodePadding | Dim, ...]: ret = [] while s: # find next part - match = re.match(RE_FACE_NODE_PADDING, s) + match = re.match(_RE_FACE_NODE_PADDING, s) if match and match.start() == 0: # match found at start, take that as next part part = match.group(0) s_new = s[match.end() :].lstrip() else: # no FaceNodePadding match at start, assume just a Dim until next space - part, *s_new = s.split(" ", 1) - s_new = "".join(s_new) + part, *rest = s.split(" ", 1) + s_new = "".join(rest) assert s != s_new, f"SGrid parsing did not advance, stuck at {s!r}" @@ -414,12 +436,12 @@ def load_mappings(s: str) -> tuple[FaceNodePadding | Dim, ...]: @overload -def maybe_load_mappings(s: None) -> None: ... +def _maybe_load_mappings(s: None) -> None: ... @overload -def maybe_load_mappings(s: Hashable) -> tuple[FaceNodePadding | Dim, ...]: ... +def _maybe_load_mappings(s: str) -> tuple[FaceNodePadding | Dim, ...]: ... -def maybe_load_mappings(s): +def _maybe_load_mappings(s): if s is None: return None return load_mappings(s) @@ -445,24 +467,10 @@ def parse_grid_attrs(attrs: dict[str, Hashable]) -> SGrid2DMetadata | SGrid3DMet return grid -def get_grid_topology(ds: xr.Dataset) -> xr.DataArray | None: - """Extracts grid topology DataArray from an xarray Dataset.""" - for var_name in ds.variables: - if ds[var_name].attrs.get("cf_role") == "grid_topology": - return ds[var_name] - return None - - -def parse_sgrid(ds: xr.Dataset): +def xgcm_parse_sgrid(ds: xr.Dataset): # Function similar to that provided in `xgcm.metadata_parsers. # Might at some point be upstreamed to xgcm directly - try: - grid_topology = get_grid_topology(ds) - assert grid_topology is not None, "No grid_topology variable found in dataset" - grid = parse_grid_attrs(grid_topology.attrs) - - except Exception as e: - raise SGridParsingException(f"Error parsing {grid_topology=!r}") from e + grid = ds.sgrid.metadata if isinstance(grid, SGrid2DMetadata): dimensions = grid.face_dimensions + (grid.vertical_dimensions or ()) @@ -484,22 +492,7 @@ def parse_sgrid(ds: xr.Dataset): return (ds, {"coords": xgcm_coords}) -def rename(ds: xr.Dataset, name_dict: dict[str, str]) -> xr.Dataset: - grid_da = get_grid_topology(ds) - if grid_da is None: - raise ValueError( - "No variable found in dataset with 'cf_role' attribute set to 'grid_topology'. This doesn't look to be an SGrid dataset - please make your dataset conforms to SGrid conventions." - ) - - ds = ds.rename(name_dict) - - # Update the metadata - grid = parse_grid_attrs(grid_da.attrs) - ds[grid_da.name].attrs = grid.rename(name_dict).to_attrs() - return ds - - -def get_unique_names(grid: SGrid2DMetadata | SGrid3DMetadata) -> set[str]: +def _get_unique_names(grid: SGrid2DMetadata | SGrid3DMetadata) -> set[str]: dims = set() dims.update(set(grid.node_dimensions)) @@ -724,7 +717,7 @@ def _metadata_rename(grid, names_dict): names_dict = names_dict.copy() assert len(names_dict) == len(set(names_dict.values())), "names_dict contains duplicate target dimension names" - existing_names = get_unique_names(grid) + existing_names = _get_unique_names(grid) for name in names_dict.keys(): if name not in existing_names: raise ValueError(f"Name {name!r} not found in names defined in SGrid metadata {existing_names!r}") diff --git a/src/parcels/_strategies/__init__.py b/src/parcels/_strategies/__init__.py index 6cc66ee5a..15536e269 100644 --- a/src/parcels/_strategies/__init__.py +++ b/src/parcels/_strategies/__init__.py @@ -4,7 +4,7 @@ import hypothesis # noqa: F401 except ImportError as err: err.add_note( - "To use strategies you must have hypothesis installed. Install it from PyPI, Conda, or using your preffered package manager." + "To use strategies you must have hypothesis installed. Install it from PyPI, Conda, or using your preferred package manager." ) raise err diff --git a/src/parcels/_strategies/sgrid.py b/src/parcels/_strategies/sgrid.py index f413ab18b..010b23ed8 100644 --- a/src/parcels/_strategies/sgrid.py +++ b/src/parcels/_strategies/sgrid.py @@ -11,7 +11,7 @@ import xarray.testing.strategies as xr_st from hypothesis import strategies as st -from parcels._core.utils import sgrid +import parcels._sgrid as sgrid padding = st.sampled_from(sgrid.Padding) dimension_name = xr_st.names().filter( diff --git a/src/parcels/convert.py b/src/parcels/convert.py index 600262bd6..9f044f049 100644 --- a/src/parcels/convert.py +++ b/src/parcels/convert.py @@ -20,7 +20,7 @@ import numpy as np import xarray as xr -from parcels._core.utils import sgrid +import parcels._sgrid as sgrid from parcels._logger import logger if typing.TYPE_CHECKING: @@ -376,7 +376,7 @@ def nemo_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr.Da ds["gphif"].attrs["units"] = "degrees" # Update to use lon and lat for internal naming - ds = sgrid.rename(ds, {"gphif": "lat", "glamf": "lon"}) # TODO: Logging message about rename + ds = ds.sgrid.rename({"gphif": "lat", "glamf": "lon"}) # TODO: Logging message about rename return ds diff --git a/tests/datasets/test_strategies.py b/tests/datasets/test_strategies.py index 38083cb54..32faa1065 100644 --- a/tests/datasets/test_strategies.py +++ b/tests/datasets/test_strategies.py @@ -6,22 +6,8 @@ from hypothesis import given, settings from hypothesis.errors import NonInteractiveExampleWarning -from parcels._core.utils import sgrid -from parcels._core.utils.sgrid import get_grid_topology, parse_grid_attrs -from parcels._datasets.structured.strategies import _face_size, sgrid_dataset - - -@pytest.mark.parametrize( - "n_nodes, padding, n_edges", - [ - (10, sgrid.Padding.NONE, 9), - (10, sgrid.Padding.LOW, 10), - (10, sgrid.Padding.HIGH, 10), - (10, sgrid.Padding.BOTH, 11), - ], -) -def test_face_size(n_nodes, padding, n_edges): - assert _face_size(n_nodes, padding) == n_edges +import parcels._sgrid as sgrid +from parcels._datasets.structured.strategies import sgrid_dataset def test_sgrid_dataset_raises_when_no_node_coordinates(): @@ -50,13 +36,13 @@ def test_sgrid_dataset_returns_dataset(ds): @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_has_grid_topology(ds): - assert get_grid_topology(ds) is not None + ds.sgrid._get_grid_topology() # shouldn't error @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_node_coordinates_present(ds): - meta = parse_grid_attrs(get_grid_topology(ds).attrs) + meta = ds.sgrid.metadata assert meta.node_coordinates is not None for coord_name in meta.node_coordinates: assert coord_name in ds.coords @@ -65,7 +51,7 @@ def test_sgrid_dataset_node_coordinates_present(ds): @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_coordinate_shapes(ds): - meta = parse_grid_attrs(get_grid_topology(ds).attrs) + meta = ds.sgrid.metadata coord_name1, coord_name2 = meta.node_coordinates node_dim1, node_dim2 = meta.node_dimensions coord1 = ds.coords[coord_name1] @@ -86,7 +72,7 @@ def test_sgrid_dataset_has_at_least_one_field(ds): @given(sgrid_dataset()) @settings(max_examples=20) def test_sgrid_dataset_field_dims_are_valid(ds): - meta = parse_grid_attrs(get_grid_topology(ds).attrs) + meta = ds.sgrid.metadata valid_dims = set(meta.node_dimensions) valid_dims.add(meta.face_dimensions[0].face) valid_dims.add(meta.face_dimensions[1].face) diff --git a/tests/sgrid/test_accessor.py b/tests/sgrid/test_accessor.py new file mode 100644 index 000000000..8a9f9b0ce --- /dev/null +++ b/tests/sgrid/test_accessor.py @@ -0,0 +1,281 @@ +import hypothesis.strategies as st +import numpy as np +import pytest +import xarray as xr +from hypothesis import assume, given + +import parcels._sgrid as sgrid +import parcels._strategies as pst +from parcels._datasets.structured.generic import datasets_sgrid +from parcels._datasets.structured.strategies import sgrid_dataset +from parcels._sgrid.accessor import SGridDatasetInconsistency, assert_metadata_ds_consistency + + +@st.composite +def grid_and_dataset(draw) -> tuple[sgrid.SGrid2DMetadata, xr.Dataset]: + # used only for test_metadata - for all other tests we can simply do `ds.sgrid.metadata` to get the metadata + metadata_2d = draw( + pst.sgrid.grid_metadata.filter( + # parcels can only generate 2D Sgrid datasets, that also have coordinates + lambda meta: isinstance(meta, sgrid.SGrid2DMetadata) and meta.node_coordinates is not None + ) + ) + ds = draw(sgrid_dataset(metadata_2d)) + return metadata_2d, ds + + +@given(grid_and_dataset()) +def test_metadata(metadata_ds): + metadata = metadata_ds[0] + ds = metadata_ds[1] + parsed_metadata = ds.sgrid.metadata + assert parsed_metadata == metadata + + +@pytest.mark.parametrize( + "ds", + [ + xr.Dataset( + { + "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(10, 10, 10, 10)), + "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(10, 10, 10, 10)), + "grid": ( + [], + np.array(0), + sgrid.SGrid2DMetadata( + cf_role="grid_topology", + topology_dimension=2, + node_dimensions=("XG", "YG"), + face_dimensions=( + sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.HIGH), + ), + vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.HIGH),), + node_coordinates=("lon", "lat"), + ).to_attrs(), + ), + }, + coords={ + "lon": (["XG"], 2 * np.pi / 10 * np.arange(0, 10)), + "lat": (["YG"], 2 * np.pi / (10) * np.arange(0, 10)), + "depth": (["ZG"], np.arange(10)), + "time": (["time"], xr.date_range("2000", "2001", 10), {"axis": "T"}), + }, + ), + ], +) +def test_rename_dataset(ds): + # Check renaming works for coordinates + ds_new = ds.sgrid.rename({"lon": "lon_updated"}) + grid_new = ds_new.sgrid.metadata + assert "lon_updated" in ds_new.coords + assert "lon_updated" == grid_new.node_coordinates[0] + + # Check renaming works for dim + ds_new = ds.sgrid.rename({"XC": "XC_updated"}) + grid_new = ds_new.sgrid.metadata + assert "XC_updated" in ds_new.dims + assert "XC" not in ds_new.dims + assert "XC_updated" == grid_new.face_dimensions[0].face + + +@given(sgrid_dataset()) +def test_assert_metadata_ds_consistency(ds): + metadata: sgrid.SGrid2DMetadata = ds.sgrid.metadata + assert_metadata_ds_consistency(ds, metadata) + + +@given(ds=sgrid_dataset(), dim=st.sampled_from(["face_dimension1", "face_dimension2", "vertical_dimension"])) +def test_assert_metadata_ds_consistency_dropped_dim(ds, dim): + # dropping one of the SGRID dimensions still results in a consistent dataset + metadata: sgrid.SGrid2DMetadata = ds.sgrid.metadata + + if dim == "face_dimension1": + fnp = metadata.face_dimensions[0] + elif dim == "face_dimension2": + fnp = metadata.face_dimensions[1] + elif dim == "vertical_dimension": + assume(metadata.vertical_dimensions is not None) + assert metadata.vertical_dimensions is not None + fnp = metadata.vertical_dimensions[0] + else: + raise ValueError("Unexpected value for dim") + + assume(fnp.face in ds.dims) + + ds = ds.isel({fnp.face: 0}) + assert_metadata_ds_consistency(ds, metadata) + + +@given(ds=sgrid_dataset(), dim=st.sampled_from(["face_dimension1", "face_dimension2", "vertical_dimension"])) +def test_assert_metadata_ds_consistency_failures(ds, dim): + metadata: sgrid.SGrid2DMetadata = ds.sgrid.metadata + + if dim == "face_dimension1": + fnp = metadata.face_dimensions[0] + elif dim == "face_dimension2": + fnp = metadata.face_dimensions[1] + elif dim == "vertical_dimension": + assume(metadata.vertical_dimensions is not None) + assert metadata.vertical_dimensions is not None + fnp = metadata.vertical_dimensions[0] + else: + raise ValueError("Unexpected value for dim") + + assume(fnp.node in ds.dims) + assume(fnp.face in ds.dims) + + ds = ds.isel({fnp.face: slice(None, -1)}) + + with pytest.raises( + SGridDatasetInconsistency, + match="Node dimension .* has size .*, and face dimension .* has size of .* .* expected face dimension .* to actually be size .*", + ): + assert_metadata_ds_consistency(ds, metadata) + + +_SYMMETRIC_DATASETS = {k: v for k, v in datasets_sgrid.items() if k in ("ds_2d_padded_high", "ds_2d_padded_low")} + + +@pytest.mark.parametrize("ds", [pytest.param(ds, id=id_) for id_, ds in _SYMMETRIC_DATASETS.items()]) +@pytest.mark.parametrize("indexer", [slice(None, None, 3), [0]]) +@pytest.mark.parametrize( + "node_dim, face_dim", [("node_dimension1", "face_dimension1"), ("node_dimension2", "face_dimension2")] +) +def test_isel(ds, indexer, node_dim, face_dim): + # Covers HIGH/LOW-specific indexer types (lists, step slices) not exercised by property tests. + metadata = ds.sgrid.metadata + + ds_trimmed = ds.sgrid.isel({node_dim: indexer}) + + assert_metadata_ds_consistency(ds_trimmed, metadata) + + # Assert that other dims haven't been affected + for dim, size_before in ds.dims.items(): + if dim in (node_dim, face_dim): + continue + size_after = ds_trimmed.dims[dim] + assert size_before == size_after + + +@pytest.mark.parametrize("ds", [pytest.param(ds, id=id_) for id_, ds in _SYMMETRIC_DATASETS.items()]) +def test_isel_drop_dim(ds): + ds = ds.copy() + + ds_trimmed = ds.sgrid.isel({"node_dimension1": 0}) + + assert "node_dimension1" not in ds_trimmed.dims + assert "face_dimension1" not in ds_trimmed.dims + + for dim, size_after in ds_trimmed.dims.items(): + size_before = ds.dims[dim] + assert size_before == size_after + + +@pytest.mark.parametrize("ds", [datasets_sgrid["ds_2d_padded_high"]]) +def test_isel_invalid(ds): + with pytest.raises( + ValueError, match=r"Cannot use SGRID accessor to .isel non-spatial \(/SGRID related\) dimension.*" + ): + ds.sgrid.isel(time=slice(None)) + + with pytest.raises(ValueError, match="Dims .* are on the same axis .* according to SGRID metadata.*"): + ds.sgrid.isel(node_dimension1=slice(None), face_dimension1=slice(None)) + + +@st.composite +def valid_isel_slice(draw) -> slice: + """Slice with no step; covers None, positive, and negative start/stop.""" + start = draw(st.one_of(st.none(), st.integers(-25, 25))) + stop = draw(st.one_of(st.none(), st.integers(-25, 25))) + return slice(start, stop) + + +@given( + ds=sgrid_dataset(), + s=valid_isel_slice(), +) +def test_isel_p1_consistency_invariant(ds, s): + """After any valid isel, the SGRID face-node relationship is preserved.""" + metadata = ds.sgrid.metadata + fnp = metadata.face_dimensions[0] + node_dim = fnp.node + assume(node_dim in ds.dims) + n_nodes = ds.dims[node_dim] + assume(len(range(*s.indices(n_nodes))) > 0) # exclude empty selections + + result = ds.sgrid.isel({node_dim: s}) + + assert_metadata_ds_consistency(result, metadata) + + +@given( + ds=sgrid_dataset(), + s=valid_isel_slice(), +) +def test_isel_p2_data_correctness(ds, s): + """Values in the result match those obtained by applying the indexers directly via xarray.""" + metadata = ds.sgrid.metadata + fnp = metadata.face_dimensions[0] + node_dim, face_dim = fnp.node, fnp.face + assume(node_dim in ds.dims) + n_nodes = ds.dims[node_dim] + assume(len(range(*s.indices(n_nodes))) > 0) + + result = ds.sgrid.isel({node_dim: s}) + + from parcels._sgrid.accessor import _derive_paired_indexer + + _, face_indexer = _derive_paired_indexer(s, indexer_is_node=True, padding=fnp.padding, dim_size=n_nodes) + + for var in ds.data_vars: + if var == "grid": + continue + var_dims = ds[var].dims + if node_dim in var_dims and face_dim not in var_dims: + xr.testing.assert_equal(result[var], ds[var].isel({node_dim: s})) + elif face_dim in var_dims and node_dim not in var_dims: + if face_dim in ds.dims: + xr.testing.assert_equal(result[var], ds[var].isel({face_dim: face_indexer})) + + +@given( + ds=sgrid_dataset(), + s=valid_isel_slice(), +) +def test_isel_p3_specification_symmetry(ds, s): + """isel(node_dim=s) produces the same dataset as isel(face_dim=derived_face_slice).""" + metadata = ds.sgrid.metadata + fnp = metadata.face_dimensions[0] + node_dim, face_dim = fnp.node, fnp.face + assume(node_dim in ds.dims and face_dim in ds.dims) + n_nodes = ds.dims[node_dim] + assume(len(range(*s.indices(n_nodes))) > 0) + + from parcels._sgrid.accessor import _derive_paired_indexer + + _, face_indexer = _derive_paired_indexer(s, indexer_is_node=True, padding=fnp.padding, dim_size=n_nodes) + + n_faces = ds.dims.get(face_dim) + if n_faces is not None: + assume(len(range(*face_indexer.indices(n_faces))) > 0) + + result_via_node = ds.sgrid.isel({node_dim: s}) + result_via_face = ds.sgrid.isel({face_dim: face_indexer}) + + xr.testing.assert_identical(result_via_node, result_via_face) + + +@pytest.mark.parametrize("ds_name", ["ds_2d_padded_none", "ds_2d_padded_both"]) +@pytest.mark.parametrize( + "indexer, match", + [ + (5, "Scalar and list indexers are not supported for NONE/BOTH padding"), + ([0, 1, 2], "Scalar and list indexers are not supported for NONE/BOTH padding"), + (slice(0, 8, 2), "Slices with step != 1 are not supported for NONE/BOTH padding"), + ], +) +def test_isel_asymmetric_padding_invalid_indexers(ds_name, indexer, match): + ds = datasets_sgrid[ds_name] + with pytest.raises(ValueError, match=match): + ds.sgrid.isel({"node_dimension1": indexer}) diff --git a/tests/utils/test_sgrid.py b/tests/sgrid/test_sgrid.py similarity index 87% rename from tests/utils/test_sgrid.py rename to tests/sgrid/test_sgrid.py index 3b8292929..5582821d0 100644 --- a/tests/utils/test_sgrid.py +++ b/tests/sgrid/test_sgrid.py @@ -1,13 +1,15 @@ import itertools +import hypothesis.strategies as st import numpy as np import pytest import xarray as xr import xgcm from hypothesis import assume, example, given +import parcels._sgrid as sgrid import parcels._strategies as pst -from parcels._core.utils import sgrid +from parcels._sgrid.core import SGRID_PADDING_TO_XGCM_POSITION, _get_unique_names, parse_grid_attrs def create_example_grid2dmetadata(with_vertical_dimensions: bool, with_node_coordinates: bool): @@ -89,7 +91,7 @@ def dummy_sgrid_2d_ds(grid: sgrid.SGrid2DMetadata) -> xr.Dataset: ds = dummy_comodo_3d_ds() # Can't rename dimensions that already exist in the dataset - assume(sgrid.get_unique_names(grid) & set(ds.dims) == set()) + assume(_get_unique_names(grid) & set(ds.dims) == set()) renamings = {} if grid.vertical_dimensions is None: @@ -114,7 +116,7 @@ def dummy_sgrid_3d_ds(grid: sgrid.SGrid3DMetadata) -> xr.Dataset: ds = dummy_comodo_3d_ds() # Can't rename dimensions that already exist in the dataset - assume(sgrid.get_unique_names(grid) & set(ds.dims) == set()) + assume(_get_unique_names(grid) & set(ds.dims) == set()) renamings = {} for old, new in zip(["XG", "YG", "ZG"], grid.node_dimensions, strict=True): @@ -242,9 +244,9 @@ def test_grid3Dmetadata_standard_names(grid: sgrid.SGrid3DMetadata): @given(pst.sgrid.grid_metadata) -def test_parse_grid_attrs(grid: sgrid.AttrsSerializable): +def test_parse_grid_attrs(grid): attrs = grid.to_attrs() - parsed = sgrid.parse_grid_attrs(attrs) + parsed = parse_grid_attrs(attrs) assert parsed == grid @@ -254,13 +256,13 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.SGrid2DMetadata): """Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided""" ds = dummy_sgrid_2d_ds(grid_metadata) - _, xgcm_kwargs = sgrid.parse_sgrid(ds) + _, xgcm_kwargs = sgrid.xgcm_parse_sgrid(ds) grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs) for obj, axis in zip(grid_metadata.face_dimensions, ["X", "Y"], strict=True): coords = grid.axes[axis].coords assert coords["center"] == obj.face - assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node + assert coords[SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node if grid_metadata.vertical_dimensions is None: assert "Z" not in grid.axes @@ -268,21 +270,24 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.SGrid2DMetadata): obj = grid_metadata.vertical_dimensions[0] coords = grid.axes["Z"].coords assert coords["center"] == obj.face - assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node + assert coords[SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node @given(pst.sgrid.grid3Dmetadata()) +@pytest.mark.xfail( + reason="Parcels doesn't have native support for SGRID 3D grids. This metadata checking is superfluous until we have such support." +) def test_parse_sgrid_3d(grid_metadata: sgrid.SGrid3DMetadata): """Test the ingestion of datasets in XGCM to ensure that it matches the SGRID metadata provided""" ds = dummy_sgrid_3d_ds(grid_metadata) - ds, xgcm_kwargs = sgrid.parse_sgrid(ds) + ds, xgcm_kwargs = sgrid.xgcm_parse_sgrid(ds) grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs) for obj, axis in zip(grid_metadata.volume_dimensions, ["X", "Y", "Z"], strict=True): coords = grid.axes[axis].coords assert coords["center"] == obj.face - assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node + assert coords[SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node @pytest.mark.parametrize( @@ -294,12 +299,12 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.SGrid3DMetadata): + [create_example_grid3dmetadata(with_node_coordinates=i) for i in [False, True]], ) def test_rename(grid): - dims = sgrid.get_unique_names(grid) + dims = _get_unique_names(grid) dims_dict = {dim: f"new_{dim}" for dim in dims} dims_dict_inv = {v: k for k, v in dims_dict.items()} grid_new = grid.rename(dims_dict) - assert dims & set(sgrid.get_unique_names(grid_new)) == set() + assert dims & set(_get_unique_names(grid_new)) == set() assert grid == grid_new.rename(dims_dict_inv) @@ -322,53 +327,6 @@ def test_rename_errors(): grid.rename(names_dict) -@pytest.mark.parametrize( - "ds", - [ - xr.Dataset( - { - "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(10, 10, 10, 10)), - "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(10, 10, 10, 10)), - "grid": ( - [], - np.array(0), - sgrid.SGrid2DMetadata( - cf_role="grid_topology", - topology_dimension=2, - node_dimensions=("XG", "YG"), - face_dimensions=( - sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.HIGH), - sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.HIGH), - ), - vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.HIGH),), - node_coordinates=("lon", "lat"), - ).to_attrs(), - ), - }, - coords={ - "lon": (["XG"], 2 * np.pi / 10 * np.arange(0, 10)), - "lat": (["YG"], 2 * np.pi / (10) * np.arange(0, 10)), - "depth": (["ZG"], np.arange(10)), - "time": (["time"], xr.date_range("2000", "2001", 10), {"axis": "T"}), - }, - ), - ], -) -def test_rename_dataset(ds): - # Check renaming works for coordinates - ds_new = sgrid.rename(ds, {"lon": "lon_updated"}) - grid_new = sgrid.parse_grid_attrs(ds_new["grid"].attrs) - assert "lon_updated" in ds_new.coords - assert "lon_updated" == grid_new.node_coordinates[0] - - # Check renaming works for dim - ds_new = sgrid.rename(ds, {"XC": "XC_updated"}) - grid_new = sgrid.parse_grid_attrs(ds_new["grid"].attrs) - assert "XC_updated" in ds_new.dims - assert "XC" not in ds_new.dims - assert "XC_updated" == grid_new.face_dimensions[0].face - - @pytest.mark.parametrize( ("metadata, expected"), [ @@ -580,3 +538,10 @@ def test_face_node_padding_to_diagram(face_node_padding: sgrid.FaceNodePadding, actual = face_node_padding.to_diagram() lines = actual.split("\n") assert lines == expected_lines + + +@given(n=st.integers(min_value=1), padding=pst.sgrid.padding) +def test_dim_sizes_roundtrip(n, padding): + n_faces = sgrid.get_n_nodes(n, padding) + n_nodes = sgrid.get_n_faces(n_faces, padding) + assert n_nodes == n diff --git a/tests/test_convert.py b/tests/test_convert.py index 88bac2c24..6a05bc960 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -6,7 +6,6 @@ import parcels.convert as convert import parcels.tutorial from parcels import FieldSet -from parcels._core.utils import sgrid from parcels._datasets.remote import open_remote_dataset from parcels._datasets.structured.circulation_models import datasets as datasets_circulation_models from parcels.interpolators._xinterpolators import _get_offsets_dictionary @@ -35,7 +34,7 @@ def test_nemo_to_sgrid_2d(U, V, coords): # noqa: N803 "vertical_dimensions": "depth_center:depth (padding:high)", } - meta = sgrid.parse_grid_attrs(ds["grid"].attrs) + meta = ds.sgrid.metadata # Assuming that node_dimension1 and node_dimension2 correspond to X and Y respectively # check that U and V are properly defined on the staggered grid @@ -75,7 +74,7 @@ def test_nemo_to_sgrid_with_depth(U, V, depth, coords): # noqa: N803 "vertical_dimensions": "depth_center:depth (padding:high)", } - meta = sgrid.parse_grid_attrs(ds["grid"].attrs) + meta = ds.sgrid.metadata # Assuming that node_dimension1 and node_dimension2 correspond to X and Y respectively # check that U and V are properly defined on the staggered grid