Skip to content

Commit bc7139f

Browse files
authored
Add an implementation of np.vdot to PytatoPyOpenCLArrayContext (#299)
* Add an implementation of vdot to the PytatoPyOpenCLArrayContext np namespace. * Remove the tests that are just skipped for scalars. * Respond to comments. * Ruff version needed to be updated locally.
1 parent a88f08d commit bc7139f

3 files changed

Lines changed: 9 additions & 7 deletions

File tree

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,4 +239,7 @@ def amin(self, a, axis=None):
239239
def absolute(self, a):
240240
return self.abs(a)
241241

242+
def vdot(self, a: Array, b: Array):
243+
244+
return rec_multimap_array_container(pt.vdot, a, b)
242245
# }}}

test/test_arraycontext.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,10 @@ def evaluate(np_, *args_):
271271

272272
assert_close_to_numpy_in_containers(actx, evaluate, args)
273273

274-
if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]:
275-
pytest.skip(f"'{sym_name}' not supported on scalars")
276-
277-
args = [randn(0, dtype)[()] for i in range(n_args)]
278-
assert_close_to_numpy(actx, evaluate, args)
274+
if sym_name not in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]:
275+
# Scalar arguments are supported.
276+
args = [randn(0, dtype)[()] for i in range(n_args)]
277+
assert_close_to_numpy(actx, evaluate, args)
279278

280279

281280
@pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [

test/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
THE SOFTWARE.
2828
"""
2929
import logging
30-
from typing import Optional, cast
30+
from typing import cast
3131

3232
import numpy as np
3333
import pytest
@@ -63,7 +63,7 @@ def test_dataclass_array_container() -> None:
6363
class ArrayContainerWithOptional:
6464
x: np.ndarray
6565
# Deliberately left as Optional to test compatibility.
66-
y: Optional[np.ndarray] # noqa: UP007
66+
y: np.ndarray | None
6767

6868
with pytest.raises(TypeError, match="Field 'y' union contains non-array"):
6969
# NOTE: cannot have wrapped annotations (here by `Optional`)

0 commit comments

Comments
 (0)