Skip to content

Commit 1c0cc10

Browse files
author
Michael Pegios
committed
add tests for converter
1 parent a89e71e commit 1c0cc10

1 file changed

Lines changed: 150 additions & 0 deletions

File tree

packages/utils/tests/data/test_converter.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414

1515

16+
import numpy as np
17+
import pandas as pd
18+
import pytest
1619
import xarray as xr
1720

1821
from pyearthtools.utils.data import converter
@@ -21,6 +24,24 @@
2124
SIMPLE_DATA_SET = xr.Dataset({"Entry": SIMPLE_DATA_ARRAY})
2225

2326

27+
def test_NumpyConverterTuple():
28+
"""
29+
Checks NumpyConverter using Tuples
30+
"""
31+
dataset_tuple = (SIMPLE_DATA_SET.copy(), SIMPLE_DATA_SET.copy())
32+
nc = converter.NumpyConverter()
33+
np_array = nc.convert_from_xarray(dataset_tuple)
34+
xr_ds = nc.convert_to_xarray(np_array)
35+
36+
assert isinstance(xr_ds, tuple)
37+
assert isinstance(xr_ds[0], xr.Dataset)
38+
assert isinstance(xr_ds[-1], xr.Dataset)
39+
40+
xr_ds1 = xr_ds[0]
41+
assert "Entry" in xr_ds1.data_vars
42+
xr.testing.assert_identical(xr_ds1["Entry"], SIMPLE_DATA_ARRAY)
43+
44+
2445
def test_NumpyConverter():
2546
"""
2647
Checks conversion from xarray to numpy and back
@@ -36,6 +57,49 @@ def test_NumpyConverter():
3657
assert "Entry" in xr_ds
3758
xr.testing.assert_identical(xr_ds["Entry"], SIMPLE_DATA_ARRAY)
3859

60+
# Data that hasn't been converted yet throws runtime error
61+
with pytest.raises(RuntimeError):
62+
nc = converter.NumpyConverter()
63+
_xr_ds = nc.convert_to_xarray(np.array([0, 1, 2]))
64+
65+
# Data with an empty dataset throws runtime error
66+
with pytest.raises(RuntimeError):
67+
nc = converter.NumpyConverter()
68+
_np_array = nc.convert_from_xarray(SIMPLE_DATA_SET)
69+
_xr_ds = nc.convert_to_xarray(np.empty(0))
70+
71+
# String data type throws error
72+
with pytest.raises(TypeError):
73+
nc = converter.NumpyConverter()
74+
_xr_ds = nc.convert_from_xarray(["wrong type"])
75+
76+
nc = converter.NumpyConverter()
77+
78+
# Create a DataArray with a 'time' dimension but no 'x' coordinate variable
79+
data_array = xr.DataArray(np.random.rand(5, 3), dims=["time", "x"], coords={"time": np.random.rand(5)})
80+
81+
# Create a Dataset from this DataArray
82+
ds = xr.Dataset({"my_variable": data_array})
83+
np_array = nc.convert_from_xarray(ds)
84+
85+
# Value error converting back to numpy array due to missing coord
86+
with pytest.raises(ValueError):
87+
_xr_ds = nc.convert_to_xarray(np_array)
88+
89+
# Check currently conversion handling
90+
ds = xr.Dataset(
91+
data_vars={"data_var": ("x", np.array([1, 2, 3]))},
92+
coords={
93+
"x": np.arange(3),
94+
"empty_coord_dim": np.arange(2), # Coordinate defining a new dimension but no data variable
95+
},
96+
)
97+
nc = converter.NumpyConverter()
98+
99+
# Throw runtime error that cannot record coordinate
100+
with pytest.raises(RuntimeError):
101+
_xr_ds = nc.convert_from_xarray(ds)
102+
39103

40104
def test_DaskConverter():
41105
"""
@@ -45,8 +109,94 @@ def test_DaskConverter():
45109
dc = converter.DaskConverter()
46110

