Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 74 additions & 74 deletions src/parcels/_core/utils/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import enum
import re
from collections.abc import Hashable, Iterable
from collections.abc import Callable, Hashable, Iterable
from dataclasses import dataclass
from textwrap import indent
from typing import Any, Literal, Protocol, Self, cast, overload
Expand All @@ -22,7 +22,7 @@

from parcels._python import repr_from_dunder_dict

RE_DIM_DIM_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)"
RE_FACE_NODE_PADDING = r"(\w+):(\w+)\s*\(padding:\s*(\w+)\)"

Dim = str

Expand Down Expand Up @@ -56,22 +56,22 @@ def from_attrs(cls, d: dict[str, Hashable]) -> Self: ...

# Note that - for some optional attributes in the SGRID spec - these IDs are not available
# hence this isn't full coverage
_ID_FETCHERS_GRID2DMETADATA = {
_ID_FETCHERS_GRID2DMETADATA: dict[str, Callable[[Grid2DMetadata], Dim | Padding]] = {
"node_dimension1": lambda meta: meta.node_dimensions[0],
"node_dimension2": lambda meta: meta.node_dimensions[1],
"face_dimension1": lambda meta: meta.face_dimensions[0].dim1,
"face_dimension2": lambda meta: meta.face_dimensions[1].dim1,
"face_dimension1": lambda meta: meta.face_dimensions[0].face,
"face_dimension2": lambda meta: meta.face_dimensions[1].face,
"type1": lambda meta: meta.face_dimensions[0].padding,
"type2": lambda meta: meta.face_dimensions[1].padding,
}

_ID_FETCHERS_GRID3DMETADATA = {
_ID_FETCHERS_GRID3DMETADATA: dict[str, Callable[[Grid3DMetadata], Dim | Padding]] = {
"node_dimension1": lambda meta: meta.node_dimensions[0],
"node_dimension2": lambda meta: meta.node_dimensions[1],
"node_dimension3": lambda meta: meta.node_dimensions[2],
"face_dimension1": lambda meta: meta.volume_dimensions[0].dim1,
"face_dimension2": lambda meta: meta.volume_dimensions[1].dim1,
"face_dimension3": lambda meta: meta.volume_dimensions[2].dim1,
"face_dimension1": lambda meta: meta.volume_dimensions[0].face,
"face_dimension2": lambda meta: meta.volume_dimensions[1].face,
"face_dimension3": lambda meta: meta.volume_dimensions[2].face,
"type1": lambda meta: meta.volume_dimensions[0].padding,
"type2": lambda meta: meta.volume_dimensions[1].padding,
"type3": lambda meta: meta.volume_dimensions[2].padding,
Expand All @@ -84,9 +84,9 @@ def __init__(
cf_role: Literal["grid_topology"],
topology_dimension: Literal[2],
node_dimensions: tuple[Dim, Dim],
face_dimensions: tuple[DimDimPadding, DimDimPadding],
face_dimensions: tuple[FaceNodePadding, FaceNodePadding],
node_coordinates: None | tuple[Dim, Dim] = None,
vertical_dimensions: None | tuple[DimDimPadding] = None,
vertical_dimensions: None | tuple[FaceNodePadding] = None,
):
if cf_role != "grid_topology":
raise ValueError(f"cf_role must be 'grid_topology', got {cf_role!r}")
Expand All @@ -104,9 +104,9 @@ def __init__(
if not (
isinstance(face_dimensions, tuple)
and len(face_dimensions) == 2
and all(isinstance(fd, DimDimPadding) for fd in face_dimensions)
and all(isinstance(fd, FaceNodePadding) for fd in face_dimensions)
):
raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid")
raise ValueError("face_dimensions must be a tuple of 2 FaceNodePadding for a 2D grid")

if node_coordinates is not None:
if not (
Expand All @@ -120,9 +120,9 @@ def __init__(
if not (
isinstance(vertical_dimensions, tuple)
and len(vertical_dimensions) == 1
and isinstance(vertical_dimensions[0], DimDimPadding)
and isinstance(vertical_dimensions[0], FaceNodePadding)
):
raise ValueError("vertical_dimensions must be a tuple of 1 DimDimPadding for a 2D grid")
raise ValueError("vertical_dimensions must be a tuple of 1 FaceNodePadding for a 2D grid")

# Required attributes
self.cf_role = cf_role
Expand All @@ -137,8 +137,8 @@ def __init__(
#! Some optional attributes aren't really important to Parcels, can be added later if needed
# Optional attributes
# # With defaults (set in init)
# edge1_dimensions: tuple[Dim, DimDimPadding]
# edge2_dimensions: tuple[DimDimPadding, Dim]
# edge1_dimensions: tuple[Dim, FaceNodePadding]
# edge2_dimensions: tuple[FaceNodePadding, Dim]

# # Without defaults
# edge1_coordinates: None | Any = None
Expand All @@ -163,7 +163,7 @@ def from_attrs(cls, attrs): # type: ignore[override]
cf_role=attrs["cf_role"],
topology_dimension=attrs["topology_dimension"],
node_dimensions=cast(tuple[Dim, Dim], load_mappings(attrs["node_dimensions"])),
face_dimensions=cast(tuple[DimDimPadding, DimDimPadding], load_mappings(attrs["face_dimensions"])),
face_dimensions=cast(tuple[FaceNodePadding, FaceNodePadding], load_mappings(attrs["face_dimensions"])),
node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")),
vertical_dimensions=maybe_load_mappings(attrs.get("vertical_dimensions")),
)
Expand All @@ -186,7 +186,7 @@ def to_attrs(self) -> dict[str, str | int]:
def rename(self, names_dict: dict[str, str]) -> Self:
return cast(Self, _metadata_rename(self, names_dict))

def get_value_by_id(self, id: str) -> str:
def get_value_by_id(self, id: str) -> Dim | Padding:
"""In the SGRID specification for 2D grids, different parts of the spec are identified by different "ID"s.

Easily extract the value for a given ID.
Expand All @@ -206,7 +206,7 @@ def __init__(
cf_role: Literal["grid_topology"],
topology_dimension: Literal[3],
node_dimensions: tuple[Dim, Dim, Dim],
volume_dimensions: tuple[DimDimPadding, DimDimPadding, DimDimPadding],
volume_dimensions: tuple[FaceNodePadding, FaceNodePadding, FaceNodePadding],
node_coordinates: None | tuple[Dim, Dim, Dim] = None,
):
if cf_role != "grid_topology":
Expand All @@ -225,9 +225,9 @@ def __init__(
if not (
isinstance(volume_dimensions, tuple)
and len(volume_dimensions) == 3
and all(isinstance(fd, DimDimPadding) for fd in volume_dimensions)
and all(isinstance(fd, FaceNodePadding) for fd in volume_dimensions)
):
raise ValueError("face_dimensions must be a tuple of 2 DimDimPadding for a 2D grid")
raise ValueError("face_dimensions must be a tuple of 2 FaceNodePadding for a 2D grid")

if node_coordinates is not None:
if not (
Expand All @@ -249,12 +249,12 @@ def __init__(
# ! Some optional attributes aren't really important to Parcels, can be added later if needed
# Optional attributes
# # With defaults (set in init)
# edge1_dimensions: tuple[DimDimPadding, Dim, Dim]
# edge2_dimensions: tuple[Dim, DimDimPadding, Dim]
# edge3_dimensions: tuple[Dim, Dim, DimDimPadding]
# face1_dimensions: tuple[Dim, DimDimPadding, DimDimPadding]
# face2_dimensions: tuple[DimDimPadding, Dim, DimDimPadding]
# face3_dimensions: tuple[DimDimPadding, DimDimPadding, Dim]
# edge1_dimensions: tuple[FaceNodePadding, Dim, Dim]
# edge2_dimensions: tuple[Dim, FaceNodePadding, Dim]
# edge3_dimensions: tuple[Dim, Dim, FaceNodePadding]
# face1_dimensions: tuple[Dim, FaceNodePadding, FaceNodePadding]
# face2_dimensions: tuple[FaceNodePadding, Dim, FaceNodePadding]
# face3_dimensions: tuple[FaceNodePadding, FaceNodePadding, Dim]

# # Without defaults
# edge *i_coordinates*
Expand All @@ -280,7 +280,7 @@ def from_attrs(cls, attrs): # type: ignore[override]
topology_dimension=attrs["topology_dimension"],
node_dimensions=cast(tuple[Dim, Dim, Dim], load_mappings(attrs["node_dimensions"])),
volume_dimensions=cast(
tuple[DimDimPadding, DimDimPadding, DimDimPadding], load_mappings(attrs["volume_dimensions"])
tuple[FaceNodePadding, FaceNodePadding, FaceNodePadding], load_mappings(attrs["volume_dimensions"])
),
node_coordinates=maybe_load_mappings(attrs.get("node_coordinates")),
)
Expand All @@ -301,7 +301,7 @@ def to_attrs(self) -> dict[str, str | int]:
def rename(self, dims_dict: dict[str, str]) -> Self:
return cast(Self, _metadata_rename(self, dims_dict))

def get_value_by_id(self, id: str) -> str:
def get_value_by_id(self, id: str) -> Dim | Padding:
"""In the SGRID specification for 3D grids, different parts of the spec are identified by different "ID"s.

Easily extract the value for a given ID.
Expand All @@ -316,39 +316,39 @@ def get_value_by_id(self, id: str) -> str:


@dataclass
class DimDimPadding:
"""A data class representing a dimension-dimension-padding triplet for SGrid metadata.
class FaceNodePadding:
"""A data class representing a face-node-padding triplet for SGrid metadata.

This triplet can represent different relations depending on context within the standard
For example - for "face_dimensions" this can show the relation between an edge (dim1) and a node
(dim2).
In the context of a 2D grid, "face" corresponds with an edge.

Use the .to_diagram() method to visualize the representation.
"""

dim1: str
dim2: str
face: Dim
node: Dim
padding: Padding
Comment on lines 318 to 329
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test suite for the sgrid parsing etc. is very thorough, so the main thing to review is this block.


def __repr__(self) -> str:
return f"DimDimPadding(dim1={self.dim1!r}, dim2={self.dim2!r}, padding={self.padding!r})"
return f"FaceNodePadding(face={self.face!r}, node={self.node!r}, padding={self.padding!r})"

def __str__(self) -> str:
return f"{self.dim1}:{self.dim2} (padding:{self.padding.value})"
return f"{self.face}:{self.node} (padding:{self.padding.value})"

@classmethod
def load(cls, s: str) -> Self:
match = re.match(RE_DIM_DIM_PADDING, s)
match = re.match(RE_FACE_NODE_PADDING, s)
if not match:
raise ValueError(f"String {s!r} does not match expected format for DimDimPadding")
dim1 = match.group(1)
dim2 = match.group(2)
raise ValueError(f"String {s!r} does not match expected format for FaceNodePadding")
face = match.group(1)
node = match.group(2)
padding = Padding(match.group(3).lower())
return cls(dim1, dim2, padding)
return cls(face, node, padding)

def to_diagram(self) -> str:
return "\n".join(_face_node_padding_to_text(self))


def dump_mappings(parts: Iterable[DimDimPadding | Dim]) -> str:
def dump_mappings(parts: Iterable[FaceNodePadding | Dim]) -> str:
"""Takes in a list of edge-node-padding tuples and serializes them into a string
according to the SGrid convention.
"""
Expand All @@ -361,7 +361,7 @@ def dump_mappings(parts: Iterable[DimDimPadding | Dim]) -> str:
@overload
def maybe_dump_mappings(parts: None) -> None: ...
@overload
def maybe_dump_mappings(parts: Iterable[DimDimPadding | Dim]) -> str: ...
def maybe_dump_mappings(parts: Iterable[FaceNodePadding | Dim]) -> str: ...


def maybe_dump_mappings(parts):
Expand All @@ -370,7 +370,7 @@ def maybe_dump_mappings(parts):
return dump_mappings(parts)


def load_mappings(s: str) -> tuple[DimDimPadding | Dim, ...]:
def load_mappings(s: str) -> tuple[FaceNodePadding | Dim, ...]:
"""Takes in a string indicating the mappings of dims and dim-dim-padding
and returns a tuple with this data destructured.

Expand All @@ -383,25 +383,25 @@ def load_mappings(s: str) -> tuple[DimDimPadding | Dim, ...]:
ret = []
while s:
# find next part
match = re.match(RE_DIM_DIM_PADDING, s)
match = re.match(RE_FACE_NODE_PADDING, s)
if match and match.start() == 0:
# match found at start, take that as next part
part = match.group(0)
s_new = s[match.end() :].lstrip()
else:
# no DimDimPadding match at start, assume just a Dim until next space
# no FaceNodePadding match at start, assume just a Dim until next space
part, *s_new = s.split(" ", 1)
s_new = "".join(s_new)

assert s != s_new, f"SGrid parsing did not advance, stuck at {s!r}"

parsed: DimDimPadding | Dim
parsed: FaceNodePadding | Dim
try:
parsed = DimDimPadding.load(part)
parsed = FaceNodePadding.load(part)
except ValueError as e:
e.add_note(f"Failed to parse part {part!r} from {s!r} as a dimension dimension padding string")
try:
# Not a DimDimPadding, assume it's just a Dim
# Not a FaceNodePadding, assume it's just a Dim
assert ":" not in part, f"Part {part!r} from {s!r} not a valid dim (contains ':')"
parsed = part
except AssertionError as e2:
Expand All @@ -416,7 +416,7 @@ def load_mappings(s: str) -> tuple[DimDimPadding | Dim, ...]:
@overload
def maybe_load_mappings(s: None) -> None: ...
@overload
def maybe_load_mappings(s: Hashable) -> tuple[DimDimPadding | Dim, ...]: ...
def maybe_load_mappings(s: Hashable) -> tuple[FaceNodePadding | Dim, ...]: ...


def maybe_load_mappings(s):
Expand Down Expand Up @@ -471,11 +471,11 @@ def parse_sgrid(ds: xr.Dataset):
dimensions = grid.volume_dimensions

xgcm_coords = {}
for dim_dim_padding, axis in zip(dimensions, "XYZ", strict=False):
xgcm_position = SGRID_PADDING_TO_XGCM_POSITION[dim_dim_padding.padding]
for face_node_padding, axis in zip(dimensions, "XYZ", strict=False):
xgcm_position = SGRID_PADDING_TO_XGCM_POSITION[face_node_padding.padding]

coords = {}
for pos, dim in [("center", dim_dim_padding.dim1), (xgcm_position, dim_dim_padding.dim2)]:
for pos, dim in [("center", face_node_padding.face), (xgcm_position, face_node_padding.node)]:
# only include dimensions in dataset (ignore dimensions in metadata that may not exist - e.g., due to `.isel`)
if dim in ds.dims:
coords[pos] = dim
Expand Down Expand Up @@ -510,16 +510,16 @@ def get_unique_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]:
f"Expected sgrid metadata attribute to be represented as a tuple, got {value!r}. This is an internal error to Parcels - please post an issue if you encounter this."
)
for item in value:
if isinstance(item, DimDimPadding):
dims.add(item.dim1)
dims.add(item.dim2)
if isinstance(item, FaceNodePadding):
dims.add(item.face)
dims.add(item.node)
else:
assert isinstance(item, str)
dims.add(item)
return dims


def _face_node_padding_to_text(obj: DimDimPadding) -> list[str]:
def _face_node_padding_to_text(obj: FaceNodePadding) -> list[str]:
"""Return ASCII diagram lines showing a face-node padding relationship.

Produces a symbolic 5-node diagram like the image below, matching the
Expand Down Expand Up @@ -566,7 +566,7 @@ def _face_node_padding_to_text(obj: DimDimPadding) -> list[str]:
face_count += 1

return [
f"{obj.dim1}:{obj.dim2} (padding:{padding.value})",
f"{obj.face}:{obj.node} (padding:{padding.value})",
f" {bar_rendered}",
f" {label.rstrip()}",
]
Expand Down Expand Up @@ -640,19 +640,19 @@ def _grid2d_to_ascii(grid: Grid2DMetadata) -> str:
nd = grid.node_dimensions
lines = [
"Grid2DMetadata",
f" X-axis: face={fd[0].dim1!r} node={nd[0]!r} padding={fd[0].padding.value}",
f" Y-axis: face={fd[1].dim1!r} node={nd[1]!r} padding={fd[1].padding.value}",
f" X-axis: face={fd[0].face!r} node={nd[0]!r} padding={fd[0].padding.value}",
f" Y-axis: face={fd[1].face!r} node={nd[1]!r} padding={fd[1].padding.value}",
]
if grid.vertical_dimensions:
vd = grid.vertical_dimensions[0]
lines.append(f" Z-axis: face={vd.dim1!r} node={vd.dim2!r} padding={vd.padding.value}")
lines.append(f" Z-axis: face={vd.face!r} node={vd.node!r} padding={vd.padding.value}")
if grid.node_coordinates:
lines.append(f" Coordinates: {grid.node_coordinates[0]}, {grid.node_coordinates[1]}")

format_kwargs = dict(n1=nd[0], n2=nd[1], u=fd[0].dim1, v=fd[1].dim1)
format_kwargs = dict(n1=nd[0], n2=nd[1], u=fd[0].face, v=fd[1].face)

if grid.vertical_dimensions:
format_kwargs["w"] = grid.vertical_dimensions[0].dim2
format_kwargs["w"] = grid.vertical_dimensions[0].node
lines += indent(_TEXT_GRID2D_WITH_Z, " ").format(**format_kwargs).split("\n")
else:
lines += indent(_TEXT_GRID2D_WITHOUT_Z, " ").format(**format_kwargs).split("\n")
Expand All @@ -672,16 +672,16 @@ def _grid3d_to_ascii(grid: Grid3DMetadata) -> str:
nd = grid.node_dimensions
lines = [
"Grid3DMetadata",
f" X-axis: face={vd[0].dim1!r} node={nd[0]!r} padding={vd[0].padding.value}",
f" Y-axis: face={vd[1].dim1!r} node={nd[1]!r} padding={vd[1].padding.value}",
f" Z-axis: face={vd[2].dim1!r} node={nd[2]!r} padding={vd[2].padding.value}",
f" X-axis: face={vd[0].face!r} node={nd[0]!r} padding={vd[0].padding.value}",
f" Y-axis: face={vd[1].face!r} node={nd[1]!r} padding={vd[1].padding.value}",
f" Z-axis: face={vd[2].face!r} node={nd[2]!r} padding={vd[2].padding.value}",
]
if grid.node_coordinates:
lines.append(f" Coordinates: {', '.join(grid.node_coordinates)}")

lines += (
indent(_TEXT_GRID3D, " ")
.format(n1=nd[0], n2=nd[1], n3=nd[2], u=vd[0].dim1, v=vd[1].dim1, w=vd[2].dim1)
.format(n1=nd[0], n2=nd[1], n3=nd[2], u=vd[0].face, v=vd[1].face, w=vd[2].face)
.split("\n")
)

Expand Down Expand Up @@ -738,10 +738,10 @@ def _metadata_rename(grid, names_dict):
if isinstance(value, tuple):
new_value = []
for item in value:
if isinstance(item, DimDimPadding):
new_item = DimDimPadding(
dim1=names_dict[item.dim1],
dim2=names_dict[item.dim2],
if isinstance(item, FaceNodePadding):
new_item = FaceNodePadding(
face=names_dict[item.face],
node=names_dict[item.node],
padding=item.padding,
)
new_value.append(new_item)
Expand Down
Loading
Loading