From 1518fe6d47fcc9fe5625c62dc921088a7e377e87 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Mar 2026 16:54:21 +0100 Subject: [PATCH 1/6] Implement copy/get_slice/get_global_contact_positions for probegroup --- src/probeinterface/probegroup.py | 89 +++++++++++++++ tests/test_probegroup.py | 190 +++++++++++++++++++++++++++++++ 2 files changed, 279 insertions(+) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 0ece2830..af1f0041 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -50,6 +50,22 @@ def _check_compatible(self, probe: Probe): def ndim(self): return self.probes[0].ndim + def copy(self) -> "ProbeGroup": + """ + Create a copy of the ProbeGroup + + Returns + ------- + copy: ProbeGroup + A copy of the ProbeGroup + """ + copy = ProbeGroup() + for probe in self.probes: + copy.add_probe(probe.copy()) + global_device_channel_indices = self.get_global_device_channel_indices()["device_channel_indices"] + copy.set_global_device_channel_indices(global_device_channel_indices) + return copy + def get_contact_count(self) -> int: """ Total number of channels. @@ -249,6 +265,79 @@ def get_global_contact_ids(self) -> np.ndarray: contact_ids = self.to_numpy(complete=True)["contact_ids"] return contact_ids + def get_global_contact_positions(self) -> np.ndarray: + """ + Gets all contact positions concatenated across probes + + Returns + ------- + contact_positions: np.ndarray + An array of the contact positions across all probes + """ + contact_positions = np.vstack([probe.contact_positions for probe in self.probes]) + return contact_positions + + def get_slice(self, selection: np.ndarray[bool | int]): + """ + Get a copy of the ProbeGroup with a sub selection of contacts. + + Selection can be boolean or by index + + Parameters + ---------- + selection : np.array of bool or int (for index) + Either an np.array of bool or for desired selection of contacts + or the indices of the desired contacts + + Returns + ------- + sliced_probe_group: ProbeGroup + The sliced probe group + + """ + + n = self.get_contact_count() + + selection = np.asarray(selection) + if selection.dtype == "bool": + assert selection.shape == (n,), ( + f"if array of bool given it must be the same size " "as the number of contacts {selection.shape} != {n}" + ) + selection_indices, = np.nonzero(selection) + elif selection.dtype.kind == "i": + assert np.unique(selection).size == selection.size + if len(selection) > 0: + assert 0 <= np.min(selection) < n, f"An index within your selection is out of bounds {np.min(selection)}" + assert 0 <= np.max(selection) < n, f"An index within your selection is out of bounds {np.max(selection)}" + selection_indices = selection + else: + selection_indices = [] + else: + raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}") + + if len(selection_indices) == 0: + return ProbeGroup() + + # Map selection to indices of individual probes + d = self.to_dict(array_as_list=False) + ind = 0 + sliced_probes = [] + for probe in self.probes: + n = probe.get_contact_count() + probe_limits = (ind, ind + n) + ind += n + + probe_selection_indices = selection_indices[(selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1])] + if len(probe_selection_indices) == 0: + continue + sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) + sliced_probes.append(sliced_probe) + + sliced_probe_group = ProbeGroup() + sliced_probe_group.probes = sliced_probes + + return sliced_probe_group + def check_global_device_wiring_and_ids(self): # check unique device_channel_indices for !=-1 chans = self.get_global_device_channel_indices() diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 56bf97d3..416b1829 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -116,6 +116,196 @@ def test_set_contact_ids_rejects_wrong_size(): probe.set_contact_ids(["a", "b", "c"]) +def _make_probegroup(n_probes=3): + """Helper: build a ProbeGroup with device channel indices set.""" + probegroup = ProbeGroup() + nchan = 0 + for i in range(n_probes): + probe = generate_dummy_probe() + probe.move([i * 100, i * 80]) + n = probe.get_contact_count() + probe.set_device_channel_indices(np.arange(n) + nchan) + probe.set_contact_ids([f"e{j}" for j in range(nchan, nchan + n)]) + nchan += n + probegroup.add_probe(probe) + return probegroup + + +def _make_probegroup_full(n_probes=3): + """Helper: build a ProbeGroup where **every** probe is added.""" + probegroup = ProbeGroup() + nchan = 0 + for i in range(n_probes): + probe = generate_dummy_probe() + probe.move([i * 100, i * 80]) + n = probe.get_contact_count() + probe.set_device_channel_indices(np.arange(n) + nchan) + probe.set_contact_ids([f"e{j}" for j in range(nchan, nchan + n)]) + probegroup.add_probe(probe) + nchan += n + return probegroup + + +# ── copy() tests ──────────────────────────────────────────────────────────── + + +def test_copy_returns_new_object(): + pg = _make_probegroup_full(2) + pg_copy = pg.copy() + assert pg_copy is not pg + assert len(pg_copy.probes) == len(pg.probes) + for orig, copied in zip(pg.probes, pg_copy.probes): + assert orig is not copied + + +def test_copy_preserves_positions(): + pg = _make_probegroup_full(2) + pg_copy = pg.copy() + for orig, copied in zip(pg.probes, pg_copy.probes): + np.testing.assert_array_equal(orig.contact_positions, copied.contact_positions) + + +def test_copy_preserves_device_channel_indices(): + pg = _make_probegroup_full(2) + pg_copy = pg.copy() + np.testing.assert_array_equal( + pg.get_global_device_channel_indices(), + pg_copy.get_global_device_channel_indices(), + ) + + +def test_copy_does_not_preserve_contact_ids(): + """Probe.copy() intentionally does not copy contact_ids.""" + pg = _make_probegroup_full(2) + pg_copy = pg.copy() + # All contact_ids should be empty strings after copy + assert all(cid == "" for cid in pg_copy.get_global_contact_ids()) + + +def test_copy_is_independent(): + """Mutating the copy must not affect the original.""" + pg = _make_probegroup_full(2) + original_positions = pg.probes[0].contact_positions.copy() + pg_copy = pg.copy() + pg_copy.probes[0].move([999, 999]) + np.testing.assert_array_equal(pg.probes[0].contact_positions, original_positions) + + +# ── get_slice() tests ─────────────────────────────────────────────────────── + + +def test_get_slice_by_bool(): + pg = _make_probegroup_full(2) + total = pg.get_contact_count() + sel = np.zeros(total, dtype=bool) + sel[:5] = True # first 5 contacts from the first probe + sliced = pg.get_slice(sel) + assert sliced.get_contact_count() == 5 + + +def test_get_slice_by_index(): + pg = _make_probegroup_full(2) + indices = np.array([0, 1, 2, 33, 34]) # contacts from both probes + sliced = pg.get_slice(indices) + assert sliced.get_contact_count() == 5 + + +def test_get_slice_preserves_device_channel_indices(): + pg = _make_probegroup_full(2) + indices = np.array([0, 1, 2]) + sliced = pg.get_slice(indices) + orig_chans = pg.get_global_device_channel_indices()["device_channel_indices"][:3] + sliced_chans = sliced.get_global_device_channel_indices()["device_channel_indices"] + np.testing.assert_array_equal(sliced_chans, orig_chans) + + +def test_get_slice_preserves_positions(): + pg = _make_probegroup_full(2) + indices = np.array([0, 1, 2]) + sliced = pg.get_slice(indices) + expected = pg.get_global_contact_positions()[indices] + np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected) + + +def test_get_slice_empty_selection(): + pg = _make_probegroup_full(2) + sliced = pg.get_slice(np.array([], dtype=int)) + assert sliced.get_contact_count() == 0 + assert len(sliced.probes) == 0 + + +def test_get_slice_wrong_bool_size(): + pg = _make_probegroup_full(2) + with pytest.raises(AssertionError): + pg.get_slice(np.array([True, False])) # wrong size + + +def test_get_slice_out_of_bounds(): + pg = _make_probegroup_full(2) + total = pg.get_contact_count() + with pytest.raises(AssertionError): + pg.get_slice(np.array([total + 10])) + + +def test_get_slice_all_contacts(): + """Slicing with all contacts should give an equivalent ProbeGroup.""" + pg = _make_probegroup_full(2) + total = pg.get_contact_count() + sliced = pg.get_slice(np.arange(total)) + assert sliced.get_contact_count() == total + np.testing.assert_array_equal( + sliced.get_global_contact_positions(), + pg.get_global_contact_positions(), + ) + + +# ── get_global_contact_positions() tests ──────────────────────────────────── + + +def test_get_global_contact_positions_shape(): + pg = _make_probegroup_full(3) + pos = pg.get_global_contact_positions() + assert pos.shape == (pg.get_contact_count(), pg.ndim) + + +def test_get_global_contact_positions_matches_per_probe(): + pg = _make_probegroup_full(3) + pos = pg.get_global_contact_positions() + offset = 0 + for probe in pg.probes: + n = probe.get_contact_count() + np.testing.assert_array_equal(pos[offset : offset + n], probe.contact_positions) + offset += n + + +def test_get_global_contact_positions_single_probe(): + pg = _make_probegroup_full(1) + pos = pg.get_global_contact_positions() + np.testing.assert_array_equal(pos, pg.probes[0].contact_positions) + + +def test_get_global_contact_positions_3d(): + pg = ProbeGroup() + for i in range(2): + probe = generate_dummy_probe().to_3d() + probe.move([i * 100, i * 80, i * 30]) + pg.add_probe(probe) + pos = pg.get_global_contact_positions() + assert pos.shape[1] == 3 + assert pos.shape[0] == pg.get_contact_count() + + +def test_get_global_contact_positions_reflects_move(): + """Positions should reflect probe movement.""" + pg = ProbeGroup() + probe = generate_dummy_probe() + original_pos = probe.contact_positions.copy() + probe.move([50, 60]) + pg.add_probe(probe) + pos = pg.get_global_contact_positions() + np.testing.assert_array_equal(pos, original_pos + np.array([50, 60])) + + if __name__ == "__main__": test_probegroup() # ~ test_probegroup_3d() From eb08f124cf1a863a02510b7cfbbdba1b602a8138 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 15:55:01 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/probeinterface/probegroup.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index af1f0041..bcaaa079 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -303,12 +303,16 @@ def get_slice(self, selection: np.ndarray[bool | int]): assert selection.shape == (n,), ( f"if array of bool given it must be the same size " "as the number of contacts {selection.shape} != {n}" ) - selection_indices, = np.nonzero(selection) + (selection_indices,) = np.nonzero(selection) elif selection.dtype.kind == "i": assert np.unique(selection).size == selection.size if len(selection) > 0: - assert 0 <= np.min(selection) < n, f"An index within your selection is out of bounds {np.min(selection)}" - assert 0 <= np.max(selection) < n, f"An index within your selection is out of bounds {np.max(selection)}" + assert ( + 0 <= np.min(selection) < n + ), f"An index within your selection is out of bounds {np.min(selection)}" + assert ( + 0 <= np.max(selection) < n + ), f"An index within your selection is out of bounds {np.max(selection)}" selection_indices = selection else: selection_indices = [] @@ -327,7 +331,9 @@ def get_slice(self, selection: np.ndarray[bool | int]): probe_limits = (ind, ind + n) ind += n - probe_selection_indices = selection_indices[(selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1])] + probe_selection_indices = selection_indices[ + (selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1]) + ] if len(probe_selection_indices) == 0: continue sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) From 9503a804c3811ff45c71940c0cb8145bb7413b65 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 24 Mar 2026 17:08:13 +0100 Subject: [PATCH 3/6] reorder tests --- tests/test_probegroup.py | 222 +++++++++++++++------------------------ 1 file changed, 82 insertions(+), 140 deletions(-) diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 416b1829..49177860 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -5,30 +5,26 @@ import numpy as np - -def test_probegroup(): +@pytest.fixture +def probegroup(): + """Fixture: a ProbeGroup with 3 probes, each with device channel indices set.""" probegroup = ProbeGroup() - nchan = 0 for i in range(3): probe = generate_dummy_probe() probe.move([i * 100, i * 80]) n = probe.get_contact_count() - probe.set_device_channel_indices(np.arange(n)[::-1] + nchan) - shank_ids = np.ones(n) - shank_ids[: n // 2] *= i * 2 - shank_ids[n // 2 :] *= i * 2 + 1 - probe.set_shank_ids(shank_ids) + probe.set_device_channel_indices(np.arange(n) + nchan) probegroup.add_probe(probe) - nchan += n + return probegroup +def test_probegroup(probegroup): indices = probegroup.get_global_device_channel_indices() ids = probegroup.get_global_contact_ids() df = probegroup.to_dataframe() - # ~ print(df['global_contact_ids']) arr = probegroup.to_numpy(complete=False) other = ProbeGroup.from_numpy(arr) @@ -38,12 +34,6 @@ def test_probegroup(): d = probegroup.to_dict() other = ProbeGroup.from_dict(d) - # ~ from probeinterface.plotting import plot_probe_group, plot_probe - # ~ import matplotlib.pyplot as plt - # ~ plot_probe_group(probegroup) - # ~ plot_probe_group(other) - # ~ plt.show() - # checking automatic generation of ids with new dummy probes probegroup.probes = [] for i in range(3): @@ -115,197 +105,149 @@ def test_set_contact_ids_rejects_wrong_size(): with pytest.raises(ValueError, match="do not have the same size"): probe.set_contact_ids(["a", "b", "c"]) +# ── get_global_contact_positions() tests ──────────────────────────────────── -def _make_probegroup(n_probes=3): - """Helper: build a ProbeGroup with device channel indices set.""" - probegroup = ProbeGroup() - nchan = 0 - for i in range(n_probes): - probe = generate_dummy_probe() - probe.move([i * 100, i * 80]) - n = probe.get_contact_count() - probe.set_device_channel_indices(np.arange(n) + nchan) - probe.set_contact_ids([f"e{j}" for j in range(nchan, nchan + n)]) - nchan += n - probegroup.add_probe(probe) - return probegroup +def test_get_global_contact_positions_shape(probegroup): + pos = probegroup.get_global_contact_positions() + assert pos.shape == (probegroup.get_contact_count(), probegroup.ndim) -def _make_probegroup_full(n_probes=3): - """Helper: build a ProbeGroup where **every** probe is added.""" - probegroup = ProbeGroup() - nchan = 0 - for i in range(n_probes): - probe = generate_dummy_probe() - probe.move([i * 100, i * 80]) + +def test_get_global_contact_positions_matches_per_probe(probegroup): + pos = probegroup.get_global_contact_positions() + offset = 0 + for probe in probegroup.probes: n = probe.get_contact_count() - probe.set_device_channel_indices(np.arange(n) + nchan) - probe.set_contact_ids([f"e{j}" for j in range(nchan, nchan + n)]) - probegroup.add_probe(probe) - nchan += n - return probegroup + np.testing.assert_array_equal(pos[offset : offset + n], probe.contact_positions) + offset += n + + +def test_get_global_contact_positions_single_probe(probegroup): + pos = probegroup.get_global_contact_positions() + np.testing.assert_array_equal(pos[: probegroup.probes[0].get_contact_count()], probegroup.probes[0].contact_positions) + + +def test_get_global_contact_positions_3d(): + pg = ProbeGroup() + for i in range(2): + probe = generate_dummy_probe().to_3d() + probe.move([i * 100, i * 80, i * 30]) + pg.add_probe(probe) + pos = pg.get_global_contact_positions() + assert pos.shape[1] == 3 + assert pos.shape[0] == pg.get_contact_count() + +def test_get_global_contact_positions_reflects_move(): + """Positions should reflect probe movement.""" + pg = ProbeGroup() + probe = generate_dummy_probe() + original_pos = probe.contact_positions.copy() + probe.move([50, 60]) + pg.add_probe(probe) + pos = pg.get_global_contact_positions() + np.testing.assert_array_equal(pos, original_pos + np.array([50, 60])) # ── copy() tests ──────────────────────────────────────────────────────────── -def test_copy_returns_new_object(): - pg = _make_probegroup_full(2) - pg_copy = pg.copy() - assert pg_copy is not pg - assert len(pg_copy.probes) == len(pg.probes) - for orig, copied in zip(pg.probes, pg_copy.probes): +def test_copy_returns_new_object(probegroup): + pg_copy = probegroup.copy() + assert pg_copy is not probegroup + assert len(pg_copy.probes) == len(probegroup.probes) + for orig, copied in zip(probegroup.probes, pg_copy.probes): assert orig is not copied -def test_copy_preserves_positions(): - pg = _make_probegroup_full(2) - pg_copy = pg.copy() - for orig, copied in zip(pg.probes, pg_copy.probes): +def test_copy_preserves_positions(probegroup): + pg_copy = probegroup.copy() + for orig, copied in zip(probegroup.probes, pg_copy.probes): np.testing.assert_array_equal(orig.contact_positions, copied.contact_positions) -def test_copy_preserves_device_channel_indices(): - pg = _make_probegroup_full(2) - pg_copy = pg.copy() +def test_copy_preserves_device_channel_indices(probegroup): + pg_copy = probegroup.copy() np.testing.assert_array_equal( - pg.get_global_device_channel_indices(), + probegroup.get_global_device_channel_indices(), pg_copy.get_global_device_channel_indices(), ) -def test_copy_does_not_preserve_contact_ids(): +def test_copy_does_not_preserve_contact_ids(probegroup): """Probe.copy() intentionally does not copy contact_ids.""" - pg = _make_probegroup_full(2) - pg_copy = pg.copy() + pg_copy = probegroup.copy() # All contact_ids should be empty strings after copy assert all(cid == "" for cid in pg_copy.get_global_contact_ids()) -def test_copy_is_independent(): +def test_copy_is_independent(probegroup): """Mutating the copy must not affect the original.""" - pg = _make_probegroup_full(2) - original_positions = pg.probes[0].contact_positions.copy() - pg_copy = pg.copy() + original_positions = probegroup.probes[0].contact_positions.copy() + pg_copy = probegroup.copy() pg_copy.probes[0].move([999, 999]) - np.testing.assert_array_equal(pg.probes[0].contact_positions, original_positions) + np.testing.assert_array_equal(probegroup.probes[0].contact_positions, original_positions) # ── get_slice() tests ─────────────────────────────────────────────────────── -def test_get_slice_by_bool(): - pg = _make_probegroup_full(2) - total = pg.get_contact_count() +def test_get_slice_by_bool(probegroup): + total = probegroup.get_contact_count() sel = np.zeros(total, dtype=bool) sel[:5] = True # first 5 contacts from the first probe - sliced = pg.get_slice(sel) + sliced = probegroup.get_slice(sel) assert sliced.get_contact_count() == 5 -def test_get_slice_by_index(): - pg = _make_probegroup_full(2) +def test_get_slice_by_index(probegroup): indices = np.array([0, 1, 2, 33, 34]) # contacts from both probes - sliced = pg.get_slice(indices) + sliced = probegroup.get_slice(indices) assert sliced.get_contact_count() == 5 -def test_get_slice_preserves_device_channel_indices(): - pg = _make_probegroup_full(2) +def test_get_slice_preserves_device_channel_indices(probegroup): indices = np.array([0, 1, 2]) - sliced = pg.get_slice(indices) - orig_chans = pg.get_global_device_channel_indices()["device_channel_indices"][:3] + sliced = probegroup.get_slice(indices) + orig_chans = probegroup.get_global_device_channel_indices()["device_channel_indices"][:3] sliced_chans = sliced.get_global_device_channel_indices()["device_channel_indices"] np.testing.assert_array_equal(sliced_chans, orig_chans) -def test_get_slice_preserves_positions(): - pg = _make_probegroup_full(2) +def test_get_slice_preserves_positions(probegroup): indices = np.array([0, 1, 2]) - sliced = pg.get_slice(indices) - expected = pg.get_global_contact_positions()[indices] + sliced = probegroup.get_slice(indices) + expected = probegroup.get_global_contact_positions()[indices] np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected) -def test_get_slice_empty_selection(): - pg = _make_probegroup_full(2) - sliced = pg.get_slice(np.array([], dtype=int)) +def test_get_slice_empty_selection(probegroup): + sliced = probegroup.get_slice(np.array([], dtype=int)) assert sliced.get_contact_count() == 0 assert len(sliced.probes) == 0 -def test_get_slice_wrong_bool_size(): - pg = _make_probegroup_full(2) +def test_get_slice_wrong_bool_size(probegroup): with pytest.raises(AssertionError): - pg.get_slice(np.array([True, False])) # wrong size + probegroup.get_slice(np.array([True, False])) # wrong size -def test_get_slice_out_of_bounds(): - pg = _make_probegroup_full(2) - total = pg.get_contact_count() +def test_get_slice_out_of_bounds(probegroup): + total = probegroup.get_contact_count() with pytest.raises(AssertionError): - pg.get_slice(np.array([total + 10])) + probegroup.get_slice(np.array([total + 10])) -def test_get_slice_all_contacts(): +def test_get_slice_all_contacts(probegroup): """Slicing with all contacts should give an equivalent ProbeGroup.""" - pg = _make_probegroup_full(2) - total = pg.get_contact_count() - sliced = pg.get_slice(np.arange(total)) + total = probegroup.get_contact_count() + sliced = probegroup.get_slice(np.arange(total)) assert sliced.get_contact_count() == total np.testing.assert_array_equal( sliced.get_global_contact_positions(), - pg.get_global_contact_positions(), + probegroup.get_global_contact_positions(), ) -# ── get_global_contact_positions() tests ──────────────────────────────────── - - -def test_get_global_contact_positions_shape(): - pg = _make_probegroup_full(3) - pos = pg.get_global_contact_positions() - assert pos.shape == (pg.get_contact_count(), pg.ndim) - - -def test_get_global_contact_positions_matches_per_probe(): - pg = _make_probegroup_full(3) - pos = pg.get_global_contact_positions() - offset = 0 - for probe in pg.probes: - n = probe.get_contact_count() - np.testing.assert_array_equal(pos[offset : offset + n], probe.contact_positions) - offset += n - - -def test_get_global_contact_positions_single_probe(): - pg = _make_probegroup_full(1) - pos = pg.get_global_contact_positions() - np.testing.assert_array_equal(pos, pg.probes[0].contact_positions) - - -def test_get_global_contact_positions_3d(): - pg = ProbeGroup() - for i in range(2): - probe = generate_dummy_probe().to_3d() - probe.move([i * 100, i * 80, i * 30]) - pg.add_probe(probe) - pos = pg.get_global_contact_positions() - assert pos.shape[1] == 3 - assert pos.shape[0] == pg.get_contact_count() - - -def test_get_global_contact_positions_reflects_move(): - """Positions should reflect probe movement.""" - pg = ProbeGroup() - probe = generate_dummy_probe() - original_pos = probe.contact_positions.copy() - probe.move([50, 60]) - pg.add_probe(probe) - pos = pg.get_global_contact_positions() - np.testing.assert_array_equal(pos, original_pos + np.array([50, 60])) - - if __name__ == "__main__": test_probegroup() # ~ test_probegroup_3d() From 74c31a9f7181d797ab40d7dc4e132ba1843447d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 16:08:37 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_probegroup.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index 49177860..c9421908 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -5,6 +5,7 @@ import numpy as np + @pytest.fixture def probegroup(): """Fixture: a ProbeGroup with 3 probes, each with device channel indices set.""" @@ -19,6 +20,7 @@ def probegroup(): nchan += n return probegroup + def test_probegroup(probegroup): indices = probegroup.get_global_device_channel_indices() @@ -105,6 +107,7 @@ def test_set_contact_ids_rejects_wrong_size(): with pytest.raises(ValueError, match="do not have the same size"): probe.set_contact_ids(["a", "b", "c"]) + # ── get_global_contact_positions() tests ──────────────────────────────────── @@ -124,7 +127,9 @@ def test_get_global_contact_positions_matches_per_probe(probegroup): def test_get_global_contact_positions_single_probe(probegroup): pos = probegroup.get_global_contact_positions() - np.testing.assert_array_equal(pos[: probegroup.probes[0].get_contact_count()], probegroup.probes[0].contact_positions) + np.testing.assert_array_equal( + pos[: probegroup.probes[0].get_contact_count()], probegroup.probes[0].contact_positions + ) def test_get_global_contact_positions_3d(): @@ -148,6 +153,7 @@ def test_get_global_contact_positions_reflects_move(): pos = pg.get_global_contact_positions() np.testing.assert_array_equal(pos, original_pos + np.array([50, 60])) + # ── copy() tests ──────────────────────────────────────────────────────────── From 33fb7fc1ef5baad2bab1f526d491c1573eda975b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Mar 2026 09:44:37 +0100 Subject: [PATCH 5/6] code review from Heberto --- src/probeinterface/probegroup.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index bcaaa079..08104310 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -14,7 +14,7 @@ class ProbeGroup: def __init__(self): self.probes = [] - def add_probe(self, probe: Probe): + def add_probe(self, probe: Probe) -> None: """ Add an additional probe to the ProbeGroup @@ -30,7 +30,7 @@ def add_probe(self, probe: Probe): self.probes.append(probe) probe._probe_group = self - def _check_compatible(self, probe: Probe): + def _check_compatible(self, probe: Probe) -> None: if probe._probe_group is not None: raise ValueError( "This probe is already attached to another ProbeGroup. Use probe.copy() to attach it to another ProbeGroup" @@ -47,7 +47,7 @@ def _check_compatible(self, probe: Probe): self.probes = self.probes[:-1] @property - def ndim(self): + def ndim(self) -> int: return self.probes[0].ndim def copy(self) -> "ProbeGroup": @@ -163,7 +163,7 @@ def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame": df.index = np.arange(df.shape[0], dtype="int64") return df - def to_dict(self, array_as_list: bool = False): + def to_dict(self, array_as_list: bool = False) -> dict: """Create a dictionary of all necessary attributes. Parameters @@ -184,7 +184,7 @@ def to_dict(self, array_as_list: bool = False): return d @staticmethod - def from_dict(d: dict): + def from_dict(d: dict) -> "ProbeGroup": """Instantiate a ProbeGroup from a dictionary Parameters @@ -226,7 +226,7 @@ def get_global_device_channel_indices(self) -> np.ndarray: channels["device_channel_indices"] = arr["device_channel_indices"] return channels - def set_global_device_channel_indices(self, channels: np.ndarray | list): + def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None: """ Set global indices for all probes @@ -277,7 +277,7 @@ def get_global_contact_positions(self) -> np.ndarray: contact_positions = np.vstack([probe.contact_positions for probe in self.probes]) return contact_positions - def get_slice(self, selection: np.ndarray[bool | int]): + def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": """ Get a copy of the ProbeGroup with a sub selection of contacts. @@ -301,7 +301,7 @@ def get_slice(self, selection: np.ndarray[bool | int]): selection = np.asarray(selection) if selection.dtype == "bool": assert selection.shape == (n,), ( - f"if array of bool given it must be the same size " "as the number of contacts {selection.shape} != {n}" + f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}" ) (selection_indices,) = np.nonzero(selection) elif selection.dtype.kind == "i": @@ -323,7 +323,6 @@ def get_slice(self, selection: np.ndarray[bool | int]): return ProbeGroup() # Map selection to indices of individual probes - d = self.to_dict(array_as_list=False) ind = 0 sliced_probes = [] for probe in self.probes: @@ -340,11 +339,12 @@ def get_slice(self, selection: np.ndarray[bool | int]): sliced_probes.append(sliced_probe) sliced_probe_group = ProbeGroup() - sliced_probe_group.probes = sliced_probes + for probe in sliced_probes: + sliced_probe_group.add_probe(probe) return sliced_probe_group - def check_global_device_wiring_and_ids(self): + def check_global_device_wiring_and_ids(self) -> None: # check unique device_channel_indices for !=-1 chans = self.get_global_device_channel_indices() keep = chans["device_channel_indices"] >= 0 @@ -353,7 +353,7 @@ def check_global_device_wiring_and_ids(self): if valid_chans.size != np.unique(valid_chans).size: raise ValueError("channel device indices are not unique across probes") - def auto_generate_probe_ids(self, *args, **kwargs): + def auto_generate_probe_ids(self, *args, **kwargs) -> None: """ Annotate all probes with unique probe_id values. @@ -377,7 +377,7 @@ def auto_generate_probe_ids(self, *args, **kwargs): for pid, probe in enumerate(self.probes): probe.annotate(probe_id=probe_ids[pid]) - def auto_generate_contact_ids(self, *args, **kwargs): + def auto_generate_contact_ids(self, *args, **kwargs) -> None: """ Annotate all contacts with unique contact_id values. From 43afd8025c7338a9d9ba0ca3dec010763c12feba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Mar 2026 08:44:58 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/probeinterface/probegroup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index 08104310..d42906a4 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -300,9 +300,9 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": selection = np.asarray(selection) if selection.dtype == "bool": - assert selection.shape == (n,), ( - f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}" - ) + assert selection.shape == ( + n, + ), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}" (selection_indices,) = np.nonzero(selection) elif selection.dtype.kind == "i": assert np.unique(selection).size == selection.size