diff --git a/src/parcels/_core/utils/sgrid.py b/src/parcels/_core/utils/sgrid.py index 2db7149642..9918f986a1 100644 --- a/src/parcels/_core/utils/sgrid.py +++ b/src/parcels/_core/utils/sgrid.py @@ -13,7 +13,7 @@ import enum import re -from collections.abc import Hashable, Iterable +from collections.abc import Callable, Hashable, Iterable from dataclasses import dataclass from textwrap import indent from typing import Any, Literal, Protocol, Self, cast, overload @@ -22,7 +22,7 @@ from parcels._python import repr_from_dunder_dict -RE_DIM_DIM_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)" +RE_FACE_NODE_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)" Dim = str @@ -56,22 +56,22 @@ def from_attrs(cls, d: dict[str, Hashable]) -> Self: ... # Note that - for some optional attributes in the SGRID spec - these IDs are not available # hence this isn't full coverage -_ID_FETCHERS_GRID2DMETADATA = { +_ID_FETCHERS_GRID2DMETADATA: dict[str, Callable[[Grid2DMetadata], Dim | Padding]] = { "node_dimension1": lambda meta: meta.node_dimensions[0], "node_dimension2": lambda meta: meta.node_dimensions[1], - "face_dimension1": lambda meta: meta.face_dimensions[0].dim1, - "face_dimension2": lambda meta: meta.face_dimensions[1].dim1, + "face_dimension1": lambda meta: meta.face_dimensions[0].face, + "face_dimension2": lambda meta: meta.face_dimensions[1].face, "type1": lambda meta: meta.face_dimensions[0].padding, "type2": lambda meta: meta.face_dimensions[1].padding, } -_ID_FETCHERS_GRID3DMETADATA = { +_ID_FETCHERS_GRID3DMETADATA: dict[str, Callable[[Grid3DMetadata], Dim | Padding]] = { "node_dimension1": lambda meta: meta.node_dimensions[0], "node_dimension2": lambda meta: meta.node_dimensions[1], "node_dimension3": lambda meta: meta.node_dimensions[2], - "face_dimension1": lambda meta: meta.volume_dimensions[0].dim1, - "face_dimension2": lambda meta: meta.volume_dimensions[1].dim1, - "face_dimension3": lambda meta: meta.volume_dimensions[2].dim1, + "face_dimension1": lambda meta: meta.volume_dimensions[0].face, + "face_dimension2": lambda meta: meta.volume_dimensions[1].face, + "face_dimension3": lambda meta: meta.volume_dimensions[2].face, "type1": lambda meta: meta.volume_dimensions[0].padding, "type2": lambda meta: meta.volume_dimensions[1].padding, "type3": lambda meta: meta.volume_dimensions[2].padding, @@ -84,9 +84,9 @@ def __init__( cf_role: Literal["grid_topology"], topology_dimension: Literal[2], node_dimensions: tuple[Dim, Dim], - face_dimensions: tuple[DimDimPadding, DimDimPadding], + face_dimensions: tuple[FaceNodePadding, FaceNodePadding], node_coordinates: None | tuple[Dim, Dim] = None, - vertical_dimensions: None | tuple[DimDimPadding] = None, + vertical_dimensions: None | tuple[FaceNodePadding] = None, ): if cf_role != "grid_topology": raise ValueError(f"cf_role must be 'grid_topology', got {cf_role!r}") @@ -104,9 +104,9 @@ def __init__( if not ( isinstance(face_dimensions, tuple) and len(face_dimensions) == 2 - and all(isinstance(fd, DimDimPadding) for fd in face_dimensions) + and all(isinstance(fd, FaceNodePadding) for fd in face_dimensions) ): - raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid") + raise ValueError("face_dimensions must be a tuple of 2 FaceNodePadding for a 2D grid") if node_coordinates is not None: if not ( @@ -120,9 +120,9 @@ def __init__( if not ( isinstance(vertical_dimensions, tuple) and len(vertical_dimensions) == 1 - and isinstance(vertical_dimensions[0], DimDimPadding) + and isinstance(vertical_dimensions[0], FaceNodePadding) ): - raise ValueError("vertical_dimensions must be a tuple of 1 DimDimPadding for a 2D grid") + raise ValueError("vertical_dimensions must be a tuple of 1 FaceNodePadding for a 2D grid") # Required attributes self.cf_role = cf_role @@ -137,8 +137,8 @@ def __init__( #! Some optional attributes aren't really important to Parcels, can be added later if needed # Optional attributes # # With defaults (set in init) - # edge1_dimensions: tuple[Dim, DimDimPadding] - # edge2_dimensions: tuple[DimDimPadding, Dim] + # edge1_dimensions: tuple[Dim, FaceNodePadding] + # edge2_dimensions: tuple[FaceNodePadding, Dim] # # Without defaults # edge1_coordinates: None | Any = None @@ -163,7 +163,7 @@ def from_attrs(cls, attrs): # type: ignore[override] cf_role=attrs["cf_role"], topology_dimension=attrs["topology_dimension"], node_dimensions=cast(tuple[Dim, Dim], load_mappings(attrs["node_dimensions"])), - face_dimensions=cast(tuple[DimDimPadding, DimDimPadding], load_mappings(attrs["face_dimensions"])), + face_dimensions=cast(tuple[FaceNodePadding, FaceNodePadding], load_mappings(attrs["face_dimensions"])), node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")), vertical_dimensions=maybe_load_mappings(attrs.get("vertical_dimensions")), ) @@ -186,7 +186,7 @@ def to_attrs(self) -> dict[str, str | int]: def rename(self, names_dict: dict[str, str]) -> Self: return cast(Self, _metadata_rename(self, names_dict)) - def get_value_by_id(self, id: str) -> str: + def get_value_by_id(self, id: str) -> Dim | Padding: """In the SGRID specification for 2D grids, different parts of the spec are identified by different "ID"s. Easily extract the value for a given ID. @@ -206,7 +206,7 @@ def __init__( cf_role: Literal["grid_topology"], topology_dimension: Literal[3], node_dimensions: tuple[Dim, Dim, Dim], - volume_dimensions: tuple[DimDimPadding, DimDimPadding, DimDimPadding], + volume_dimensions: tuple[FaceNodePadding, FaceNodePadding, FaceNodePadding], node_coordinates: None | tuple[Dim, Dim, Dim] = None, ): if cf_role != "grid_topology": @@ -225,9 +225,9 @@ def __init__( if not ( isinstance(volume_dimensions, tuple) and len(volume_dimensions) == 3 - and all(isinstance(fd, DimDimPadding) for fd in volume_dimensions) + and all(isinstance(fd, FaceNodePadding) for fd in volume_dimensions) ): - raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid") + raise ValueError("face_dimensions must be a tuple of 2 FaceNodePadding for a 2D grid") if node_coordinates is not None: if not ( @@ -249,12 +249,12 @@ def __init__( # ! Some optional attributes aren't really important to Parcels, can be added later if needed # Optional attributes # # With defaults (set in init) - # edge1_dimensions: tuple[DimDimPadding, Dim, Dim] - # edge2_dimensions: tuple[Dim, DimDimPadding, Dim] - # edge3_dimensions: tuple[Dim, Dim, DimDimPadding] - # face1_dimensions: tuple[Dim, DimDimPadding, DimDimPadding] - # face2_dimensions: tuple[DimDimPadding, Dim, DimDimPadding] - # face3_dimensions: tuple[DimDimPadding, DimDimPadding, Dim] + # edge1_dimensions: tuple[FaceNodePadding, Dim, Dim] + # edge2_dimensions: tuple[Dim, FaceNodePadding, Dim] + # edge3_dimensions: tuple[Dim, Dim, FaceNodePadding] + # face1_dimensions: tuple[Dim, FaceNodePadding, FaceNodePadding] + # face2_dimensions: tuple[FaceNodePadding, Dim, FaceNodePadding] + # face3_dimensions: tuple[FaceNodePadding, FaceNodePadding, Dim] # # Without defaults # edge *i_coordinates* @@ -280,7 +280,7 @@ def from_attrs(cls, attrs): # type: ignore[override] topology_dimension=attrs["topology_dimension"], node_dimensions=cast(tuple[Dim, Dim, Dim], load_mappings(attrs["node_dimensions"])), volume_dimensions=cast( - tuple[DimDimPadding, DimDimPadding, DimDimPadding], load_mappings(attrs["volume_dimensions"]) + tuple[FaceNodePadding, FaceNodePadding, FaceNodePadding], load_mappings(attrs["volume_dimensions"]) ), node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")), ) @@ -301,7 +301,7 @@ def to_attrs(self) -> dict[str, str | int]: def rename(self, dims_dict: dict[str, str]) -> Self: return cast(Self, _metadata_rename(self, dims_dict)) - def get_value_by_id(self, id: str) -> str: + def get_value_by_id(self, id: str) -> Dim | Padding: """In the SGRID specification for 3D grids, different parts of the spec are identified by different "ID"s. Easily extract the value for a given ID. @@ -316,39 +316,39 @@ def get_value_by_id(self, id: str) -> str: @dataclass -class DimDimPadding: - """A data class representing a dimension-dimension-padding triplet for SGrid metadata. +class FaceNodePadding: + """A data class representing a face-node-padding triplet for SGrid metadata. - This triplet can represent different relations depending on context within the standard - For example - for "face_dimensions" this can show the relation between an edge (dim1) and a node - (dim2). + In the context of a 2D grid, "face" corresponds with an edge. + + Use the .to_diagram() method to visualize the representation. """ - dim1: str - dim2: str + face: Dim + node: Dim padding: Padding def __repr__(self) -> str: - return f"DimDimPadding(dim1={self.dim1!r}, dim2={self.dim2!r}, padding={self.padding!r})" + return f"FaceNodePadding(face={self.face!r}, node={self.node!r}, padding={self.padding!r})" def __str__(self) -> str: - return f"{self.dim1}:{self.dim2} (padding:{self.padding.value})" + return f"{self.face}:{self.node} (padding:{self.padding.value})" @classmethod def load(cls, s: str) -> Self: - match = re.match(RE_DIM_DIM_PADDING, s) + match = re.match(RE_FACE_NODE_PADDING, s) if not match: - raise ValueError(f"String {s!r} does not match expected format for DimDimPadding") - dim1 = match.group(1) - dim2 = match.group(2) + raise ValueError(f"String {s!r} does not match expected format for FaceNodePadding") + face = match.group(1) + node = match.group(2) padding = Padding(match.group(3).lower()) - return cls(dim1, dim2, padding) + return cls(face, node, padding) def to_diagram(self) -> str: return "\n".join(_face_node_padding_to_text(self)) -def dump_mappings(parts: Iterable[DimDimPadding | Dim]) -> str: +def dump_mappings(parts: Iterable[FaceNodePadding | Dim]) -> str: """Takes in a list of edge-node-padding tuples and serializes them into a string according to the SGrid convention. """ @@ -361,7 +361,7 @@ def dump_mappings(parts: Iterable[DimDimPadding | Dim]) -> str: @overload def maybe_dump_mappings(parts: None) -> None: ... @overload -def maybe_dump_mappings(parts: Iterable[DimDimPadding | Dim]) -> str: ... +def maybe_dump_mappings(parts: Iterable[FaceNodePadding | Dim]) -> str: ... def maybe_dump_mappings(parts): @@ -370,7 +370,7 @@ def maybe_dump_mappings(parts): return dump_mappings(parts) -def load_mappings(s: str) -> tuple[DimDimPadding | Dim, ...]: +def load_mappings(s: str) -> tuple[FaceNodePadding | Dim, ...]: """Takes in a string indicating the mappings of dims and dim-dim-padding and returns a tuple with this data destructured. @@ -383,25 +383,25 @@ def load_mappings(s: str) -> tuple[DimDimPadding | Dim, ...]: ret = [] while s: # find next part - match = re.match(RE_DIM_DIM_PADDING, s) + match = re.match(RE_FACE_NODE_PADDING, s) if match and match.start() == 0: # match found at start, take that as next part part = match.group(0) s_new = s[match.end() :].lstrip() else: - # no DimDimPadding match at start, assume just a Dim until next space + # no FaceNodePadding match at start, assume just a Dim until next space part, *s_new = s.split(" ", 1) s_new = "".join(s_new) assert s != s_new, f"SGrid parsing did not advance, stuck at {s!r}" - parsed: DimDimPadding | Dim + parsed: FaceNodePadding | Dim try: - parsed = DimDimPadding.load(part) + parsed = FaceNodePadding.load(part) except ValueError as e: e.add_note(f"Failed to parse part {part!r} from {s!r} as a dimension dimension padding string") try: - # Not a DimDimPadding, assume it's just a Dim + # Not a FaceNodePadding, assume it's just a Dim assert ":" not in part, f"Part {part!r} from {s!r} not a valid dim (contains ':')" parsed = part except AssertionError as e2: @@ -416,7 +416,7 @@ def load_mappings(s: str) -> tuple[DimDimPadding | Dim, ...]: @overload def maybe_load_mappings(s: None) -> None: ... @overload -def maybe_load_mappings(s: Hashable) -> tuple[DimDimPadding | Dim, ...]: ... +def maybe_load_mappings(s: Hashable) -> tuple[FaceNodePadding | Dim, ...]: ... def maybe_load_mappings(s): @@ -471,11 +471,11 @@ def parse_sgrid(ds: xr.Dataset): dimensions = grid.volume_dimensions xgcm_coords = {} - for dim_dim_padding, axis in zip(dimensions, "XYZ", strict=False): - xgcm_position = SGRID_PADDING_TO_XGCM_POSITION[dim_dim_padding.padding] + for face_node_padding, axis in zip(dimensions, "XYZ", strict=False): + xgcm_position = SGRID_PADDING_TO_XGCM_POSITION[face_node_padding.padding] coords = {} - for pos, dim in [("center", dim_dim_padding.dim1), (xgcm_position, dim_dim_padding.dim2)]: + for pos, dim in [("center", face_node_padding.face), (xgcm_position, face_node_padding.node)]: # only include dimensions in dataset (ignore dimensions in metadata that may not exist - e.g., due to `.isel`) if dim in ds.dims: coords[pos] = dim @@ -510,16 +510,16 @@ def get_unique_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]: f"Expected sgrid metadata attribute to be represented as a tuple, got {value!r}. This is an internal error to Parcels - please post an issue if you encounter this." ) for item in value: - if isinstance(item, DimDimPadding): - dims.add(item.dim1) - dims.add(item.dim2) + if isinstance(item, FaceNodePadding): + dims.add(item.face) + dims.add(item.node) else: assert isinstance(item, str) dims.add(item) return dims -def _face_node_padding_to_text(obj: DimDimPadding) -> list[str]: +def _face_node_padding_to_text(obj: FaceNodePadding) -> list[str]: """Return ASCII diagram lines showing a face-node padding relationship. Produces a symbolic 5-node diagram like the image below, matching the @@ -566,7 +566,7 @@ def _face_node_padding_to_text(obj: DimDimPadding) -> list[str]: face_count += 1 return [ - f"{obj.dim1}:{obj.dim2} (padding:{padding.value})", + f"{obj.face}:{obj.node} (padding:{padding.value})", f" {bar_rendered}", f" {label.rstrip()}", ] @@ -640,19 +640,19 @@ def _grid2d_to_ascii(grid: Grid2DMetadata) -> str: nd = grid.node_dimensions lines = [ "Grid2DMetadata", - f" X-axis: face={fd[0].dim1!r} node={nd[0]!r} padding={fd[0].padding.value}", - f" Y-axis: face={fd[1].dim1!r} node={nd[1]!r} padding={fd[1].padding.value}", + f" X-axis: face={fd[0].face!r} node={nd[0]!r} padding={fd[0].padding.value}", + f" Y-axis: face={fd[1].face!r} node={nd[1]!r} padding={fd[1].padding.value}", ] if grid.vertical_dimensions: vd = grid.vertical_dimensions[0] - lines.append(f" Z-axis: face={vd.dim1!r} node={vd.dim2!r} padding={vd.padding.value}") + lines.append(f" Z-axis: face={vd.face!r} node={vd.node!r} padding={vd.padding.value}") if grid.node_coordinates: lines.append(f" Coordinates: {grid.node_coordinates[0]}, {grid.node_coordinates[1]}") - format_kwargs = dict(n1=nd[0], n2=nd[1], u=fd[0].dim1, v=fd[1].dim1) + format_kwargs = dict(n1=nd[0], n2=nd[1], u=fd[0].face, v=fd[1].face) if grid.vertical_dimensions: - format_kwargs["w"] = grid.vertical_dimensions[0].dim2 + format_kwargs["w"] = grid.vertical_dimensions[0].node lines += indent(_TEXT_GRID2D_WITH_Z, " ").format(**format_kwargs).split("\n") else: lines += indent(_TEXT_GRID2D_WITHOUT_Z, " ").format(**format_kwargs).split("\n") @@ -672,16 +672,16 @@ def _grid3d_to_ascii(grid: Grid3DMetadata) -> str: nd = grid.node_dimensions lines = [ "Grid3DMetadata", - f" X-axis: face={vd[0].dim1!r} node={nd[0]!r} padding={vd[0].padding.value}", - f" Y-axis: face={vd[1].dim1!r} node={nd[1]!r} padding={vd[1].padding.value}", - f" Z-axis: face={vd[2].dim1!r} node={nd[2]!r} padding={vd[2].padding.value}", + f" X-axis: face={vd[0].face!r} node={nd[0]!r} padding={vd[0].padding.value}", + f" Y-axis: face={vd[1].face!r} node={nd[1]!r} padding={vd[1].padding.value}", + f" Z-axis: face={vd[2].face!r} node={nd[2]!r} padding={vd[2].padding.value}", ] if grid.node_coordinates: lines.append(f" Coordinates: {', '.join(grid.node_coordinates)}") lines += ( indent(_TEXT_GRID3D, " ") - .format(n1=nd[0], n2=nd[1], n3=nd[2], u=vd[0].dim1, v=vd[1].dim1, w=vd[2].dim1) + .format(n1=nd[0], n2=nd[1], n3=nd[2], u=vd[0].face, v=vd[1].face, w=vd[2].face) .split("\n") ) @@ -738,10 +738,10 @@ def _metadata_rename(grid, names_dict): if isinstance(value, tuple): new_value = [] for item in value: - if isinstance(item, DimDimPadding): - new_item = DimDimPadding( - dim1=names_dict[item.dim1], - dim2=names_dict[item.dim2], + if isinstance(item, FaceNodePadding): + new_item = FaceNodePadding( + face=names_dict[item.face], + node=names_dict[item.node], padding=item.padding, ) new_value.append(new_item) diff --git a/src/parcels/_datasets/structured/generated.py b/src/parcels/_datasets/structured/generated.py index 1ff9e0ea00..8a4eaed988 100644 --- a/src/parcels/_datasets/structured/generated.py +++ b/src/parcels/_datasets/structured/generated.py @@ -4,7 +4,7 @@ import xarray as xr from parcels._core.utils.sgrid import ( - DimDimPadding, + FaceNodePadding, Grid2DMetadata, Padding, _attach_sgrid_metadata, @@ -35,10 +35,10 @@ def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"): topology_dimension=2, node_dimensions=("XG", "YG"), face_dimensions=( - DimDimPadding("XC", "XG", Padding.LOW), - DimDimPadding("YC", "YG", Padding.LOW), + FaceNodePadding("XC", "XG", Padding.LOW), + FaceNodePadding("YC", "YG", Padding.LOW), ), - vertical_dimensions=(DimDimPadding("ZC", "depth", Padding.BOTH),), + vertical_dimensions=(FaceNodePadding("ZC", "depth", Padding.BOTH),), ), ) diff --git a/src/parcels/_datasets/structured/generic.py b/src/parcels/_datasets/structured/generic.py index 008f785a05..3725e544bd 100644 --- a/src/parcels/_datasets/structured/generic.py +++ b/src/parcels/_datasets/structured/generic.py @@ -2,7 +2,7 @@ import xarray as xr from parcels._core.utils.sgrid import ( - DimDimPadding, + FaceNodePadding, Grid2DMetadata, Padding, _attach_sgrid_metadata, @@ -254,11 +254,11 @@ def _unrolled_cone_curvilinear_grid(): topology_dimension=2, node_dimensions=("XG", "YG"), face_dimensions=( - DimDimPadding("XC", "XG", Padding.HIGH), - DimDimPadding("YC", "YG", Padding.HIGH), + FaceNodePadding("XC", "XG", Padding.HIGH), + FaceNodePadding("YC", "YG", Padding.HIGH), ), node_coordinates=("lon", "lat"), - vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.HIGH),), + vertical_dimensions=(FaceNodePadding("ZC", "ZG", Padding.HIGH),), ), ) .pipe( @@ -275,11 +275,11 @@ def _unrolled_cone_curvilinear_grid(): topology_dimension=2, node_dimensions=("XG", "YG"), face_dimensions=( - DimDimPadding("XC", "XG", Padding.LOW), - DimDimPadding("YC", "YG", Padding.LOW), + FaceNodePadding("XC", "XG", Padding.LOW), + FaceNodePadding("YC", "YG", Padding.LOW), ), node_coordinates=("lon", "lat"), - vertical_dimensions=(DimDimPadding("ZC", "ZG", Padding.LOW),), + vertical_dimensions=(FaceNodePadding("ZC", "ZG", Padding.LOW),), ), ) .pipe( diff --git a/src/parcels/convert.py b/src/parcels/convert.py index 44ccb6200f..32f81d9009 100644 --- a/src/parcels/convert.py +++ b/src/parcels/convert.py @@ -329,10 +329,10 @@ def nemo_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr.Da node_dimensions=("x", "y"), node_coordinates=("glamf", "gphif"), face_dimensions=( - sgrid.DimDimPadding("x_center", "x", sgrid.Padding.LOW), - sgrid.DimDimPadding("y_center", "y", sgrid.Padding.LOW), + sgrid.FaceNodePadding("x_center", "x", sgrid.Padding.LOW), + sgrid.FaceNodePadding("y_center", "y", sgrid.Padding.LOW), ), - vertical_dimensions=(sgrid.DimDimPadding("z_center", "depth", sgrid.Padding.HIGH),), + vertical_dimensions=(sgrid.FaceNodePadding("z_center", "depth", sgrid.Padding.HIGH),), ).to_attrs(), ) @@ -397,10 +397,10 @@ def mitgcm_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr. node_dimensions=("lon", "lat"), node_coordinates=("lon", "lat"), face_dimensions=( - sgrid.DimDimPadding("XC", "lon", sgrid.Padding.HIGH), - sgrid.DimDimPadding("YC", "lat", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("XC", "lon", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("YC", "lat", sgrid.Padding.HIGH), ), - vertical_dimensions=(sgrid.DimDimPadding("depth", "depth", sgrid.Padding.HIGH),), + vertical_dimensions=(sgrid.FaceNodePadding("depth", "depth", sgrid.Padding.HIGH),), ).to_attrs(), ) @@ -458,10 +458,10 @@ def croco_to_sgrid(*, fields: dict[str, xr.Dataset | xr.DataArray], coords: xr.D node_dimensions=("lon", "lat"), node_coordinates=("lon", "lat"), face_dimensions=( - sgrid.DimDimPadding("xi_u", "xi_rho", sgrid.Padding.HIGH), - sgrid.DimDimPadding("eta_v", "eta_rho", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("xi_u", "xi_rho", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("eta_v", "eta_rho", sgrid.Padding.HIGH), ), - vertical_dimensions=(sgrid.DimDimPadding("s_rho", "depth", sgrid.Padding.HIGH),), + vertical_dimensions=(sgrid.FaceNodePadding("s_rho", "depth", sgrid.Padding.HIGH),), ).to_attrs(), ) @@ -522,10 +522,10 @@ def copernicusmarine_to_sgrid( node_dimensions=("lon", "lat"), node_coordinates=("lon", "lat"), face_dimensions=( - sgrid.DimDimPadding("x_center", "lon", sgrid.Padding.LOW), - sgrid.DimDimPadding("y_center", "lat", sgrid.Padding.LOW), + sgrid.FaceNodePadding("x_center", "lon", sgrid.Padding.LOW), + sgrid.FaceNodePadding("y_center", "lat", sgrid.Padding.LOW), ), - vertical_dimensions=(sgrid.DimDimPadding("z_center", "depth", sgrid.Padding.LOW),), + vertical_dimensions=(sgrid.FaceNodePadding("z_center", "depth", sgrid.Padding.LOW),), ).to_attrs(), ) diff --git a/tests/strategies/sgrid.py b/tests/strategies/sgrid.py index ccfc100a02..a613d4e31d 100644 --- a/tests/strategies/sgrid.py +++ b/tests/strategies/sgrid.py @@ -20,7 +20,7 @@ dim_dim_padding = ( st.tuples(dimension_name, dimension_name, padding) .filter(lambda t: t[0] != t[1]) - .map(lambda t: sgrid.DimDimPadding(*t)) + .map(lambda t: sgrid.FaceNodePadding(*t)) ) mappings = st.lists(dim_dim_padding | dimension_name).map(tuple) @@ -45,8 +45,8 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata: node_coordinates_var2 = names[5] has_node_coordinates = draw(st.booleans()) - vertical_dimensions_dim1 = names[6] - vertical_dimensions_dim2 = names[7] + vertical_dimensions_face = names[6] + vertical_dimensions_node = names[7] vertical_dimensions_padding = draw(padding) has_vertical_dimensions = draw(st.booleans()) @@ -57,7 +57,7 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata: if has_vertical_dimensions: vertical_dimensions = ( - sgrid.DimDimPadding(vertical_dimensions_dim1, vertical_dimensions_dim2, vertical_dimensions_padding), + sgrid.FaceNodePadding(vertical_dimensions_face, vertical_dimensions_node, vertical_dimensions_padding), ) else: vertical_dimensions = None @@ -67,8 +67,8 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata: topology_dimension=2, node_dimensions=(node_dimension1, node_dimension2), face_dimensions=( - sgrid.DimDimPadding(face_dimension1, node_dimension1, padding_type1), - sgrid.DimDimPadding(face_dimension2, node_dimension2, padding_type2), + sgrid.FaceNodePadding(face_dimension1, node_dimension1, padding_type1), + sgrid.FaceNodePadding(face_dimension2, node_dimension2, padding_type2), ), node_coordinates=node_coordinates, vertical_dimensions=vertical_dimensions, @@ -108,9 +108,9 @@ def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata: topology_dimension=3, node_dimensions=(node_dimension1, node_dimension2, node_dimension3), volume_dimensions=( - sgrid.DimDimPadding(face_dimension1, node_dimension1, padding_type1), - sgrid.DimDimPadding(face_dimension2, node_dimension2, padding_type2), - sgrid.DimDimPadding(face_dimension3, node_dimension3, padding_type3), + sgrid.FaceNodePadding(face_dimension1, node_dimension1, padding_type1), + sgrid.FaceNodePadding(face_dimension2, node_dimension2, padding_type2), + sgrid.FaceNodePadding(face_dimension3, node_dimension3, padding_type3), ), node_coordinates=node_coordinates, ) diff --git a/tests/utils/test_sgrid.py b/tests/utils/test_sgrid.py index 43588a987c..78339948fa 100644 --- a/tests/utils/test_sgrid.py +++ b/tests/utils/test_sgrid.py @@ -12,7 +12,7 @@ def create_example_grid2dmetadata(with_vertical_dimensions: bool, with_node_coordinates: bool): vertical_dimensions = ( - (sgrid.DimDimPadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),) + (sgrid.FaceNodePadding("vertical_dimensions_dim1", "vertical_dimensions_dim2", sgrid.Padding.LOW),) if with_vertical_dimensions else None ) @@ -23,8 +23,8 @@ def create_example_grid2dmetadata(with_vertical_dimensions: bool, with_node_coor topology_dimension=2, node_dimensions=("node_dimension1", "node_dimension2"), face_dimensions=( - sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), - sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), + sgrid.FaceNodePadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), + sgrid.FaceNodePadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), ), node_coordinates=node_coordinates, vertical_dimensions=vertical_dimensions, @@ -40,9 +40,9 @@ def create_example_grid3dmetadata(with_node_coordinates: bool): topology_dimension=3, node_dimensions=("node_dimension1", "node_dimension2", "node_dimension3"), volume_dimensions=( - sgrid.DimDimPadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), - sgrid.DimDimPadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), - sgrid.DimDimPadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW), + sgrid.FaceNodePadding("face_dimension1", "node_dimension1", sgrid.Padding.LOW), + sgrid.FaceNodePadding("face_dimension2", "node_dimension2", sgrid.Padding.LOW), + sgrid.FaceNodePadding("face_dimension3", "node_dimension3", sgrid.Padding.LOW), ), node_coordinates=node_coordinates, ) @@ -95,13 +95,13 @@ def dummy_sgrid_2d_ds(grid: sgrid.Grid2DMetadata) -> xr.Dataset: if grid.vertical_dimensions is None: ds = ds.isel(ZC=0, ZG=0) else: - renamings.update({"ZC": grid.vertical_dimensions[0].dim2, "ZG": grid.vertical_dimensions[0].dim1}) + renamings.update({"ZC": grid.vertical_dimensions[0].face, "ZG": grid.vertical_dimensions[0].node}) for old, new in zip(["XG", "YG"], grid.node_dimensions, strict=True): renamings[old] = new - for old, dim_dim_padding in zip(["XC", "YC"], grid.face_dimensions, strict=True): - renamings[old] = dim_dim_padding.dim1 + for old, face_node_padding in zip(["XC", "YC"], grid.face_dimensions, strict=True): + renamings[old] = face_node_padding.face ds = ds.rename_dims(renamings) @@ -120,8 +120,8 @@ def dummy_sgrid_3d_ds(grid: sgrid.Grid3DMetadata) -> xr.Dataset: for old, new in zip(["XG", "YG", "ZG"], grid.node_dimensions, strict=True): renamings[old] = new - for old, dim_dim_padding in zip(["XC", "YC", "ZC"], grid.volume_dimensions, strict=True): - renamings[old] = dim_dim_padding.dim1 + for old, face_node_padding in zip(["XC", "YC", "ZC"], grid.volume_dimensions, strict=True): + renamings[old] = face_node_padding.face ds = ds.rename_dims(renamings) @@ -179,8 +179,8 @@ def dummy_comodo_3d_ds() -> xr.Dataset: @example( edge_node_padding=( - sgrid.DimDimPadding("edge1", "node1", sgrid.Padding.NONE), - sgrid.DimDimPadding("edge2", "node2", sgrid.Padding.LOW), + sgrid.FaceNodePadding("edge1", "node1", sgrid.Padding.NONE), + sgrid.FaceNodePadding("edge2", "node2", sgrid.Padding.LOW), ) ) @given(sgrid_strategies.mappings) @@ -195,7 +195,7 @@ def test_edge_node_mapping_metadata_roundtrip(edge_node_padding): [ ( "edge1: node1(padding: none)", - (sgrid.DimDimPadding("edge1", "node1", sgrid.Padding.NONE),), + (sgrid.FaceNodePadding("edge1", "node1", sgrid.Padding.NONE),), ), ], ) @@ -235,20 +235,18 @@ def test_parse_sgrid_2d(grid_metadata: sgrid.Grid2DMetadata): _, xgcm_kwargs = sgrid.parse_sgrid(ds) grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs) - for ddp, axis in zip(grid_metadata.face_dimensions, ["X", "Y"], strict=True): - dim_edge, dim_node, padding = ddp.dim1, ddp.dim2, ddp.padding + for obj, axis in zip(grid_metadata.face_dimensions, ["X", "Y"], strict=True): coords = grid.axes[axis].coords - assert coords["center"] == dim_edge - assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node + assert coords["center"] == obj.face + assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node if grid_metadata.vertical_dimensions is None: assert "Z" not in grid.axes else: - ddp = grid_metadata.vertical_dimensions[0] - dim_edge, dim_node, padding = ddp.dim1, ddp.dim2, ddp.padding + obj = grid_metadata.vertical_dimensions[0] coords = grid.axes["Z"].coords - assert coords["center"] == dim_edge - assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node + assert coords["center"] == obj.face + assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node @given(sgrid_strategies.grid3Dmetadata()) @@ -259,11 +257,10 @@ def test_parse_sgrid_3d(grid_metadata: sgrid.Grid3DMetadata): ds, xgcm_kwargs = sgrid.parse_sgrid(ds) grid = xgcm.Grid(ds, autoparse_metadata=False, **xgcm_kwargs) - for ddp, axis in zip(grid_metadata.volume_dimensions, ["X", "Y", "Z"], strict=True): - dim_edge, dim_node, padding = ddp.dim1, ddp.dim2, ddp.padding + for obj, axis in zip(grid_metadata.volume_dimensions, ["X", "Y", "Z"], strict=True): coords = grid.axes[axis].coords - assert coords["center"] == dim_edge - assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[padding]] == dim_node + assert coords["center"] == obj.face + assert coords[sgrid.SGRID_PADDING_TO_XGCM_POSITION[obj.padding]] == obj.node @pytest.mark.parametrize( @@ -318,10 +315,10 @@ def test_rename_errors(): topology_dimension=2, node_dimensions=("XG", "YG"), face_dimensions=( - sgrid.DimDimPadding("XC", "XG", sgrid.Padding.HIGH), - sgrid.DimDimPadding("YC", "YG", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.HIGH), ), - vertical_dimensions=(sgrid.DimDimPadding("ZC", "ZG", sgrid.Padding.HIGH),), + vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.HIGH),), node_coordinates=("lon", "lat"), ).to_attrs(), ), @@ -347,7 +344,7 @@ def test_rename_dataset(ds): grid_new = sgrid.parse_grid_attrs(ds_new["grid"].attrs) assert "XC_updated" in ds_new.dims assert "XC" not in ds_new.dims - assert "XC_updated" == grid_new.face_dimensions[0].dim1 + assert "XC_updated" == grid_new.face_dimensions[0].face @pytest.mark.parametrize( @@ -524,7 +521,7 @@ def test_grid_str(metadata, expected): ("face_node_padding", "expected_lines"), [ ( - sgrid.DimDimPadding("face", "node", sgrid.Padding.LOW), + sgrid.FaceNodePadding("face", "node", sgrid.Padding.LOW), [ "face:node (padding:low)", " ─────●─────●─────●─────●─────●", @@ -532,7 +529,7 @@ def test_grid_str(metadata, expected): ], ), ( - sgrid.DimDimPadding("face", "node", sgrid.Padding.HIGH), + sgrid.FaceNodePadding("face", "node", sgrid.Padding.HIGH), [ "face:node (padding:high)", " ●─────●─────●─────●─────●─────", @@ -540,7 +537,7 @@ def test_grid_str(metadata, expected): ], ), ( - sgrid.DimDimPadding("face", "node", sgrid.Padding.BOTH), + sgrid.FaceNodePadding("face", "node", sgrid.Padding.BOTH), [ "face:node (padding:both)", " ─────●─────●─────●─────●─────●─────", @@ -548,7 +545,7 @@ def test_grid_str(metadata, expected): ], ), ( - sgrid.DimDimPadding("face", "node", sgrid.Padding.NONE), + sgrid.FaceNodePadding("face", "node", sgrid.Padding.NONE), [ "face:node (padding:none)", " ●─────●─────●─────●─────●", @@ -557,7 +554,7 @@ def test_grid_str(metadata, expected): ), ], ) -def test_face_node_padding_to_diagram(face_node_padding: sgrid.DimDimPadding, expected_lines: list[str]): +def test_face_node_padding_to_diagram(face_node_padding: sgrid.FaceNodePadding, expected_lines: list[str]): actual = face_node_padding.to_diagram() lines = actual.split("\n") assert lines == expected_lines