Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from re import match as regex_match # type: ignore[no-redef]


from . import parametric, sgrid
from . import parametric, sgrid, ugrid
from .criteria import (
_DSG_ROLES,
_GEOMETRY_TYPES,
Expand Down Expand Up @@ -2235,6 +2235,7 @@ def get_associated_variable_names(
4. "coordinates"
5. "grid_mapping"
6. "grid"
7. "mesh"

Parameters
----------
Expand All @@ -2247,7 +2248,7 @@ def get_associated_variable_names(
-------
names : dict
Dictionary with keys "ancillary_variables", "cell_measures", "coordinates", "bounds",
"grid_mapping", "grid".
"grid_mapping", "grid", "mesh".
"""
keys = [
"ancillary_variables",
Expand All @@ -2256,6 +2257,7 @@ def get_associated_variable_names(
"bounds",
"grid_mapping",
"grid",
"mesh",
"geometry",
]

Expand Down Expand Up @@ -2300,6 +2302,13 @@ def get_associated_variable_names(
if isinstance(self._obj, Dataset):
coords["coordinates"].extend(sgrid.get_topology_coords(self._obj, grid))

if mesh := attrs_or_encoding.get("mesh", None):
coords["mesh"] = [mesh]
if isinstance(self._obj, Dataset):
connectivity, mesh_coords = ugrid.get_mesh_variables(self._obj, mesh)
coords["mesh"].extend(connectivity)
coords["coordinates"].extend(mesh_coords)

if grid_mapping_attr := attrs_or_encoding.get("grid_mapping", None):
# Parse grid mapping variables and their coordinates
grid_mapping_dict = _parse_grid_mapping_attribute(grid_mapping_attr)
Expand Down
52 changes: 52 additions & 0 deletions cf_xarray/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2434,6 +2434,58 @@ def test_grid_topology() -> None:
assert "T" in ds.cf.axes


def test_ugrid_includes_topology_variables() -> None:
"""The mesh_topology variable, its connectivity, and its coordinate
variables should be pulled in by ds.cf[[var]] for a UGRID data variable."""
ds = xr.Dataset(
data_vars={
"h": (
"face",
np.zeros(2),
{"mesh": "mesh", "location": "face"},
),
},
coords={
"mesh": (
tuple(),
1,
{
"cf_role": "mesh_topology",
"topology_dimension": 2,
"node_coordinates": "node_lon node_lat",
"face_node_connectivity": "face_nodes",
"edge_node_connectivity": "edge_nodes",
"face_coordinates": "face_lon face_lat",
},
),
"node_lon": ("node", np.zeros(4)),
"node_lat": ("node", np.zeros(4)),
"face_lon": ("face", np.zeros(2)),
"face_lat": ("face", np.zeros(2)),
"face_nodes": (("face", "nvertex"), np.zeros((2, 3))),
"edge_nodes": (("edge", "two"), np.zeros((3, 2))),
},
)

assoc = ds.cf.get_associated_variable_names("h")
assert {"mesh", "face_nodes", "edge_nodes"}.issubset(set(assoc["mesh"]))
assert {"node_lon", "node_lat", "face_lon", "face_lat"}.issubset(
set(assoc["coordinates"])
)

expected = {
"mesh",
"face_nodes",
"edge_nodes",
"node_lon",
"node_lat",
"face_lon",
"face_lat",
}
subset = ds.cf[["h"]]
assert expected.issubset(set(subset.variables))


@requires_scipy
def test_curvefit() -> None:
from cf_xarray.datasets import airds
Expand Down
41 changes: 41 additions & 0 deletions cf_xarray/ugrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Connectivity attributes on a UGRID ``mesh_topology`` variable.
UGRID_CONNECTIVITY_ATTRS = [
"face_node_connectivity",
"edge_node_connectivity",
"face_edge_connectivity",
"face_face_connectivity",
"edge_face_connectivity",
"boundary_node_connectivity",
]

# Coordinate-variable attributes on a UGRID ``mesh_topology`` variable.
UGRID_COORD_ATTRS = [
"node_coordinates",
"edge_coordinates",
"face_coordinates",
]


def get_mesh_variables(ds, mesh_var_name):
"""Return variables referenced by a UGRID ``mesh_topology`` variable.

Reads the connectivity attributes (``face_node_connectivity``,
``edge_node_connectivity``, ...) and the coordinate attributes
(``node_coordinates``, ``edge_coordinates``, ``face_coordinates``) from the
mesh topology variable's attrs.

Returns a ``(connectivity, coordinates)`` tuple of variable-name lists,
each filtered to names actually present in ``ds``.
"""
if mesh_var_name not in ds.variables:
return [], []
mesh_attrs = ds[mesh_var_name].attrs
connectivity: list = []
for attr_name in UGRID_CONNECTIVITY_ATTRS:
if conn_str := mesh_attrs.get(attr_name):
connectivity.extend(n for n in conn_str.split() if n in ds.variables)
coordinates: list = []
for attr_name in UGRID_COORD_ATTRS:
if coord_str := mesh_attrs.get(attr_name):
coordinates.extend(n for n in coord_str.split() if n in ds.variables)
return connectivity, coordinates
16 changes: 16 additions & 0 deletions doc/sgrid_ugrid.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ appear as coords on `u`.

`cf_xarray` supports identifying the `mesh_topology` variable using the `cf_role` attribute.

### Connectivity and coordinate variables

When a data variable references a `mesh_topology` variable through its `mesh`
attribute, `cf_xarray` follows that attribute when subsetting with `.cf`. The
mesh topology variable, its connectivity variables (`face_node_connectivity`,
`edge_node_connectivity`, `face_edge_connectivity`, `face_face_connectivity`,
`edge_face_connectivity`, `boundary_node_connectivity`), and its coordinate
variables (`node_coordinates`, `edge_coordinates`, `face_coordinates`) are
pulled in alongside the data variable:

```python
ds.cf[["h"]] # includes `mesh`, face_nodes, node_lon/node_lat, ...
```

Only names actually present in the dataset are propagated.

## More?

Further support for interpreting the SGRID and UGRID conventions can be added. Contributions are welcome!
Loading