From fbe6b6c6e4e935c584c633d5a5def77290f34b72 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Mon, 1 Jun 2026 11:42:39 -0500 Subject: [PATCH 1/2] Add per-group ref_channel_ids to common_reference (cross-group referencing) With reference="global" + groups, common_reference referenced each group to its OWN channels and ignored ref_channel_ids, despite the docstring stating "a list of channels to be applied to each group is expected". This allows ref_channel_ids to be a list of per-group channel-id lists: the reference subtracted from each group is the operator (median/average) over that group's reference set, which may include channels OUTSIDE the group. This enables cross-group referencing (e.g. each tetrode referenced to the median of all channels on the other tetrodes). ref_channel_ids=None (default) keeps the previous own-group behavior. --- .../preprocessing/common_reference.py | 33 +++++++++++++++++-- .../tests/test_common_reference.py | 25 ++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 5a3a9b0043..5a6ffad35c 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -55,7 +55,14 @@ class CommonReferenceRecording(BasePreprocessor): ref_channel_ids : list | str | int | None, default: None If "global" reference, a list of channels to be used as reference. If "single" reference, a list of one channel or a single channel id is expected. - If "groups" is provided, then a list of channels to be applied to each group is expected. + If "groups" is provided with "single" reference, a list with one reference channel id + per group is expected. + If "groups" is provided with "global" reference, a list with one *list* of reference + channel ids per group is expected: the reference subtracted from each group is the + operator (median/average) over that group's reference set. The reference set may contain + channels outside the group, enabling cross-group referencing (e.g. referencing each + tetrode to the median of all channels on the other tetrodes). If None, each group is + referenced to its own channels. local_radius : tuple(int, int), default: (30, 55) Use in the local CAR implementation as the selecting annulus with the following format: @@ -101,6 +108,18 @@ def __init__( if ref_channel_ids is not None: if not isinstance(ref_channel_ids, list): raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") + if groups is not None: + # Per-group reference sets: one list of channel ids per group. The reference + # subtracted from each group is the operator over that group's reference set + # (which may be channels outside the group, e.g. for cross-group referencing). + assert len(ref_channel_ids) == len(groups), ( + "With 'global' reference and 'groups', 'ref_channel_ids' must be a list " + "with one channel-id list per group" + ) + assert all(isinstance(r, (list, np.ndarray)) for r in ref_channel_ids), ( + "With 'global' reference and 'groups', each element of 'ref_channel_ids' " + "must itself be a list of channel ids (the reference set for that group)" + ) elif reference == "single": assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'" if groups is not None: @@ -150,7 +169,11 @@ def __init__( else: group_indices = None if ref_channel_ids is not None: - ref_channel_indices = self.ids_to_indices(ref_channel_ids) + if reference == "global" and groups is not None: + # one reference-channel index array per group + ref_channel_indices = [self.ids_to_indices(r) for r in ref_channel_ids] + else: + ref_channel_indices = self.ids_to_indices(ref_channel_ids) else: ref_channel_indices = None @@ -246,7 +269,11 @@ def get_traces(self, start_frame, end_frame, channel_indices): in_group_traces = traces[:, selected_indices_in_group] if self.reference == "global": - shift = self.operator_func(traces[:, all_group_indices], axis=1, keepdims=True) + if self.ref_channel_indices is None: + ref_indices = all_group_indices # reference each group to its own channels + else: + ref_indices = self.ref_channel_indices[group_index] # per-group reference set + shift = self.operator_func(traces[:, ref_indices], axis=1, keepdims=True) re_referenced_traces[:, out_indices] = in_group_traces - shift else: # single (as local is not allowed for groups) diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index e19cad59ba..417605bc85 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -170,6 +170,31 @@ def test_common_reference_groups(recording): assert np.allclose(traces[:, 1], 0) +def test_common_reference_groups_cross(recording): + # "global" reference with groups AND a per-group ref_channel_ids: each group is + # referenced to a (possibly external) set of channels -> enables cross-group referencing. + original_traces = recording.get_traces() + groups = [["a", "c"], ["b", "d"]] + ref_channel_ids = [["b", "d"], ["a", "c"]] # reference each group to the OTHER group's channels + + rec_cross = common_reference( + recording, reference="global", operator="median", groups=groups, ref_channel_ids=ref_channel_ids + ) + traces = rec_cross.get_traces(channel_ids=["a", "b", "c", "d"]) + # a, c (group 0) referenced to median of b, d + ref0 = np.median(original_traces[:, [1, 3]], axis=1) + assert np.allclose(traces[:, 0], original_traces[:, 0] - ref0, atol=0.01) + assert np.allclose(traces[:, 2], original_traces[:, 2] - ref0, atol=0.01) + # b, d (group 1) referenced to median of a, c + ref1 = np.median(original_traces[:, [0, 2]], axis=1) + assert np.allclose(traces[:, 1], original_traces[:, 1] - ref1, atol=0.01) + assert np.allclose(traces[:, 3], original_traces[:, 3] - ref1, atol=0.01) + + # mismatched lengths raise + with pytest.raises(AssertionError): + common_reference(recording, reference="global", groups=groups, ref_channel_ids=[["b", "d"]]) + + def test_min_local_radius(): # Test that local radius smaller than the number of channels is handled correctly recording = generate_recording(durations=[1.0], num_channels=32) From 75990833e240e87269c491b69f66ccf72b0b9dd1 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Mon, 1 Jun 2026 11:59:35 -0500 Subject: [PATCH 2/2] Add common_reference(..., ref_channel_ids="complement"). The most principled reference for a tetrode may sometimes be the average/median of all OTHER tetrodes. In other words, each group is referenced to all channels NOT in it -- ref_channel_ids is each group's complement (with "global" reference and "groups"). Adds common_reference(..., ref_channel_ids="complement") as syntactic sugar for this. --- .../preprocessing/common_reference.py | 9 +++++++++ .../tests/test_common_reference.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index 5a6ffad35c..d15301b943 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -63,6 +63,9 @@ class CommonReferenceRecording(BasePreprocessor): channels outside the group, enabling cross-group referencing (e.g. referencing each tetrode to the median of all channels on the other tetrodes). If None, each group is referenced to its own channels. + As a shortcut for that cross-group case, pass the string "complement" (with "global" + reference and "groups"): each group is then referenced to all channels NOT in it, + i.e. ref_channel_ids is auto-built as each group's complement. local_radius : tuple(int, int), default: (30, 55) Use in the local CAR implementation as the selecting annulus with the following format: @@ -105,6 +108,12 @@ def __init__( raise ValueError("'operator' must be either 'median', 'average'") if reference == "global": + if ref_channel_ids == "complement": + # Convenience: reference each group to all channels NOT in it (its complement). + if groups is None: + raise ValueError("ref_channel_ids='complement' requires 'groups' to be set") + all_ids = list(recording.channel_ids) + ref_channel_ids = [[c for c in all_ids if c not in set(group)] for group in groups] if ref_channel_ids is not None: if not isinstance(ref_channel_ids, list): raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list") diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index 417605bc85..3551d0dcb6 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -195,6 +195,23 @@ def test_common_reference_groups_cross(recording): common_reference(recording, reference="global", groups=groups, ref_channel_ids=[["b", "d"]]) +def test_common_reference_groups_complement(recording): + # ref_channel_ids="complement" shortcut: reference each group to all channels NOT in it. + groups = [["a", "c"], ["b", "d"]] + # complements of these groups within {a,b,c,d} are exactly [["b","d"], ["a","c"]] + explicit = common_reference( + recording, reference="global", operator="median", groups=groups, ref_channel_ids=[["b", "d"], ["a", "c"]] + ) + sugar = common_reference( + recording, reference="global", operator="median", groups=groups, ref_channel_ids="complement" + ) + assert np.allclose(sugar.get_traces(), explicit.get_traces(), atol=1e-6) + + # "complement" requires groups + with pytest.raises(ValueError): + common_reference(recording, reference="global", ref_channel_ids="complement") + + def test_min_local_radius(): # Test that local radius smaller than the number of channels is handled correctly recording = generate_recording(durations=[1.0], num_channels=32)