diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 7b7bfda6..44808ec9 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -167,6 +167,15 @@ def searchsorted( return cp.searchsorted(x1, x2, side, sorter) +# CuPy isin does not accept scalars +def isin(x1: Array | int, x2: Array | int, /, *, invert: bool = False, **kwds) -> Array: + if isinstance(x1, int): + x1 = cp.asarray(x1) + if isinstance(x2, int): + x2 = cp.asarray(x2) + return cp.isin(x1, x2, invert=invert, **kwds) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -191,7 +200,7 @@ def searchsorted( 'bool', 'concat', 'count_nonzero', 'pow', 'sign', 'ceil', 'floor', 'trunc', 'take_along_axis', 'broadcast_arrays', 'meshgrid', - 'searchsorted', + 'searchsorted', 'isin', ]