forked from data-apis/array-api-strict
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_searching_functions.py
More file actions
122 lines (93 loc) · 4.16 KB
/
_searching_functions.py
File metadata and controls
122 lines (93 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Literal
import numpy as np
from ._array_object import Array
from ._dtypes import _real_numeric_dtypes, _result_type
from ._dtypes import bool as _bool
from ._flags import requires_api_version, requires_data_dependent_shapes, get_array_api_strict_flags
from ._helpers import _maybe_normalize_py_scalars
def argmax(x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array:
"""
Array API compatible wrapper for :py:func:`np.argmax <numpy.argmax>`.
See its docstring for more information.
"""
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmax")
return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims)), device=x.device)
def argmin(x: Array, /, *, axis: int | None = None, keepdims: bool = False) -> Array:
"""
Array API compatible wrapper for :py:func:`np.argmin <numpy.argmin>`.
See its docstring for more information.
"""
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmin")
return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims)), device=x.device)
@requires_data_dependent_shapes
def nonzero(x: Array, /) -> tuple[Array, ...]:
"""
Array API compatible wrapper for :py:func:`np.nonzero <numpy.nonzero>`.
See its docstring for more information.
"""
# Note: nonzero is disallowed on 0-dimensional arrays
if x.ndim == 0:
raise ValueError("nonzero is not allowed on 0-dimensional arrays")
return tuple(Array._new(i, device=x.device) for i in np.nonzero(x._array))
@requires_api_version('2024.12')
def count_nonzero(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.count_nonzero <numpy.count_nonzero>`
See its docstring for more information.
"""
arr = np.count_nonzero(x._array, axis=axis, keepdims=keepdims)
return Array._new(np.asarray(arr), device=x.device)
@requires_api_version('2023.12')
def searchsorted(
x1: Array,
x2: Array | int | float,
/,
*,
side: Literal["left", "right"] = "left",
sorter: Array | None = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.searchsorted <numpy.searchsorted>`.
See its docstring for more information.
"""
flags = get_array_api_strict_flags()
if flags["api_version"] >= "2025.12":
# scalar x2 support is new in 2025.12
if isinstance(x2, bool | int | float | complex):
x2 = x1._promote_scalar(x2)
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in searchsorted")
if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
np_sorter = sorter._array if sorter is not None else None
# TODO: The sort order of nans and signed zeros is implementation
# dependent. Should we error/warn if they are present?
# x1 must be 1-D, but NumPy already requires this.
return Array._new(
np.searchsorted(x1._array, x2._array, side=side, sorter=np_sorter),
device=x1.device,
)
def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
See its docstring for more information.
"""
if not isinstance(condition, Array):
raise TypeError(f"`condition` must be an Array; got {type(condition)}")
x1, x2 = _maybe_normalize_py_scalars(x1, x2, "all", "where")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
if condition.dtype != _bool:
raise TypeError("`condition` must be have a boolean data type")
if len({a.device for a in (condition, x1, x2)}) > 1:
raise ValueError("Inputs to `where` must all use the same device")
x1, x2 = Array._normalize_two_args(x1, x2)
return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device)