Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/parcels/_core/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import xarray as xr
import xgcm

import parcels._sgrid as sgrid
from parcels._core.field import Field, VectorField
from parcels._core.utils import sgrid
from parcels._core.utils.string import _assert_str_and_python_varname
from parcels._core.utils.time import get_datetime_type_calendar
from parcels._core.utils.time import is_compatible as datetime_is_compatible
Expand Down Expand Up @@ -280,7 +280,7 @@ def from_sgrid_conventions(
ds = ds.rename({time_dim: "time"})

# Parse SGRID metadata and get xgcm kwargs
_, xgcm_kwargs = sgrid.parse_sgrid(ds)
_, xgcm_kwargs = sgrid.xgcm_parse_sgrid(ds)

# Add time axis to xgcm_kwargs if present
if "time" in ds.dims:
Expand Down Expand Up @@ -475,11 +475,7 @@ def _select_uxinterpolator(da: ux.UxDataArray):
# TODO: Refactor later into something like `parcels._metadata.discover(dataset)` helper that can be used to discover important metadata like this. I think this whole metadata handling should be refactored into its own module.
def _get_mesh_type_from_sgrid_dataset(ds_sgrid: xr.Dataset) -> Mesh:
"""Small helper to inspect SGRID metadata and dataset metadata to determine mesh type."""
grid_da = sgrid.get_grid_topology(ds_sgrid)
if grid_da is None:
raise ValueError("Dataset does not contain SGRID grid topology metadata (cf_role='grid_topology').")

sgrid_metadata = sgrid.parse_grid_attrs(grid_da.attrs)
sgrid_metadata = ds_sgrid.sgrid.metadata

fpoint_x, fpoint_y = sgrid_metadata.node_coordinates

Expand Down
18 changes: 7 additions & 11 deletions src/parcels/_datasets/structured/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,7 @@
import numpy as np
import xarray as xr

from parcels._core.utils.sgrid import (
FaceNodePadding,
Padding,
SGrid2DMetadata,
_attach_sgrid_metadata,
)
import parcels._sgrid as sgrid
from parcels._core.utils.time import timedelta_to_float


Expand All @@ -29,16 +24,17 @@ def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"):
"lon": (["XG"], np.linspace(-max_lon, max_lon, dims[3]), {"axis": "X", "c_grid_axis_shift": -0.5}),
},
).pipe(
_attach_sgrid_metadata,
SGrid2DMetadata(
sgrid._attach_sgrid_metadata,
sgrid.SGrid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("XG", "YG"),
node_coordinates=("lon", "lat"),
face_dimensions=(
FaceNodePadding("XC", "XG", Padding.LOW),
FaceNodePadding("YC", "YG", Padding.LOW),
sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.LOW),
sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.LOW),
),
vertical_dimensions=(FaceNodePadding("ZC", "depth", Padding.BOTH),),
vertical_dimensions=(sgrid.FaceNodePadding("ZC", "depth", sgrid.Padding.BOTH),),
),
)

Expand Down
84 changes: 61 additions & 23 deletions src/parcels/_datasets/structured/generic.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
import numpy as np
import xarray as xr

from parcels._core.utils.sgrid import (
FaceNodePadding,
Padding,
SGrid2DMetadata,
_attach_sgrid_metadata,
)
from parcels._core.utils.sgrid import (
rename as sgrid_rename,
)
import parcels._sgrid as sgrid

from . import T, X, Y, Z

Expand Down Expand Up @@ -248,43 +240,89 @@ def _unrolled_cone_curvilinear_grid():
"ds_2d_padded_high": (
datasets["ds_2d_left"]
.pipe(
_attach_sgrid_metadata,
SGrid2DMetadata(
sgrid._attach_sgrid_metadata,
sgrid.SGrid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("XG", "YG"),
face_dimensions=(
FaceNodePadding("XC", "XG", Padding.HIGH),
FaceNodePadding("YC", "YG", Padding.HIGH),
sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.HIGH),
sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.HIGH),
),
node_coordinates=("lon", "lat"),
vertical_dimensions=(FaceNodePadding("ZC", "ZG", Padding.HIGH),),
vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.HIGH),),
),
)
.pipe(
sgrid_rename,
.sgrid.rename(
_COMODO_TO_2D_SGRID,
)
),
"ds_2d_padded_low": (
datasets["ds_2d_right"]
.pipe(
_attach_sgrid_metadata,
SGrid2DMetadata(
sgrid._attach_sgrid_metadata,
sgrid.SGrid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("XG", "YG"),
face_dimensions=(
FaceNodePadding("XC", "XG", Padding.LOW),
FaceNodePadding("YC", "YG", Padding.LOW),
sgrid.FaceNodePadding("XC", "XG", sgrid.Padding.LOW),
sgrid.FaceNodePadding("YC", "YG", sgrid.Padding.LOW),
),
node_coordinates=("lon", "lat"),
vertical_dimensions=(FaceNodePadding("ZC", "ZG", Padding.LOW),),
vertical_dimensions=(sgrid.FaceNodePadding("ZC", "ZG", sgrid.Padding.LOW),),
),
)
.pipe(
sgrid_rename,
.sgrid.rename(
_COMODO_TO_2D_SGRID,
)
),
"ds_2d_padded_none": xr.Dataset(
{
"data_g": (["node_dimension1", "node_dimension2"], np.random.rand(10, 10)),
"data_c": (["face_dimension1", "face_dimension2"], np.random.rand(9, 9)),
"grid": (
[],
np.array(0),
sgrid.SGrid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("node_dimension1", "node_dimension2"),
face_dimensions=(
sgrid.FaceNodePadding("face_dimension1", "node_dimension1", sgrid.Padding.NONE),
sgrid.FaceNodePadding("face_dimension2", "node_dimension2", sgrid.Padding.NONE),
),
node_coordinates=("lon", "lat"),
).to_attrs(),
),
},
coords={
"lon": (["node_dimension1"], np.linspace(0, 1, 10)),
"lat": (["node_dimension2"], np.linspace(0, 1, 10)),
},
),
"ds_2d_padded_both": xr.Dataset(
{
"data_g": (["node_dimension1", "node_dimension2"], np.random.rand(10, 10)),
"data_c": (["face_dimension1", "face_dimension2"], np.random.rand(11, 11)),
"grid": (
[],
np.array(0),
sgrid.SGrid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("node_dimension1", "node_dimension2"),
face_dimensions=(
sgrid.FaceNodePadding("face_dimension1", "node_dimension1", sgrid.Padding.BOTH),
sgrid.FaceNodePadding("face_dimension2", "node_dimension2", sgrid.Padding.BOTH),
),
node_coordinates=("lon", "lat"),
).to_attrs(),
),
},
coords={
"lon": (["node_dimension1"], np.linspace(0, 1, 10)),
"lat": (["node_dimension2"], np.linspace(0, 1, 10)),
},
),
}
20 changes: 5 additions & 15 deletions src/parcels/_datasets/structured/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,8 @@
from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays as np_arrays

