Skip to content

Commit 1c4c7b8

Browse files
authored
Merge pull request #221 from csiro-coasts/estimate-bounds-1d
Add `emsarray.utils.estimate_bounds_1d()` function
2 parents 6829951 + 1e91722 commit 1c4c7b8

3 files changed

Lines changed: 126 additions & 0 deletions

File tree

docs/releases/development.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,5 @@ Next release (in development)
2525
(:pr:`220`).
2626
* Split `tests.test_utils` in to multiple `tests.utils.test_component` submodules
2727
(:pr:`220`).
28+
* Add :func:`emsarray.utils.estimate_bounds_1d` function
29+
(:pr:`221`).

src/emsarray/utils.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,36 @@ def find_unused_dimension(
544544
if candidate not in existing_dims)
545545

546546

547+
def find_unused_name(
548+
dataset: xarray.Dataset,
549+
candidate: Hashable,
550+
) -> Hashable:
551+
"""
552+
Find an unused variable name in a :class:`xarray.Dataset`.
553+
Useful when adding a derived variable to a dataset, such as bounds variables.
554+
This first tests a candidate name to see if it exists,
555+
then appends numeric suffixes ("_0", "_1", "_2", ...) until a valid name is found.
556+
557+
Parameters
558+
----------
559+
dataset : xarray.Dataset
560+
The dataset to find an unused name in
561+
candidate : Hashable
562+
A candidate variable name.
563+
564+
Returns
565+
-------
566+
Hashable
567+
A variable name that does not clash with any other names in the dataset.
568+
"""
569+
if candidate not in dataset.variables.keys():
570+
return candidate
571+
candidates = (f'{candidate}_{suffix}' for suffix in itertools.count(start=0))
572+
return next(
573+
candidate for candidate in candidates
574+
if candidate not in dataset.variables.keys())
575+
576+
547577
def ravel_dimensions(
548578
data_array: xarray.DataArray,
549579
dimensions: list[Hashable],
@@ -935,3 +965,78 @@ def data_array_to_name(dataset: xarray.Dataset, data_array: DataArrayOrName) ->
935965
if data_array not in dataset.variables:
936966
raise ValueError(f"Data array {data_array!r} is not in the dataset")
937967
return data_array
968+
969+
970+
def estimate_bounds_1d(
971+
dataset: xarray.Dataset,
972+
coordinate: DataArrayOrName,
973+
*,
974+
bounds_name: Hashable | None = None,
975+
bounds_dimension: Hashable = 'Two',
976+
) -> xarray.Dataset:
977+
"""
978+
Estimate the bounds of a one dimensional coordinate variable.
979+
The bounds between two coordinates is the average of the two values,
980+
while the bounds on each end are the first and last coordinate values.
981+
This is a crude approach.
982+
983+
Parameters
984+
==========
985+
dataset : xarray.Dataset
986+
The dataset containing the coordinate.
987+
coordinate : xarray.DataArray or str
988+
The coordinate variable to estimate the bounds of.
989+
bounds_name : Hashable, optional
990+
The name of the bounds variable to create.
991+
Optional, defaults to the name of the coordinate with a "_bounds" suffix.
992+
bounds_dimension : Hashable, default "Two"
993+
The name of the second dimension of the bounds variable.
994+
This dimension will have size 2.
995+
This dimension can be reused by other one-dimensional bounds variables.
996+
Defaults to "Two".
997+
998+
Returns
999+
=======
1000+
xarray.Dataset
1001+
A copy of the original dataset including the new estimated bounds.
1002+
1003+
Raises
1004+
======
1005+
ValueError
1006+
Raised if the coordinate variable already has a 'bounds' attribute.
1007+
"""
1008+
dataset = dataset.copy()
1009+
coordinate_name = data_array_to_name(dataset, coordinate)
1010+
coordinate = dataset[coordinate_name]
1011+
1012+
if len(coordinate.dims) != 1:
1013+
raise ValueError(
1014+
f"Coordinate {coordinate_name!r} has {len(coordinate.dims)} dimensions {coordinate.dims}, "
1015+
"expecting one dimension")
1016+
1017+
if 'bounds' in coordinate.attrs:
1018+
raise ValueError(f"Coordinate {coordinate_name!r} already has a 'bounds' attribute")
1019+
1020+
if bounds_dimension in dataset.dims and dataset.sizes[bounds_dimension] != 2:
1021+
raise ValueError(
1022+
f"Dataset already has a conflicting dimension {bounds_dimension!r} "
1023+
f"of size {dataset.sizes[bounds_dimension]}")
1024+
1025+
if bounds_name is None:
1026+
bounds_name = find_unused_name(dataset, f'{coordinate_name}_bounds')
1027+
else:
1028+
if bounds_name in dataset:
1029+
raise ValueError(
1030+
f"Dataset already has a variable named {bounds_name!r}")
1031+
1032+
values = coordinate.values
1033+
midpoints = (values[:-1] + values[1:]) / 2
1034+
midpoints = numpy.concat([[values[0]], midpoints, [values[-1]]])
1035+
dataset[bounds_name] = xarray.DataArray(
1036+
name=bounds_name,
1037+
data=numpy.c_[midpoints[:-1], midpoints[1:]],
1038+
dims=(coordinate.dims[0], bounds_dimension),
1039+
)
1040+
dataset = dataset.set_coords(bounds_name)
1041+
dataset[coordinate.name].attrs['bounds'] = bounds_name
1042+
return dataset

tests/utils/test_xarray.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,25 @@ def test_find_unused_dimension_conflict(dims: list[str], prefix: str, expected:
422422
assert utils.find_unused_dimension(data_array, prefix) == expected
423423

424424

425+
@pytest.mark.parametrize(
426+
['candidate', 'expected'],
427+
[
428+
('x', 'x'),
429+
('y', 'y_0'),
430+
('z', 'z_2'),
431+
],
432+
)
433+
def test_find_unused_name(candidate, expected):
434+
data_array = xarray.Dataset({
435+
'y': ((), 0),
436+
'y0': ((), 1),
437+
'z': ((), 2),
438+
'z_0': ((), 3),
439+
'z_1': ((), 4),
440+
})
441+
assert utils.find_unused_name(data_array, candidate) == expected
442+
443+
425444
def test_ravel_dimensions_exact_dimensions():
426445
data_array = xarray.DataArray(
427446
data=numpy.random.random((3, 5)),

0 commit comments

Comments
 (0)