Skip to content

Commit 3ba7c42

Browse files
committed
Closes #5175: overloads for in1d
1 parent dd045c7 commit 3ba7c42

7 files changed

Lines changed: 249 additions & 29 deletions

File tree

arkouda/numpy/pdarraysetops.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Sequence, TypeVar, Union, cast
3+
from typing import TYPE_CHECKING, Literal, Sequence, Tuple, TypeVar, Union, cast, overload
44

55
import numpy as np
66

@@ -19,7 +19,7 @@
1919

2020

2121
if TYPE_CHECKING:
22-
from arkouda.numpy.pdarraycreation import array, zeros, zeros_like
22+
from arkouda.numpy.pdarraycreation import array, zeros_like
2323
from arkouda.numpy.strings import Strings
2424
from arkouda.pandas.categorical import Categorical
2525
else:
@@ -94,6 +94,7 @@ def _in1d_single(
9494
array([False True])
9595
"""
9696
from arkouda.client import generic_msg
97+
from arkouda.numpy.pdarraycreation import zeros
9798
from arkouda.numpy.strings import Strings
9899
from arkouda.pandas.categorical import Categorical as Categorical_
99100

@@ -138,22 +139,46 @@ def _in1d_single(
138139
raise TypeError("Both pda1 and pda2 must be pdarray, Strings, or Categorical")
139140

140141

142+
@overload
143+
def in1d(
144+
A: groupable,
145+
B: groupable,
146+
assume_unique: bool = ...,
147+
symmetric: Literal[False] = ...,
148+
invert: bool = ...,
149+
) -> pdarray: ...
150+
151+
152+
@overload
153+
def in1d(
154+
A: groupable,
155+
B: groupable,
156+
assume_unique: bool = ...,
157+
symmetric: Literal[True] = ...,
158+
invert: bool = ...,
159+
) -> Tuple[pdarray, pdarray]: ...
160+
161+
141162
@typechecked
142163
def in1d(
143164
A: groupable,
144165
B: groupable,
145166
assume_unique: bool = False,
146167
symmetric: bool = False,
147168
invert: bool = False,
148-
) -> groupable:
169+
) -> Union[pdarray, Tuple[pdarray, pdarray]]:
149170
"""
150171
Test whether each element of a 1-D array is also present in a second array.
151172
152-
Returns a boolean array the same length as `A` that is True
153-
where an element of `A` is in `B` and False otherwise.
173+
If ``symmetric=False`` (default), returns a boolean pdarray of the same
174+
shape as ``A`` indicating whether each element of ``A`` is in ``B``.
175+
176+
If ``symmetric=True``, returns a tuple ``(maskA, maskB)`` where:
177+
178+
* ``maskA[i]`` is True iff ``A[i]`` is in ``B``
179+
* ``maskB[j]`` is True iff ``B[j]`` is in ``A``
154180
155-
Supports multi-level, i.e. test if rows of a are in the set of rows of b.
156-
But note that multi-dimensional pdarrays are not supported.
181+
If ``invert=True``, the returned mask(s) are logically inverted.
157182
158183
Parameters
159184
----------
@@ -223,7 +248,7 @@ def in1d(
223248
raise TypeError("If A is pdarray, B must also be pdarray")
224249
elif isinstance(B, (pdarray, Strings, Categorical_)):
225250
if symmetric:
226-
return _in1d_single(A, B), _in1d_single(B, A, invert)
251+
return _in1d_single(A, B, invert), _in1d_single(B, A, invert)
227252
return _in1d_single(A, B, invert)
228253
else:
229254
raise TypeError(
@@ -260,18 +285,25 @@ def in1d(
260285
if assume_unique:
261286
# Deinterleave truth into a and b domains
262287
if symmetric:
263-
return truth[isa], truth[~isa] if not invert else ~truth[isa], ~truth[~isa]
288+
aout = truth[isa]
289+
bout = truth[~isa]
290+
if invert:
291+
return ~aout, ~bout
292+
return aout, bout
264293
else:
265-
return truth[isa] if not invert else ~truth[isa]
294+
aout = truth[isa]
295+
return ~aout if invert else aout
266296
else:
267297
# If didn't start unique, first need to deinterleave into ua domain,
268298
# then broadcast to a domain
269299
atruth = ag.broadcast(truth[isa], permute=True)
270300
if symmetric:
271301
btruth = bg.broadcast(truth[~isa], permute=True)
272-
return atruth, btruth if not invert else ~atruth, ~btruth
302+
if invert:
303+
return ~atruth, ~btruth
304+
return atruth, btruth
273305
else:
274-
return atruth if not invert else ~atruth
306+
return ~atruth if invert else atruth
275307

276308

277309
def in1dmulti(a, b, assume_unique=False, symmetric=False):

arkouda/numpy/util.py

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,9 @@ def map(
991991
TypeError
992992
If `mapping` is not of type `dict` or `Series`.
993993
If `values` is not of type `pdarray`, `Categorical`, or `Strings`.
994+
ValueError
995+
If a mapping with tuple keys has inconsistent lengths, or if a MultiIndex
996+
mapping has a different number of levels than the GroupBy keys.
994997
995998
Examples
996999
--------
@@ -1012,29 +1015,97 @@ def map(
10121015
from arkouda.numpy.pdarraysetops import in1d
10131016
from arkouda.numpy.strings import Strings
10141017
from arkouda.pandas.categorical import Categorical
1018+
from arkouda.pandas.index import MultiIndex
10151019

10161020
keys = values
10171021
gb = GroupBy(keys, dropna=False)
10181022
gb_keys = gb.unique_keys
10191023

1024+
# helper: number of unique keys (works for single key or tuple-of-keys)
1025+
nuniq = gb_keys[0].size if isinstance(gb_keys, tuple) else gb_keys.size
1026+
1027+
# Fast-path: empty mapping => everything is missing
1028+
if (isinstance(mapping, dict) and len(mapping) == 0) or (
1029+
isinstance(mapping, Series) and len(mapping.index) == 0
1030+
):
1031+
if not isinstance(values, (Strings, Categorical)):
1032+
fillvals = full(nuniq, np.nan, values.dtype)
1033+
else:
1034+
fillvals = full(nuniq, "null")
1035+
return broadcast(gb.segments, fillvals, permutation=gb.permutation)
1036+
10201037
if isinstance(mapping, dict):
1021-
mapping = Series([array(list(mapping.keys())), array(list(mapping.values()))])
1038+
# Build mapping as a Series with an Index/MultiIndex (avoid rank>1 arrays)
1039+
m_keys = list(mapping.keys())
1040+
m_vals = list(mapping.values())
1041+
1042+
k0 = m_keys[0]
1043+
if isinstance(k0, tuple):
1044+
# validate tuple keys
1045+
if not all(isinstance(k, tuple) for k in m_keys):
1046+
raise TypeError("Mixed key types in mapping dict (tuple and non-tuple).")
1047+
n = len(k0)
1048+
if not all(len(k) == n for k in m_keys):
1049+
raise ValueError("All tuple keys in mapping dict must have the same length.")
1050+
1051+
cols = list(zip(*m_keys)) # transpose list[tuple] -> list[level]
1052+
idx = MultiIndex([array(col) for col in cols])
1053+
mapping = Series(array(m_vals), index=idx)
1054+
else:
1055+
mapping = Series(array(m_vals), index=array(m_keys))
10221056

10231057
if isinstance(mapping, Series):
1024-
xtra_keys = gb_keys[in1d(gb_keys, mapping.index.values, invert=True)]
1058+
# Normalize mapping index keys into a "groupable" (single array OR tuple-of-arrays)
1059+
mindex = mapping.index
1060+
if isinstance(mindex, MultiIndex):
1061+
mkeys = tuple(mindex.index)
1062+
else:
1063+
mkeys = mindex.values
10251064

1026-
if xtra_keys.size > 0:
1027-
if not isinstance(mapping.values, (Strings, Categorical)):
1028-
nans = full(xtra_keys.size, np.nan, mapping.values.dtype)
1029-
else:
1030-
nans = full(xtra_keys.size, "null")
1065+
if isinstance(gb_keys, tuple) and isinstance(mkeys, tuple):
1066+
if len(gb_keys) != len(mkeys):
1067+
raise ValueError(
1068+
f"Mapping MultiIndex has {len(mkeys)} levels but GroupBy has {len(gb_keys)} keys"
1069+
)
1070+
1071+
mask = in1d(gb_keys, mkeys, invert=True)
1072+
1073+
# Compute extra keys + extra size without mixing tuple/non-tuple assignments
1074+
if isinstance(gb_keys, tuple):
1075+
xtra_keys_t = tuple(k[mask] for k in gb_keys)
1076+
xtra_size = xtra_keys_t[0].size if len(xtra_keys_t) > 0 else 0
1077+
1078+
if xtra_size > 0:
1079+
if not isinstance(mapping.values, (Strings, Categorical)):
1080+
nans = full(xtra_size, np.nan, mapping.values.dtype)
1081+
else:
1082+
nans = full(xtra_size, "null")
1083+
1084+
# Convert any categorical levels to strings, level-by-level
1085+
xtra_keys_t = tuple(
1086+
k.to_strings() if isinstance(k, Categorical) else k for k in xtra_keys_t
1087+
)
1088+
1089+
xtra_series = Series(nans, index=MultiIndex(list(xtra_keys_t)))
1090+
mapping = Series.concat([mapping, xtra_series])
1091+
1092+
else:
1093+
xtra_keys_s = gb_keys[mask]
1094+
xtra_size = xtra_keys_s.size
1095+
1096+
if xtra_size > 0:
1097+
if not isinstance(mapping.values, (Strings, Categorical)):
1098+
nans = full(xtra_size, np.nan, mapping.values.dtype)
1099+
else:
1100+
nans = full(xtra_size, "null")
10311101

1032-
if isinstance(xtra_keys, Categorical):
1033-
xtra_keys = xtra_keys.to_strings()
1102+
if isinstance(xtra_keys_s, Categorical):
1103+
xtra_keys_s = xtra_keys_s.to_strings()
10341104

1035-
xtra_series = Series(nans, index=xtra_keys)
1036-
mapping = Series.concat([mapping, xtra_series])
1105+
xtra_series = Series(nans, index=xtra_keys_s)
1106+
mapping = Series.concat([mapping, xtra_series])
10371107

1108+
# Align mapping to gb_keys
10381109
if isinstance(gb_keys, Categorical):
10391110
mapping = mapping[gb_keys.to_strings()]
10401111
else:

arkouda/pandas/groupbyclass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686

8787
groupable_element_type = Union[pdarray, Strings, "Categorical"]
8888
groupable = Union[groupable_element_type, Sequence[groupable_element_type]]
89+
8990
# Note: we won't be typechecking GroupBy until we can figure out a way to handle
9091
# the circular import with Categorical
9192

arkouda/pandas/join.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from arkouda.numpy.pdarrayclass import create_pdarray, pdarray
1919
from arkouda.numpy.pdarraysetops import concatenate, in1d
2020
from arkouda.pandas.categorical import Categorical
21-
from arkouda.pandas.groupbyclass import GroupBy, broadcast
21+
from arkouda.pandas.groupbyclass import GroupBy, broadcast, groupable_element_type
2222

2323

2424
if TYPE_CHECKING:
@@ -198,8 +198,8 @@ def compute_join_size(a: pdarray, b: pdarray) -> Tuple[int, int]:
198198
ua, asize = bya.size()
199199
byb = GroupBy(b)
200200
ub, bsize = byb.size()
201-
afact = asize[in1d(ua, ub)]
202-
bfact = bsize[in1d(ub, ua)]
201+
afact = asize[in1d(cast(groupable_element_type, ua), cast(groupable_element_type, ub))]
202+
bfact = bsize[in1d(cast(groupable_element_type, ub), cast(groupable_element_type, ua))]
203203
nelem = (afact * bfact).sum()
204204
nbytes = 3 * 8 * nelem
205205
return nelem, nbytes

arkouda/pandas/series.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from arkouda.numpy.pdarrayclass import RegistrationError, any, argmaxk, create_pdarray, pdarray
2020
from arkouda.numpy.pdarraysetops import argsort, concatenate, in1d, indexof1d
2121
from arkouda.numpy.util import get_callback, is_float
22-
from arkouda.pandas.groupbyclass import GroupBy, groupable_element_type
22+
from arkouda.pandas.groupbyclass import GroupBy, groupable, groupable_element_type
2323
from arkouda.pandas.index import Index, MultiIndex
2424

2525

@@ -429,6 +429,7 @@ def __setitem__(
429429
"""
430430
from arkouda.numpy.pdarraycreation import array
431431
from arkouda.numpy.strings import Strings
432+
from arkouda.pandas.categorical import Categorical
432433

433434
val = self.validate_val(val)
434435
key = self.validate_key(key)
@@ -440,7 +441,23 @@ def __setitem__(
440441
if is_supported_scalar(key):
441442
indices = self.index == key
442443
else:
443-
indices = in1d(self.index.values, key)
444+
# mypy: key may be scalar/SegArray/etc, but in1d only accepts groupables
445+
if not isinstance(key, (pdarray, Strings, Categorical, list, tuple)):
446+
raise TypeError(f"Unsupported key type for membership test: {type(key)}")
447+
448+
# If key is a python list/tuple, it will be validated/converted by validate_key in many paths
449+
# but if it slips through, convert here.
450+
if (
451+
isinstance(self.index, MultiIndex)
452+
and isinstance(key, tuple)
453+
and len(key) == self.index.nlevels
454+
):
455+
indices = self.index.lookup(key) # returns boolean mask
456+
else:
457+
if isinstance(key, list):
458+
key = array(key)
459+
indices = in1d(self.index.values, cast(groupable, key))
460+
444461
tf, counts = GroupBy(indices).size()
445462
update_count = counts[1] if len(counts) == 2 else 0
446463
if update_count == 0:
@@ -614,10 +631,28 @@ def isin(self, lst: Union[pdarray, Strings, List]) -> Series:
614631
and False otherwise.
615632
616633
"""
634+
from arkouda.numpy.pdarraycreation import array
635+
from arkouda.numpy.strings import Strings
636+
from arkouda.pandas.categorical import Categorical
637+
617638
if isinstance(lst, list):
618639
lst = array(lst)
619640

620-
boolean = in1d(self.values, lst)
641+
# mypy: lst/self.values can be a wider union (SegArray/Any) at type level.
642+
# At runtime, in1d only supports pdarray/Strings/Categorical (or sequences of those).
643+
if not isinstance(self.values, (pdarray, Strings, Categorical)):
644+
raise TypeError(f"in1d not supported for Series values type: {type(self.values)}")
645+
646+
if not isinstance(lst, (pdarray, Strings, Categorical, list, tuple)):
647+
raise TypeError(f"in1d not supported for list type: {type(lst)}")
648+
649+
if isinstance(lst, (list, tuple)):
650+
lst = array(lst)
651+
652+
boolean = in1d(
653+
cast(groupable_element_type, self.values),
654+
cast(groupable_element_type, lst),
655+
)
621656
return Series(data=boolean, index=self.index)
622657

623658
@typechecked

tests/numpy/setops_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,43 @@ def test_in1d_multiarray_categorical(self, size):
238238
stringsTwo = ak.Categorical(ak.array(["String {}".format(i % 2) for i in range(10)]))
239239
assert [(x % 3) < 2 for x in range(10)] == ak.in1d(stringsOne, stringsTwo).tolist()
240240

241+
@pytest.mark.requires_chapel_module("In1dMsg")
242+
def test_in1d_symmetric(self):
243+
# Duplicates to exercise assume_unique=False (GroupBy/broadcast path)
244+
a = ak.array([1, 2, 2, 3, 4])
245+
b = ak.array([2, 4, 4, 5])
246+
247+
def exp_in(x, y):
248+
yset = set(y)
249+
return [xi in yset for xi in x]
250+
251+
a_list = a.to_ndarray().tolist()
252+
b_list = b.to_ndarray().tolist()
253+
254+
# assume_unique=False path (duplicates allowed; should match membership semantics)
255+
am2, bm2 = ak.in1d(a, b, assume_unique=False, symmetric=True, invert=False)
256+
assert am2.tolist() == exp_in(a_list, b_list)
257+
assert bm2.tolist() == exp_in(b_list, a_list)
258+
259+
am2_i, bm2_i = ak.in1d(a, b, assume_unique=False, symmetric=True, invert=True)
260+
assert am2_i.tolist() == [not v for v in exp_in(a_list, b_list)]
261+
assert bm2_i.tolist() == [not v for v in exp_in(b_list, a_list)]
262+
263+
# assume_unique=True path (inputs must be unique for this branch to be valid)
264+
au = ak.array([1, 2, 3, 4])
265+
bu = ak.array([2, 4, 5])
266+
267+
au_list = au.to_ndarray().tolist()
268+
bu_list = bu.to_ndarray().tolist()
269+
270+
am, bm = ak.in1d(au, bu, assume_unique=True, symmetric=True, invert=False)
271+
assert am.tolist() == exp_in(au_list, bu_list)
272+
assert bm.tolist() == exp_in(bu_list, au_list)
273+
274+
am_i, bm_i = ak.in1d(au, bu, assume_unique=True, symmetric=True, invert=True)
275+
assert am_i.tolist() == [not v for v in exp_in(au_list, bu_list)]
276+
assert bm_i.tolist() == [not v for v in exp_in(bu_list, au_list)]
277+
241278
@pytest.mark.parametrize("size", pytest.prob_size)
242279
@pytest.mark.parametrize("dtype", INTEGRAL_TYPES)
243280
def test_intersect1d_multiarray_numeric_types(self, size, dtype):

0 commit comments

Comments
 (0)