Skip to content

Commit 0b43a9f

Browse files
committed
add dask filter tests
1 parent 7813538 commit 0b43a9f

1 file changed

Lines changed: 109 additions & 0 deletions

File tree

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright Commonwealth of Australia, Bureau of Meteorology 2025.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pyearthtools.pipeline.operations.dask import filters
16+
from pyearthtools.pipeline.exceptions import PipelineFilterException
17+
18+
import numpy as np
19+
import dask.array as da
20+
import pytest
21+
22+
23+
def test_DropAnyNan():
24+
"""Tests DropAnyNan dask filter."""
25+
26+
original = da.ones((2, 2))
27+
28+
# no nans - should succeed quietly
29+
drop = filters.DropAnyNan()
30+
drop.filter(original)
31+
32+
# one nan - should raise exception
33+
original[0, 0] = np.nan
34+
drop = filters.DropAnyNan()
35+
with pytest.raises(PipelineFilterException):
36+
drop.filter(original)
37+
38+
39+
def test_DropAllNan():
40+
"""Tests DropAllNan dask filter."""
41+
42+
original = da.empty((2, 2))
43+
44+
# no nans - should succeed quietly
45+
drop = filters.DropAllNan()
46+
drop.filter(original)
47+
48+
# one nan - should succeed quietly
49+
original[0, 0] = np.nan
50+
drop.filter(original)
51+
52+
# all nans - should raise exception
53+
original[:, :] = np.nan
54+
with pytest.raises(PipelineFilterException):
55+
drop.filter(original)
56+
57+
58+
def test_DropValue():
59+
"""Tests DropValue dask filter."""
60+
61+
original = da.from_array([[0, 0], [1, 2]])
62+
63+
# drop case (num zeros < threshold)
64+
drop = filters.DropValue(0, 75)
65+
with pytest.raises(PipelineFilterException):
66+
drop.filter(original)
67+
68+
# non-drop case (num zeros >= threshold)
69+
drop = filters.DropValue(0, 50)
70+
drop.filter(original)
71+
72+
# drop case (num nans < threshold)
73+
original = da.from_array([[np.nan, np.nan], [1, 2]])
74+
drop = filters.DropValue("nan", 75)
75+
with pytest.raises(PipelineFilterException):
76+
drop.filter(original)
77+
78+
# non-drop case (num nans >= threshold)
79+
drop = filters.DropValue("nan", 50)
80+
drop.filter(original)
81+
82+
83+
def test_Shape():
84+
"""Tests Shape dask filter."""
85+
86+
originals = (da.empty((2, 2)), da.empty((2, 3)))
87+
88+
# check drop case
89+
drop = filters.Shape((2, 3))
90+
with pytest.raises(PipelineFilterException):
91+
drop.filter(originals[0])
92+
93+
# check non-drop case
94+
drop = filters.Shape((2, 2))
95+
drop.filter(originals[0])
96+
97+
# check tuple inputs drop cases
98+
drop = filters.Shape(((2, 3), (2, 3)))
99+
with pytest.raises(PipelineFilterException):
100+
drop.filter(originals)
101+
102+
# check tuple inputs non-drop cases
103+
drop = filters.Shape(((2, 2), (2, 3)))
104+
drop.filter(originals)
105+
106+
# invalid mismatched shape and input
107+
drop = filters.Shape(((2, 2),))
108+
with pytest.raises(RuntimeError):
109+
drop.filter(originals)

0 commit comments

Comments
 (0)