47111
da_array = dc.convert_from_xarray(SIMPLE_DATA_SET)
112+
48113
da_array = da_array.compute()
49114
xr_ds = dc.convert_to_xarray(da_array)
50115
assert isinstance(xr_ds, xr.Dataset)
51116
assert "Entry" in xr_ds
52117
xr.testing.assert_identical(xr_ds["Entry"], SIMPLE_DATA_ARRAY)
118+
119+
dataset_tuple = (SIMPLE_DATA_SET.copy(), SIMPLE_DATA_SET.copy())
120+
da_array = dc.convert_from_xarray(dataset_tuple)
121+
122+
xr_ds = dc.convert_to_xarray(da_array)
123+
assert isinstance(xr_ds, tuple)
124+
assert "Entry" in xr_ds
125+
xr.testing.assert_identical(xr_ds[0]["Entry"], SIMPLE_DATA_ARRAY)
126+
127+
128+
def test_save_and_load_records(tmpdir):
129+
"""
130+
Test save and load records functionality.
131+
"""
132+
tmp_path = tmpdir.mkdir("sub").join("nc.records")
133+
134+
time_index = pd.date_range("2025-01-01", periods=3, freq="D")
135+
136+
data = np.random.rand(3, 3, 3)
137+
138+
coords = {"time": time_index, "x": [np.nan, np.nan, np.nan], "y": np.random.randn(3)}
139+
140+
test_data_array = xr.DataArray(
141+
data,
142+
coords=coords,
143+
dims=["time", "x", "y"],
144+
)
145+
146+
test_dataset = xr.Dataset({"entry": test_data_array})
147+
148+
nc = converter.NumpyConverter()
149+
nc.convert_from_xarray(test_data_array)
150+
nc.save_records(tmp_path)
151+
saved_records = nc.records.copy()
152+
153+
# Add extra record to the numpy converter
154+
nc.convert_from_xarray(test_data_array)
155+
156+
assert len(saved_records) == 1
157+
assert len(nc.records) == 2
158+
assert nc.records != saved_records # Assert saved records are not the same as converted records
159+
160+
assert nc.load_records(tmp_path) # Assert it loads correctly
161+
assert len(nc.records) == 1
162+
163+
loaded_records = nc.records.copy()
164+
loaded_vars = loaded_records[0]["coords"]["x"]
165+
saved_vars = saved_records[0]["coords"]["x"]
166+
167+
# Check if variables are equal
168+
np.testing.assert_equal(saved_vars, loaded_vars)
169+
170+
# Broken path returns False
171+
assert not nc.load_records("/broken/path")
172+
173+
# Trigger datetime instance in save_records.parse
174+
nc = converter.NumpyConverter()
175+
nc.convert_from_xarray(test_data_array)
176+
nc.save_records(tmp_path)
177+
178+
# Trigger np.isnan(v).all() in save_records.parse
179+
nc = converter.NumpyConverter()
180+
da_array = nc.convert_from_xarray(test_dataset)
181+
nc.convert_to_xarray(da_array)
182+
nc.save_records(tmp_path)
183+
184+
185+
def test_non_shared_coordinates_throws_value_error():
186+
"""
187+
Test if coordinates are not shared will throw an appropriate value error
188+
"""
189+
x_coords = np.array([0, 1, 2])
190+
y_coords = np.array([0, 1, 2])
191+
192+
data_1 = np.random.rand(len(x_coords), len(y_coords))
193+
data_2 = np.random.rand(len(x_coords))
194+
missing_coord_ds = xr.Dataset(
195+
{"data_1": (("x", "y"), data_1), "data_2": (("x"), data_2)},
196+
coords={"x": x_coords, "y": y_coords},
197+
)
198+
nc = converter.NumpyConverter()
199+
200+
# Cannot stack variables so will raise a value error
201+
with pytest.raises(ValueError):
202+
nc.convert_from_xarray(missing_coord_ds)

0 commit comments

Comments
 (0)