diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 78610310..ed0c216d 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -55,7 +55,7 @@ from re import match as regex_match # type: ignore[no-redef] -from . import parametric, sgrid +from . import parametric, sgrid, ugrid from .criteria import ( _DSG_ROLES, _GEOMETRY_TYPES, @@ -2235,6 +2235,7 @@ def get_associated_variable_names( 4. "coordinates" 5. "grid_mapping" 6. "grid" + 7. "mesh" Parameters ---------- @@ -2247,7 +2248,7 @@ def get_associated_variable_names( ------- names : dict Dictionary with keys "ancillary_variables", "cell_measures", "coordinates", "bounds", - "grid_mapping", "grid". + "grid_mapping", "grid", "mesh". """ keys = [ "ancillary_variables", @@ -2256,6 +2257,7 @@ def get_associated_variable_names( "bounds", "grid_mapping", "grid", + "mesh", "geometry", ] @@ -2300,6 +2302,13 @@ def get_associated_variable_names( if isinstance(self._obj, Dataset): coords["coordinates"].extend(sgrid.get_topology_coords(self._obj, grid)) + if mesh := attrs_or_encoding.get("mesh", None): + coords["mesh"] = [mesh] + if isinstance(self._obj, Dataset): + connectivity, mesh_coords = ugrid.get_mesh_variables(self._obj, mesh) + coords["mesh"].extend(connectivity) + coords["coordinates"].extend(mesh_coords) + if grid_mapping_attr := attrs_or_encoding.get("grid_mapping", None): # Parse grid mapping variables and their coordinates grid_mapping_dict = _parse_grid_mapping_attribute(grid_mapping_attr) diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index ef7dee10..225bab8c 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -2434,6 +2434,58 @@ def test_grid_topology() -> None: assert "T" in ds.cf.axes +def test_ugrid_includes_topology_variables() -> None: + """The mesh_topology variable, its connectivity, and its coordinate + variables should be pulled in by ds.cf[[var]] for a UGRID data variable.""" + ds = xr.Dataset( + data_vars={ + "h": ( + "face", + np.zeros(2), + {"mesh": "mesh", "location": "face"}, + ), + }, + coords={ + "mesh": ( + tuple(), + 1, + { + "cf_role": "mesh_topology", + "topology_dimension": 2, + "node_coordinates": "node_lon node_lat", + "face_node_connectivity": "face_nodes", + "edge_node_connectivity": "edge_nodes", + "face_coordinates": "face_lon face_lat", + }, + ), + "node_lon": ("node", np.zeros(4)), + "node_lat": ("node", np.zeros(4)), + "face_lon": ("face", np.zeros(2)), + "face_lat": ("face", np.zeros(2)), + "face_nodes": (("face", "nvertex"), np.zeros((2, 3))), + "edge_nodes": (("edge", "two"), np.zeros((3, 2))), + }, + ) + + assoc = ds.cf.get_associated_variable_names("h") + assert {"mesh", "face_nodes", "edge_nodes"}.issubset(set(assoc["mesh"])) + assert {"node_lon", "node_lat", "face_lon", "face_lat"}.issubset( + set(assoc["coordinates"]) + ) + + expected = { + "mesh", + "face_nodes", + "edge_nodes", + "node_lon", + "node_lat", + "face_lon", + "face_lat", + } + subset = ds.cf[["h"]] + assert expected.issubset(set(subset.variables)) + + @requires_scipy def test_curvefit() -> None: from cf_xarray.datasets import airds diff --git a/cf_xarray/ugrid.py b/cf_xarray/ugrid.py new file mode 100644 index 00000000..d0e55f7b --- /dev/null +++ b/cf_xarray/ugrid.py @@ -0,0 +1,41 @@ +# Connectivity attributes on a UGRID ``mesh_topology`` variable. +UGRID_CONNECTIVITY_ATTRS = [ + "face_node_connectivity", + "edge_node_connectivity", + "face_edge_connectivity", + "face_face_connectivity", + "edge_face_connectivity", + "boundary_node_connectivity", +] + +# Coordinate-variable attributes on a UGRID ``mesh_topology`` variable. +UGRID_COORD_ATTRS = [ + "node_coordinates", + "edge_coordinates", + "face_coordinates", +] + + +def get_mesh_variables(ds, mesh_var_name): + """Return variables referenced by a UGRID ``mesh_topology`` variable. + + Reads the connectivity attributes (``face_node_connectivity``, + ``edge_node_connectivity``, ...) and the coordinate attributes + (``node_coordinates``, ``edge_coordinates``, ``face_coordinates``) from the + mesh topology variable's attrs. + + Returns a ``(connectivity, coordinates)`` tuple of variable-name lists, + each filtered to names actually present in ``ds``. + """ + if mesh_var_name not in ds.variables: + return [], [] + mesh_attrs = ds[mesh_var_name].attrs + connectivity: list = [] + for attr_name in UGRID_CONNECTIVITY_ATTRS: + if conn_str := mesh_attrs.get(attr_name): + connectivity.extend(n for n in conn_str.split() if n in ds.variables) + coordinates: list = [] + for attr_name in UGRID_COORD_ATTRS: + if coord_str := mesh_attrs.get(attr_name): + coordinates.extend(n for n in coord_str.split() if n in ds.variables) + return connectivity, coordinates diff --git a/doc/sgrid_ugrid.md b/doc/sgrid_ugrid.md index ff0417e2..496da7c5 100644 --- a/doc/sgrid_ugrid.md +++ b/doc/sgrid_ugrid.md @@ -84,6 +84,22 @@ appear as coords on `u`. `cf_xarray` supports identifying the `mesh_topology` variable using the `cf_role` attribute. +### Connectivity and coordinate variables + +When a data variable references a `mesh_topology` variable through its `mesh` +attribute, `cf_xarray` follows that attribute when subsetting with `.cf`. The +mesh topology variable, its connectivity variables (`face_node_connectivity`, +`edge_node_connectivity`, `face_edge_connectivity`, `face_face_connectivity`, +`edge_face_connectivity`, `boundary_node_connectivity`), and its coordinate +variables (`node_coordinates`, `edge_coordinates`, `face_coordinates`) are +pulled in alongside the data variable: + +```python +ds.cf[["h"]] # includes `mesh`, face_nodes, node_lon/node_lat, ... +``` + +Only names actually present in the dataset are propagated. + ## More? Further support for interpreting the SGRID and UGRID conventions can be added. Contributions are welcome!