import parcels._sgrid as sgrid
import parcels._strategies as pst
from parcels._core.utils import sgrid
from parcels._core.utils.sgrid import _attach_sgrid_metadata


def _face_size(node_size: int, padding: sgrid.Padding) -> int:
if padding == sgrid.Padding.NONE:
return node_size - 1
elif padding in (sgrid.Padding.LOW, sgrid.Padding.HIGH):
return node_size
else: # Padding.BOTH
return node_size + 1


@st.composite
Expand All @@ -33,14 +23,14 @@ def sgrid_dataset(draw, grid: sgrid.SGrid2DMetadata | None = None) -> xr.Dataset
node_dim1, node_dim2 = grid.node_dimensions
face_dim1 = grid.face_dimensions[0].face
face_dim2 = grid.face_dimensions[1].face
N_face = _face_size(N, grid.face_dimensions[0].padding)
M_face = _face_size(M, grid.face_dimensions[1].padding)
N_face = sgrid.get_n_faces(N, grid.face_dimensions[0].padding)
M_face = sgrid.get_n_faces(M, grid.face_dimensions[1].padding)

if has_vertical := grid.vertical_dimensions is not None:
P = draw(st.integers(min_value=5, max_value=20))
vert_node_dim = grid.vertical_dimensions[0].node
vert_face_dim = grid.vertical_dimensions[0].face
P_face = _face_size(P, grid.vertical_dimensions[0].padding)
P_face = sgrid.get_n_faces(P, grid.vertical_dimensions[0].padding)

has_curvilinear_grid = draw(st.booleans())
coord_name1, coord_name2 = grid.node_coordinates
Expand Down Expand Up @@ -90,4 +80,4 @@ def sgrid_dataset(draw, grid: sgrid.SGrid2DMetadata | None = None) -> xr.Dataset
}

ds = xr.Dataset(data_vars=data_vars, coords=coords)
return _attach_sgrid_metadata(ds, grid)
return sgrid._attach_sgrid_metadata(ds, grid)
13 changes: 12 additions & 1 deletion src/parcels/_python.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Generic Python helpers
import inspect
from collections.abc import Callable
from collections.abc import Callable, Mapping
from typing import TypeVar

K = TypeVar("K")
V = TypeVar("V")
Comment on lines +6 to +7
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's happening here? What is this for?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are generic type variables https://mypy.readthedocs.io/en/stable/generics.html

They allow you to describe (via typing) a relationship between the input and output types of a function, without providing concrete types.

e.g.,

T = TypeVar("T")
def make_list(obj: T) -> list[T]:
    # make a list of length 1 from provided object
    # works on any object type. Returns a list with objects of same type as input
    return [obj]

In our case here we define type variables K and V (for "key" and "value") for

def invert_non_unique_mapping(d: Mapping[K, V]) -> Mapping[V, list[K]]:
    inv_map: dict[V, list[K]] = {}
    for k, v in d.items():
        inv_map[v] = inv_map.get(v, []) + [k]
    return inv_map



def isinstance_noimport(obj, class_or_tuple):
Expand Down Expand Up @@ -39,3 +43,10 @@ def assert_same_function_signature(f: Callable, *, ref: Callable, context: str)
raise ValueError(
f"Parameter '{param2.name}' has incorrect name. Expected '{param1.name}', got '{param2.name}'"
)


def invert_non_unique_mapping(d: Mapping[K, V]) -> Mapping[V, list[K]]:
inv_map: dict[V, list[K]] = {}
for k, v in d.items():
inv_map[v] = inv_map.get(v, []) + [k]
return inv_map
27 changes: 27 additions & 0 deletions src/parcels/_sgrid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from .accessor import SgridAccessor
from .core import (
FaceNodePadding,
Padding,
SGrid2DMetadata,
SGrid3DMetadata,
_attach_sgrid_metadata,
dump_mappings,
get_n_faces,
get_n_nodes,
load_mappings,
xgcm_parse_sgrid,
)

__all__ = [
"FaceNodePadding",
"Padding",
"SGrid2DMetadata",
"SGrid3DMetadata",
"SgridAccessor",
"_attach_sgrid_metadata",
"dump_mappings",
"get_n_faces",
"get_n_nodes",
"load_mappings",
"xgcm_parse_sgrid",
]
Loading
Loading