We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e600449 commit 1acba0cCopy full SHA for 1acba0c
3 files changed
array_api_compat/cupy/_aliases.py
@@ -124,6 +124,11 @@ def count_nonzero(
124
return result
125
126
127
+# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
128
+def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
129
+ return cp.take_along_axis(x, indices, axis=axis)
130
+
131
132
# These functions are completely new here. If the library already has them
133
# (i.e., numpy 2.0), use the library version instead of our wrapper.
134
if hasattr(cp, 'vecdot'):
@@ -145,6 +150,7 @@ def count_nonzero(
145
150
'acos', 'acosh', 'asin', 'asinh', 'atan',
146
151
'atan2', 'atanh', 'bitwise_left_shift',
147
152
'bitwise_invert', 'bitwise_right_shift',
148
- 'bool', 'concat', 'count_nonzero', 'pow', 'sign']
153
+ 'bool', 'concat', 'count_nonzero', 'pow', 'sign',
154
+ 'take_along_axis']
149
155
156
_all_ignore = ['cp', 'get_xp']
array_api_compat/numpy/_aliases.py
@@ -140,6 +140,11 @@ def count_nonzero(
140
141
142
143
+# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
144
+ return np.take_along_axis(x, indices, axis=axis)
if hasattr(np, "vecdot"):
@@ -175,6 +180,7 @@ def count_nonzero(
175
180
"concat",
176
181
"count_nonzero",
177
182
"pow",
183
+ "take_along_axis"
178
184
]
179
185
__all__ += _aliases.__all__
186
_all_ignore = ["np", "get_xp"]
dask-xfails.txt
@@ -24,6 +24,9 @@ array_api_tests/test_creation_functions.py::test_linspace
24
# Shape mismatch
25
array_api_tests/test_indexing_functions.py::test_take
26
27
+# missing `take_along_axis`, https://github.com/dask/dask/issues/3663
28
+array_api_tests/test_indexing_functions.py::test_take_along_axis
29
30
# Array methods and attributes not already on da.Array cannot be wrapped
31
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
32
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
0 commit comments