From 3dc57290dbde0aeaa5048f2301ee75015a93fe26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Dec 2025 15:43:44 +0100 Subject: [PATCH 1/7] Test IBL extractors tests failing for PI update --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [ From 61c317aba92608d9f096a3a374bc3d43e27faaba Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Mar 2026 10:09:46 -0800 Subject: [PATCH 2/7] Fix OpenEphys tests --- .../extractors/neoextractors/openephys.py | 20 ++++++++++++------- .../extractors/tests/test_neoextractors.py | 3 +++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 1c39a1b97c..1d16df534b 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -351,13 +351,19 @@ def __init__( # Ensure device channel index corresponds to channel_ids probe_channel_names = probe.contact_annotations.get("channel_name", None) if probe_channel_names is not None and not np.array_equal(probe_channel_names, self.channel_ids): - device_channel_indices = [] - probe_channel_names = list(probe_channel_names) - device_channel_indices = np.zeros(len(self.channel_ids), dtype=int) - for i, ch in enumerate(self.channel_ids): - index_in_probe = probe_channel_names.index(ch) - device_channel_indices[index_in_probe] = i - probe.set_device_channel_indices(device_channel_indices) + if set(probe_channel_names) == set(self.channel_ids): + device_channel_indices = [] + probe_channel_names = list(probe_channel_names) + device_channel_indices = np.zeros(len(self.channel_ids), dtype=int) + for i, ch in enumerate(self.channel_ids): + index_in_probe = probe_channel_names.index(ch) + device_channel_indices[index_in_probe] = i + probe.set_device_channel_indices(device_channel_indices) + else: + warnings.warn( + "Channel names in the probe do not match the channel ids from Neo. " + "Cannot set device channel indices, but this might lead to incorrect probe geometries" + ) if probe.shank_ids is not None: self.set_probe(probe, in_place=True, group_mode="by_shank") diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index f80f62ebf0..f40b4d05ab 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -121,6 +121,9 @@ class OpenEphysBinaryRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ("openephysbinary/v0.5.x_two_nodes", {"stream_id": "0"}), ("openephysbinary/v0.5.x_two_nodes", {"stream_id": "1"}), ("openephysbinary/v0.6.x_neuropixels_multiexp_multistream", {"stream_id": "0", "block_index": 0}), + # TODO: block_indices 1/2 of v0.6.x_neuropixels_multiexp_multistream have a mismatch in the channel names between + # the settings files (starting with CH0) and structure.oebin (starting at CH1). + # Currently, the extractor will skip remapping to match order in oebin and settings file, raising a warning ("openephysbinary/v0.6.x_neuropixels_multiexp_multistream", {"stream_id": "1", "block_index": 1}), ( "openephysbinary/v0.6.x_neuropixels_multiexp_multistream", From f9de0519bbf735c3554f71ad8599899233df6e47 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Mar 2026 14:56:10 +0100 Subject: [PATCH 3/7] Centralize segment handling to BaseExtractors --- src/spikeinterface/core/base.py | 33 ++++++---- src/spikeinterface/core/baserecording.py | 62 +++++++------------ src/spikeinterface/core/basesorting.py | 28 ++++----- .../core/binaryrecordingextractor.py | 9 ++- .../core/channelsaggregationrecording.py | 2 +- src/spikeinterface/core/channelslice.py | 2 +- .../core/frameslicerecording.py | 2 +- src/spikeinterface/core/frameslicesorting.py | 2 +- src/spikeinterface/core/generate.py | 4 +- src/spikeinterface/core/node_pipeline.py | 2 +- src/spikeinterface/core/numpyextractors.py | 2 +- src/spikeinterface/core/operatorrecordings.py | 6 +- src/spikeinterface/core/segmentutils.py | 18 +++--- src/spikeinterface/core/sparsity.py | 16 ++--- src/spikeinterface/core/template.py | 5 +- src/spikeinterface/core/template_tools.py | 26 ++++---- .../core/tests/test_recording_tools.py | 2 +- .../core/tests/test_time_handling.py | 6 +- .../core/unitsaggregationsorting.py | 2 +- .../core/unitsselectionsorting.py | 2 +- src/spikeinterface/core/zarrextractors.py | 2 +- .../curation/bombcell_curation.py | 4 +- src/spikeinterface/curation/curation_tools.py | 6 +- .../curation/mergeunitssorting.py | 2 +- .../curation/remove_duplicated_spikes.py | 2 +- .../curation/remove_excess_spikes.py | 2 +- .../curation/splitunitsorting.py | 2 +- .../extractors/mdaextractors.py | 2 +- .../extractors/phykilosortextractors.py | 2 +- src/spikeinterface/generation/drift_tools.py | 4 +- .../postprocessing/alignsorting.py | 2 +- .../postprocessing/localization_tools.py | 9 +-- src/spikeinterface/preprocessing/astype.py | 2 +- .../preprocessing/average_across_direction.py | 2 +- src/spikeinterface/preprocessing/clip.py | 4 +- .../preprocessing/common_reference.py | 2 +- src/spikeinterface/preprocessing/decimate.py | 2 +- .../deepinterpolation/deepinterpolation.py | 2 +- .../preprocessing/detect_artifacts.py | 2 - .../preprocessing/directional_derivative.py | 2 +- src/spikeinterface/preprocessing/filter.py | 2 +- .../preprocessing/filter_gaussian.py | 2 +- .../preprocessing/filter_opencl.py | 2 +- .../preprocessing/highpass_spatial_filter.py | 2 +- .../preprocessing/interpolate_bad_channels.py | 2 +- .../preprocessing/normalize_scale.py | 8 +-- .../preprocessing/phase_shift.py | 2 +- src/spikeinterface/preprocessing/rectify.py | 2 +- .../preprocessing/remove_artifacts.py | 2 +- src/spikeinterface/preprocessing/resample.py | 2 +- .../preprocessing/silence_periods.py | 2 +- .../preprocessing/tests/test_clip.py | 2 +- .../tests/test_common_reference.py | 6 +- .../preprocessing/tests/test_decimate.py | 2 +- .../tests/test_detect_bad_channels.py | 4 +- .../preprocessing/tests/test_resample.py | 4 +- .../preprocessing/tests/test_whiten.py | 4 +- .../preprocessing/unsigned_to_signed.py | 2 +- src/spikeinterface/preprocessing/whiten.py | 2 +- .../preprocessing/zero_channel_pad.py | 4 +- .../sortingcomponents/matching/tdc_peeler.py | 2 +- .../motion/motion_interpolation.py | 2 +- .../widgets/bombcell_curation.py | 2 - src/spikeinterface/widgets/unit_labels.py | 2 - 64 files changed, 166 insertions(+), 188 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 9dc270d38d..8d149a7c49 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from pathlib import Path import shutil from typing import Any @@ -87,6 +85,8 @@ def __init__(self, main_ids: Sequence) -> None: self._main_ids.dtype.kind in "uiSU" ), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}" + self._segments: "list[BaseSegment]" = [] + # dict at object level self._annotations = {} @@ -142,11 +142,18 @@ def name(self, value): # we remove the annotation if it exists _ = self._annotations.pop("name", None) + @property + def segments(self) -> "list[BaseSegment]": + return self._segments + + def add_segment(self, segment: "BaseSegment") -> None: + self._segments.append(segment) + segment.set_parent_extractor(self) + def get_num_segments(self) -> int: - # This is implemented in BaseRecording or BaseSorting - raise NotImplementedError + return len(self._segments) - def get_parent(self) -> BaseExtractor | None: + def get_parent(self) -> "BaseExtractor | None": """Returns parent object if it exists, otherwise None""" return getattr(self, "_parent", None) @@ -381,7 +388,7 @@ def delete_property(self, key) -> None: def copy_metadata( self, - other: BaseExtractor, + other: "BaseExtractor", only_main: bool = False, ids: Iterable | slice | None = None, skip_properties: Iterable[str] | None = None, @@ -570,7 +577,7 @@ def to_dict( return dump_dict @staticmethod - def from_dict(dictionary: dict, base_folder: Path | str | None = None) -> BaseExtractor: + def from_dict(dictionary: dict, base_folder: Path | str | None = None) -> "BaseExtractor": """ Instantiate extractor from dictionary @@ -624,7 +631,7 @@ def save_metadata_to_folder(self, folder_metadata): values = self.get_property(key) np.save(prop_folder / (key + ".npy"), values) - def clone(self) -> BaseExtractor: + def clone(self) -> "BaseExtractor": """ Clones an existing extractor into a new instance. """ @@ -816,7 +823,7 @@ def dump_to_pickle( file_path.write_bytes(pickle.dumps(dump_dict)) @staticmethod - def load(file_or_folder_path: str | Path, base_folder: Path | str | bool | None = None) -> BaseExtractor: + def load(file_or_folder_path: str | Path, base_folder: Path | str | bool | None = None) -> "BaseExtractor": """ Load extractor from file path (.json or .pkl) @@ -839,7 +846,7 @@ def __reduce__(self): return (instance_constructor, intialization_args) @staticmethod - def load_from_folder(folder) -> BaseExtractor: + def load_from_folder(folder) -> "BaseExtractor": return BaseExtractor.load(folder) def _save(self, folder, **save_kwargs): @@ -855,7 +862,7 @@ def _extra_metadata_to_folder(self, folder): # This implemented in BaseRecording for probe pass - def save(self, **kwargs) -> BaseExtractor: + def save(self, **kwargs) -> "BaseExtractor": """ Save a SpikeInterface object. @@ -891,7 +898,7 @@ def save(self, **kwargs) -> BaseExtractor: save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc) - def save_to_memory(self, sharedmem=True, **save_kwargs) -> BaseExtractor: + def save_to_memory(self, sharedmem=True, **save_kwargs) -> "BaseExtractor": save_kwargs.pop("format", None) cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs) @@ -1092,7 +1099,7 @@ def save_to_zarr( return cached -def _load_extractor_from_dict(dic) -> BaseExtractor: +def _load_extractor_from_dict(dic) -> "BaseExtractor": """ Convert a dictionary into an instance of BaseExtractor or its subclass. diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 75bd47597b..322e1c7547 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -1,5 +1,5 @@ -from __future__ import annotations import warnings +from typing import Literal from pathlib import Path import numpy as np @@ -43,9 +43,6 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype): BaseRecordingSnippets.__init__( self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype ) - - self._recording_segments: list[BaseRecordingSegment] = [] - # initialize main annotation and properties self.annotate(is_filtered=False) @@ -171,18 +168,7 @@ def __sub__(self, other): return SubtractRecordings(self, other) - def get_num_segments(self) -> int: - """ - Returns the number of segments. - - Returns - ------- - int - Number of segments in the recording - """ - return len(self._recording_segments) - - def add_recording_segment(self, recording_segment): + def add_recording_segment(self, recording_segment: "BaseRecordingSegment") -> None: """Adds a recording segment. Parameters @@ -190,9 +176,7 @@ def add_recording_segment(self, recording_segment): recording_segment : BaseRecordingSegment The recording segment to add """ - # todo: check channel count and sampling frequency - self._recording_segments.append(recording_segment) - recording_segment.set_parent_extractor(self) + super().add_segment(recording_segment) def get_num_samples(self, segment_index: int | None = None) -> int: """ @@ -211,7 +195,7 @@ def get_num_samples(self, segment_index: int | None = None) -> int: The number of samples """ segment_index = self._check_segment_index(segment_index) - return int(self._recording_segments[segment_index].get_num_samples()) + return int(self.segments[segment_index].get_num_samples()) get_num_frames = get_num_samples @@ -305,7 +289,7 @@ def get_traces( start_frame: int | None = None, end_frame: int | None = None, channel_ids: list | np.ndarray | tuple | None = None, - order: "C" | "F" | None = None, + order: Literal["C", "F"] | None = None, return_scaled: bool | None = None, return_in_uV: bool = False, ) -> np.ndarray: @@ -343,7 +327,7 @@ def get_traces( """ segment_index = self._check_segment_index(segment_index) channel_indices = self.ids_to_indices(channel_ids, prefer_slice=True) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] start_frame = int(start_frame) if start_frame is not None else 0 num_samples = rs.get_num_samples() end_frame = int(min(end_frame, num_samples)) if end_frame is not None else num_samples @@ -401,7 +385,7 @@ def get_time_info(self, segment_index=None) -> dict: """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] time_kwargs = rs.get_times_kwargs() return time_kwargs @@ -425,7 +409,7 @@ def get_times(self, segment_index=None) -> np.ndarray: The 1d times array """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] times = rs.get_times() return times @@ -443,7 +427,7 @@ def get_start_time(self, segment_index=None) -> float: The start time in seconds """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.get_start_time() def get_end_time(self, segment_index=None) -> float: @@ -460,7 +444,7 @@ def get_end_time(self, segment_index=None) -> float: The stop time in seconds """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.get_end_time() def has_time_vector(self, segment_index: int | None = None): @@ -477,7 +461,7 @@ def has_time_vector(self, segment_index: int | None = None): True if the recording has time vectors, False otherwise """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] d = rs.get_times_kwargs() return d["time_vector"] is not None @@ -494,7 +478,7 @@ def set_times(self, times, segment_index=None, with_warning=True): If True, a warning is printed """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] assert times.ndim == 1, "Time must have ndim=1" assert rs.get_num_samples() == times.shape[0], "times have wrong shape" @@ -517,7 +501,7 @@ def reset_times(self): segment's sampling frequency is set to the recording's sampling frequency. """ for segment_index in range(self.get_num_segments()): - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] if self.has_time_vector(segment_index): rs.time_vector = None rs.t_start = None @@ -545,7 +529,7 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N segments_to_shift = (segment_index,) for segment_index in segments_to_shift: - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] if self.has_time_vector(segment_index=segment_index): rs.time_vector += shift @@ -558,19 +542,19 @@ def sample_index_to_time(self, sample_ind, segment_index=None): Transform sample index into time in seconds """ segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.sample_index_to_time(sample_ind) def time_to_sample_index(self, time_s, segment_index=None): segment_index = self._check_segment_index(segment_index) - rs = self._recording_segments[segment_index] + rs = self.segments[segment_index] return rs.time_to_sample_index(time_s) def _get_t_starts(self): # handle t_starts t_starts = [] has_time_vectors = [] - for rs in self._recording_segments: + for rs in self.segments: d = rs.get_times_kwargs() t_starts.append(d["t_start"]) @@ -580,7 +564,7 @@ def _get_t_starts(self): def _get_time_vectors(self): time_vectors = [] - for rs in self._recording_segments: + for rs in self.segments: d = rs.get_times_kwargs() time_vectors.append(d["time_vector"]) if all(time_vector is None for time_vector in time_vectors): @@ -668,7 +652,7 @@ def _extra_metadata_from_folder(self, folder): self.set_probegroup(probegroup, in_place=True) # load time vector if any - for segment_index, rs in enumerate(self._recording_segments): + for segment_index, rs in enumerate(self.segments): time_file = folder / f"times_cached_seg{segment_index}.npy" if time_file.is_file(): time_vector = np.load(time_file) @@ -681,7 +665,7 @@ def _extra_metadata_to_folder(self, folder): write_probeinterface(folder / "probe.json", probegroup) # save time vector if any - for segment_index, rs in enumerate(self._recording_segments): + for segment_index, rs in enumerate(self.segments): d = rs.get_times_kwargs() time_vector = d["time_vector"] if time_vector is not None: @@ -735,7 +719,7 @@ def _remove_channels(self, remove_channel_ids): sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording - def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRecording: + def frame_slice(self, start_frame: int | None, end_frame: int | None) -> "BaseRecording": """ Returns a new recording with sliced frames. Note that this operation is not in place. @@ -757,7 +741,7 @@ def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRec sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame) return sub_recording - def time_slice(self, start_time: float | None, end_time: float | None) -> BaseRecording: + def time_slice(self, start_time: float | None, end_time: float | None) -> "BaseRecording": """ Returns a new recording object, restricted to the time interval [start_time, end_time]. @@ -815,7 +799,7 @@ def _select_segments(self, segment_indices): def get_channel_locations( self, channel_ids: list | np.ndarray | tuple | None = None, - axes: "xy" | "yz" | "xz" | "xyz" = "xy", + axes: Literal["xy", "yz", "xz", "xyz"] = "xy", ) -> np.ndarray: """ Get the physical locations of specified channels. diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index d49065e28d..5afa8ac495 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -1,4 +1,3 @@ -from __future__ import annotations import warnings from copy import deepcopy @@ -17,7 +16,6 @@ class BaseSorting(BaseExtractor): def __init__(self, sampling_frequency: float, unit_ids: list): BaseExtractor.__init__(self, unit_ids) self._sampling_frequency = float(sampling_frequency) - self._sorting_segments: list[BaseSortingSegment] = [] # this weak link is to handle times from a recording object self._recording = None self._sorting_info = None @@ -76,16 +74,12 @@ def get_unit_ids(self) -> list: def get_num_units(self) -> int: return len(self.get_unit_ids()) - def add_sorting_segment(self, sorting_segment): - self._sorting_segments.append(sorting_segment) - sorting_segment.set_parent_extractor(self) + def add_sorting_segment(self, sorting_segment: "BaseSortingSegment") -> None: + super().add_segment(sorting_segment) def get_sampling_frequency(self) -> float: return self._sampling_frequency - def get_num_segments(self) -> int: - return len(self._sorting_segments) - def get_num_samples(self, segment_index=None) -> int: """Returns the number of samples of the associated recording for a segment. @@ -200,7 +194,7 @@ def get_unit_spike_train( end = np.searchsorted(spike_frames, end_frame) spike_frames = spike_frames[:end] else: - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] spike_frames = segment.get_unit_spike_train( unit_id=unit_id, start_frame=start_frame, end_frame=end_frame ).astype("int64") @@ -244,7 +238,7 @@ def get_unit_spike_train_in_seconds( Spike times in seconds """ segment_index = self._check_segment_index(segment_index) - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] # If sorting has a registered recording, get the frames and get the times from the recording # Note that this take into account the segment start time of the recording @@ -497,7 +491,7 @@ def count_total_num_spikes(self) -> int: """ return self.to_spike_vector().size - def select_units(self, unit_ids, renamed_unit_ids=None) -> BaseSorting: + def select_units(self, unit_ids, renamed_unit_ids=None) -> "BaseSorting": """ Returns a new sorting object which contains only a selected subset of units. @@ -519,7 +513,7 @@ def select_units(self, unit_ids, renamed_unit_ids=None) -> BaseSorting: sub_sorting = UnitsSelectionSorting(self, unit_ids, renamed_unit_ids=renamed_unit_ids) return sub_sorting - def rename_units(self, new_unit_ids: np.ndarray | list) -> BaseSorting: + def rename_units(self, new_unit_ids: np.ndarray | list) -> "BaseSorting": """ Returns a new sorting object with renamed units. @@ -540,7 +534,7 @@ def rename_units(self, new_unit_ids: np.ndarray | list) -> BaseSorting: sub_sorting = UnitsSelectionSorting(self, renamed_unit_ids=new_unit_ids) return sub_sorting - def remove_units(self, remove_unit_ids) -> BaseSorting: + def remove_units(self, remove_unit_ids) -> "BaseSorting": """ Returns a new sorting object with contains only a selected subset of units. @@ -613,7 +607,7 @@ def frame_slice(self, start_frame, end_frame, check_spike_frames=True): ) return sub_sorting - def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSorting: + def time_slice(self, start_time: float | None, end_time: float | None) -> "BaseSorting": """ Returns a new sorting object, restricted to the time interval [start_time, end_time]. @@ -705,7 +699,7 @@ def time_to_sample_index(self, time, segment_index=0): if self.has_recording(): sample_index = self._recording.time_to_sample_index(time, segment_index=segment_index) else: - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] t_start = segment._t_start if segment._t_start is not None else 0 sample_index = round((time - t_start) * self.get_sampling_frequency()) @@ -721,7 +715,7 @@ def sample_index_to_time( if self.has_recording(): return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) else: - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] t_start = segment._t_start if segment._t_start is not None else 0 return (sample_index / self.get_sampling_frequency()) + t_start @@ -754,7 +748,7 @@ def _compute_and_cache_spike_vector(self) -> None: sample_indices = [] unit_indices = [] for u, unit_id in enumerate(self.unit_ids): - segment = self._sorting_segments[segment_index] + segment = self.segments[segment_index] spike_frames = segment.get_unit_spike_train(unit_id=unit_id, start_frame=None, end_frame=None).astype( "int64" ) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 62de7e8fde..b3eaa099ed 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -159,11 +159,10 @@ def __del__(self): Closes any open file handles in the recording segments. """ # Close all recording segments - if hasattr(self, "_recording_segments"): - for segment in self._recording_segments: - # This will trigger the __del__ method of the BinaryRecordingSegment - # which will close the file handle - del segment + for segment in self.segments: + # This will trigger the __del__ method of the BinaryRecordingSegment + # which will close the file handle + del segment BinaryRecordingExtractor.write_recording.__doc__ = BinaryRecordingExtractor.write_recording.__doc__.format( diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 0da4797440..697aab875e 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -127,7 +127,7 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record ch_id += 1 for i_seg in range(num_segments): - parent_segments = [rec._recording_segments[i_seg] for rec in recording_list] + parent_segments = [rec.segments[i_seg] for rec in recording_list] sub_segment = ChannelsAggregationRecordingSegment(channel_map, parent_segments) self.add_recording_segment(sub_segment) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 67d25b2925..de693d5c26 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -53,7 +53,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) self._parent_channel_indices = parent_recording.ids_to_indices(self._channel_ids) # link recording segment - for parent_segment in parent_recording._recording_segments: + for parent_segment in parent_recording.segments: sub_segment = ChannelSliceRecordingSegment(parent_segment, self._parent_channel_indices) self.add_recording_segment(sub_segment) diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 5cc4daa7ed..513c8b3dfb 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -46,7 +46,7 @@ def __init__(self, parent_recording, start_frame=None, end_frame=None): ) # link recording segment - parent_segment = parent_recording._recording_segments[0] + parent_segment = parent_recording.segments[0] sub_segment = FrameSliceRecordingSegment(parent_segment, start_frame=int(start_frame), end_frame=int(end_frame)) self.add_recording_segment(sub_segment) diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index 0d1f307d2e..a337e83707 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -75,7 +75,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike BaseSorting.__init__(self, sampling_frequency=parent_sorting.get_sampling_frequency(), unit_ids=unit_ids) # link sorting segment - parent_segment = parent_sorting._sorting_segments[0] + parent_segment = parent_sorting.segments[0] sub_segment = FrameSliceSortingSegment(parent_segment, start_frame, end_frame) self.add_sorting_segment(sub_segment) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 9b40a23dbd..35116a9e4c 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1993,9 +1993,7 @@ def __init__( amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None - parent_recording_segment = ( - None if parent_recording is None else parent_recording._recording_segments[segment_index] - ) + parent_recording_segment = None if parent_recording is None else parent_recording.segments[segment_index] recording_segment = InjectTemplatesRecordingSegment( self.sampling_frequency, self.dtype, diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 2c38248c1a..43cdd30c87 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -657,7 +657,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c nodes = worker_ctx["nodes"] skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] - recording_segment = recording._recording_segments[segment_index] + recording_segment = recording.segments[segment_index] retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever)) # get peak slices once for all retrievers peak_slice_by_retriever = {} diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 1200612864..31a3a8831d 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -198,7 +198,7 @@ def __init__( } def __del__(self): - self._recording_segments = [] + self._segments = [] for shm in self.shms: shm.close() if self.main_shm_owner: diff --git a/src/spikeinterface/core/operatorrecordings.py b/src/spikeinterface/core/operatorrecordings.py index 6ffb7d9fa3..63332bffa1 100644 --- a/src/spikeinterface/core/operatorrecordings.py +++ b/src/spikeinterface/core/operatorrecordings.py @@ -25,7 +25,7 @@ def __init__(self, recording1, recording2, operator: str): BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) - for segment1, segment2 in zip(recording1._recording_segments, recording2._recording_segments): + for segment1, segment2 in zip(recording1.segments, recording2.segments): add_segment = OperatorRecordingSegment(segment1, segment2, operator) self.add_recording_segment(add_segment) @@ -35,8 +35,8 @@ def are_times_kwargs_compatible(self, recording1, recording2) -> bool: import numpy as np for segment_index in range(recording1.get_num_segments()): - time_kwargs1 = recording1._recording_segments[segment_index].get_times_kwargs() - time_kwargs2 = recording2._recording_segments[segment_index].get_times_kwargs() + time_kwargs1 = recording1.segments[segment_index].get_times_kwargs() + time_kwargs2 = recording2.segments[segment_index].get_times_kwargs() for key in time_kwargs1.keys(): val1 = time_kwargs1[key] val2 = time_kwargs2[key] diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index 6b563ff1d7..3d99fd23c4 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -63,7 +63,7 @@ def __init__(self, recording_list, sampling_frequency_max_diff=0): rec0.copy_metadata(self) for rec in recording_list: - for parent_segment in rec._recording_segments: + for parent_segment in rec.segments: rec_seg = ProxyAppendRecordingSegment(parent_segment) self.add_recording_segment(rec_seg) @@ -119,7 +119,7 @@ def __init__(self, recording_list, ignore_times=True, sampling_frequency_max_dif parent_segments = [] for rec in recording_list: - for parent_segment in rec._recording_segments: + for parent_segment in rec.segments: time_kwargs = parent_segment.get_times_kwargs() if not ignore_times: assert time_kwargs["time_vector"] is None, ( @@ -240,7 +240,7 @@ def __init__(self, recording: BaseRecording, segment_indices: int | list[int]): ), f"'segment_index' must be between 0 and {num_segments - 1}" for segment_index in segment_indices: - rec_seg = recording._recording_segments[segment_index] + rec_seg = recording.segments[segment_index] self.add_recording_segment(rec_seg) self._parent = recording @@ -302,7 +302,7 @@ def __init__(self, sorting_list, sampling_frequency_max_diff=0): sorting0.copy_metadata(self) for sorting in sorting_list: - for parent_segment in sorting._sorting_segments: + for parent_segment in sorting.segments: sorting_seg = ProxyAppendSortingSegment(parent_segment) self.add_sorting_segment(sorting_seg) @@ -384,7 +384,7 @@ def __init__(self, sorting_list, total_samples_list=None, ignore_times=True, sam parent_segments = [] parent_num_samples = [] for sorting_i, sorting in enumerate(sorting_list): - for segment_i, parent_segment in enumerate(sorting._sorting_segments): + for segment_i, parent_segment in enumerate(sorting.segments): # Check t_start is not assigned segment_t_start = parent_segment._t_start if not ignore_times: @@ -438,7 +438,7 @@ def __init__(self, sorting_list, total_samples_list=None, ignore_times=True, sam def get_num_samples(self, segment_index=None): """Overrides the BaseSorting method, which requires a recording.""" segment_index = self._check_segment_index(segment_index) - n_samples = self._sorting_segments[segment_index].get_num_samples() + n_samples = self.segments[segment_index].get_num_samples() if self.has_recording(): # Sanity check assert n_samples == self._recording.get_num_samples(segment_index) return n_samples @@ -554,7 +554,7 @@ def __init__(self, parent_sorting: BaseSorting, recording_or_recording_list=None num_samples = [0] for recording in recording_list: - for recording_segment in recording._recording_segments: + for recording_segment in recording.segments: num_samples.append(recording_segment.get_num_samples()) cumsum_num_samples = np.cumsum(num_samples) @@ -562,7 +562,7 @@ def __init__(self, parent_sorting: BaseSorting, recording_or_recording_list=None sliced_parent_sorting = parent_sorting.frame_slice( start_frame=cumsum_num_samples[idx], end_frame=cumsum_num_samples[idx + 1] ) - sliced_segment = sliced_parent_sorting._sorting_segments[0] + sliced_segment = sliced_parent_sorting.segments[0] self.add_sorting_segment(sliced_segment) self._parent = parent_sorting @@ -597,7 +597,7 @@ def __init__(self, sorting: BaseSorting, segment_indices: int | list[int]): ), f"'segment_index' must be between 0 and {num_segments - 1}" for segment_index in segment_indices: - sort_seg = sorting._sorting_segments[segment_index] + sort_seg = sorting.segments[segment_index] self.add_sorting_segment(sort_seg) self._kwargs = {"sorting": sorting, "segment_indices": [int(s) for s in segment_indices]} diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 91eb7df864..05963520cd 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Literal import numpy as np @@ -619,13 +619,15 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, - method: "radius" | "best_channels" | "closest_channels" | "snr" | "amplitude" | "energy" | "by_property" = "radius", - peak_sign: "neg" | "pos" | "both" = "neg", + method: Literal[ + "radius", "best_channels", "closest_channels", "snr", "amplitude", "energy", "by_property" + ] = "radius", + peak_sign: Literal["neg", "pos", "both"] = "neg", num_channels: int | None = 5, radius_um: float | None = 100.0, threshold: float | None = 5, by_property: str | None = None, - amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + amplitude_mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", ) -> ChannelSparsity: """ Compute channel sparsity from a `SortingAnalyzer` for each template with several methods. @@ -718,12 +720,12 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "closest_channels" | "amplitude" | "snr" | "by_property" = "radius", - peak_sign: "neg" | "pos" | "both" = "neg", + method: Literal["radius", "best_channels", "closest_channels", "amplitude", "snr", "by_property"] = "radius", + peak_sign: Literal["neg", "pos", "both"] = "neg", radius_um: float = 100.0, num_channels: int = 5, threshold: float | None = 5, - amplitude_mode: "extremum" | "peak_to_peak" = "extremum", + amplitude_mode: Literal["extremum", "peak_to_peak"] = "extremum", by_property: str | None = None, noise_levels: np.ndarray | list | None = None, **job_kwargs, diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 3e5a517b0a..67ba1179b0 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -1,4 +1,3 @@ -from __future__ import annotations import numpy as np import json from dataclasses import dataclass, field, astuple, replace @@ -140,7 +139,7 @@ def __repr__(self): return repr_str - def select_units(self, unit_ids) -> Templates: + def select_units(self, unit_ids) -> "Templates": """ Return a new Templates object with only the selected units. @@ -162,7 +161,7 @@ def select_units(self, unit_ids) -> Templates: check_for_consistent_sparsity=False, ) - def select_channels(self, channel_ids) -> Templates: + def select_channels(self, channel_ids) -> "Templates": """ Return a new Templates object with only the selected channels. This operation can be useful to remove bad channels for hybrid recording diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 0293c23876..b6a6b90bb2 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Literal import numpy as np @@ -62,8 +62,8 @@ def _get_nbefore(one_object): def get_template_amplitudes( templates_or_sorting_analyzer, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + peak_sign: Literal["neg", "pos", "both"] = "neg", + mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", return_in_uV: bool = True, abs_value: bool = True, operator: str = "average", @@ -135,9 +135,9 @@ def get_template_amplitudes( def get_template_extremum_channel( templates_or_sorting_analyzer, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", - outputs: "id" | "index" = "id", + peak_sign: Literal["neg", "pos", "both"] = "neg", + mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", + outputs: Literal["id", "index"] = "id", operator: str = "average", ): """ @@ -202,7 +202,9 @@ def get_template_extremum_channel( def get_template_extremum_channel_peak_shift( - templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", operator: str = "average" + templates_or_sorting_analyzer, + peak_sign: Literal["neg", "pos", "both"] = "neg", + operator: Literal["average", "median"] = "average", ): """ In some situations spike sorters could return a spike index with a small shift related to the waveform peak. @@ -228,7 +230,9 @@ def get_template_extremum_channel_peak_shift( channel_ids = templates_or_sorting_analyzer.channel_ids nbefore = _get_nbefore(templates_or_sorting_analyzer) - extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign) + extremum_channels_ids = get_template_extremum_channel( + templates_or_sorting_analyzer, peak_sign=peak_sign, operator=operator + ) shifts = {} @@ -265,10 +269,10 @@ def get_template_extremum_channel_peak_shift( def get_template_extremum_amplitude( templates_or_sorting_analyzer, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "at_index", + peak_sign: Literal["neg", "pos", "both"] = "neg", + mode: Literal["extremum", "at_index", "peak_to_peak"] = "at_index", abs_value: bool = True, - operator: str = "average", + operator: Literal["average", "median"] = "average", ): """ Computes amplitudes on the best channel. diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 02798099ec..405a2ecccf 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -251,7 +251,7 @@ def test_get_noise_levels_output(): def test_get_chunk_with_margin(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0]) - rec_seg = rec._recording_segments[0] + rec_seg = rec.segments[0] length = rec_seg.get_num_samples() #  rec_segment, start_frame, end_frame, channel_indices, sample_margin diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index f22939c33c..e03096ce14 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -64,7 +64,7 @@ def _get_time_vector_recording(self, raw_recording): times_recording.set_times(times=time_vector, segment_index=segment_index) assert np.array_equal( - times_recording._recording_segments[segment_index].time_vector, + times_recording.segments[segment_index].time_vector, time_vector, ), "time_vector was not properly set during test setup" @@ -84,7 +84,7 @@ def _get_t_start_recording(self, raw_recording): t_start = (segment_index + 1) * 100 all_t_starts.append(t_start + t_start_recording.get_times(segment_index)) - t_start_recording._recording_segments[segment_index].t_start = t_start + t_start_recording.segments[segment_index].t_start = t_start return (raw_recording, t_start_recording, all_t_starts) @@ -442,6 +442,6 @@ def test_shift_times_with_None_as_t_start(): """Ensures we can shift times even when t_stat is None which is interpeted as zero""" recording = generate_recording(num_channels=4, durations=[10]) - assert recording._recording_segments[0].t_start is None + assert recording.segments[0].t_start is None recording.shift_times(shift=1.0) # Shift by one seconds should not generate an error assert recording.get_start_time() == 1.0 diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 84d2c06e59..32040f8f61 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -134,7 +134,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None, sampling_frequency_max_d # add segments for i_seg in range(num_segments): - parent_segments = [sort._sorting_segments[i_seg] for sort in sorting_list] + parent_segments = [sort.segments[i_seg] for sort in sorting_list] sub_segment = UnitsAggregationSortingSegment(unit_map, parent_segments) self.add_sorting_segment(sub_segment) diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index b0f3b19472..59356db976 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -33,7 +33,7 @@ def __init__(self, parent_sorting, unit_ids=None, renamed_unit_ids=None): BaseSorting.__init__(self, sampling_frequency, self._renamed_unit_ids) - for parent_segment in self._parent_sorting._sorting_segments: + for parent_segment in self._parent_sorting.segments: sub_segment = UnitsSelectionSortingSegment(parent_segment, ids_conversion) self.add_sorting_segment(sub_segment) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index e58ef4ee68..8f266e0123 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -500,7 +500,7 @@ def add_recording_to_zarr_group( # save time vector if any t_starts = np.zeros(recording.get_num_segments(), dtype="float64") * np.nan - for segment_index, rs in enumerate(recording._recording_segments): + for segment_index, rs in enumerate(recording.segments): d = rs.get_times_kwargs() time_vector = d["time_vector"] diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py index 888f5964ca..9e0b17632e 100644 --- a/src/spikeinterface/curation/bombcell_curation.py +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -8,8 +8,6 @@ non_soma: Non-somatic units (axonal) """ -from __future__ import annotations - import operator from pathlib import Path import json @@ -87,7 +85,7 @@ def bombcell_label_units( thresholds: dict | str | Path | None = None, label_non_somatic: bool = True, split_non_somatic_good_mua: bool = False, - external_metrics: "pd.DataFrame | list[pd.DataFrame]" | None = None, + external_metrics: "pd.DataFrame | list[pd.DataFrame] | None" = None, ) -> "pd.DataFrame": """ Label units based on quality metrics and template metrics using Bombcell logic: diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index 31ce825c7f..ff2c32d07f 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Literal import numpy as np @@ -61,7 +61,7 @@ def _find_duplicated_spikes_numpy( spike_train: np.ndarray, censored_period: int, seed: int | None = None, - method: "keep_first" | "random" | "keep_last" = "keep_first", + method: Literal["keep_first", "random", "keep_last"] = "keep_first", ) -> np.ndarray: (indices_of_duplicates,) = np.where(np.diff(spike_train) <= censored_period) @@ -138,7 +138,7 @@ def _find_duplicated_spikes_keep_last_iterative(spike_train, censored_period): def find_duplicated_spikes( spike_train, censored_period: int, - method: "keep_first" | "keep_last" | "keep_first_iterative" | "keep_last_iterative" | "random" = "random", + method: Literal["keep_first", "keep_last", "keep_first_iterative", "keep_last_iterative", "random"] = "random", seed: int | None = None, ) -> np.ndarray: """ diff --git a/src/spikeinterface/curation/mergeunitssorting.py b/src/spikeinterface/curation/mergeunitssorting.py index ff83edaca2..9d9e10e75f 100644 --- a/src/spikeinterface/curation/mergeunitssorting.py +++ b/src/spikeinterface/curation/mergeunitssorting.py @@ -68,7 +68,7 @@ def __init__(self, sorting, units_to_merge, new_unit_ids=None, properties_policy rm_dup_delta = None else: rm_dup_delta = int(delta_time_ms / 1000 * sampling_frequency) - for parent_segment in self._parent_sorting._sorting_segments: + for parent_segment in self._parent_sorting.segments: sub_segment = MergeUnitsSortingSegment(parent_segment, units_to_merge, new_unit_ids, rm_dup_delta) self.add_sorting_segment(sub_segment) diff --git a/src/spikeinterface/curation/remove_duplicated_spikes.py b/src/spikeinterface/curation/remove_duplicated_spikes.py index 2ff3456822..33d342ff14 100644 --- a/src/spikeinterface/curation/remove_duplicated_spikes.py +++ b/src/spikeinterface/curation/remove_duplicated_spikes.py @@ -37,7 +37,7 @@ def __init__(self, sorting: BaseSorting, censored_period_ms: float = 0.3, method censored_period = int(round(censored_period_ms * 1e-3 * sorting.get_sampling_frequency())) seed = np.random.randint(low=0, high=np.iinfo(np.int32).max) - for segment in sorting._sorting_segments: + for segment in sorting.segments: self.add_sorting_segment( RemoveDuplicatedSpikesSortingSegment(segment, censored_period, sorting.unit_ids, method, seed) ) diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 020037b2b7..04169808f5 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -32,7 +32,7 @@ def __init__(self, sorting: BaseSorting, recording: BaseRecording) -> None: self._parent_sorting = sorting self._num_samples = np.empty(sorting.get_num_segments(), dtype=np.int64) for segment_index in range(sorting.get_num_segments()): - sorting_segment = sorting._sorting_segments[segment_index] + sorting_segment = sorting.segments[segment_index] self._num_samples[segment_index] = recording.get_num_samples(segment_index=segment_index) self.add_sorting_segment( RemoveExcessSpikesSortingSegment(sorting_segment, self._num_samples[segment_index]) diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index f5a548113d..c09f57df5a 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -78,7 +78,7 @@ def __init__(self, sorting, split_unit_id, indices_list, new_unit_ids=None, prop np.isin(unchanged_units, self.unit_ids) ), "new_unit_ids should have a compatible format with the parent ids" - for si, parent_segment in enumerate(self._parent_sorting._sorting_segments): + for si, parent_segment in enumerate(self._parent_sorting.segments): sub_segment = SplitSortingUnitSegment(parent_segment, split_unit_id, indices_zero_based[si], new_unit_ids) self.add_sorting_segment(sub_segment) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 676b2bceac..7a7bdd45a6 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -210,7 +210,7 @@ def __init__(self, file_path, sampling_frequency): # Every spike assigned to a unit (label) has the same max channel # ref: https://github.com/SpikeInterface/spikeinterface/issues/3695#issuecomment-2663329006 max_channels = [] - segment = self._sorting_segments[0] + segment = self.segments[0] for unit_id in self.unit_ids: label_mask = segment._labels == unit_id # since all max channels are the same, we can just grab the first occurrence for the unit diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index dec65404e9..0e5dd2694d 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -355,7 +355,7 @@ def read_kilosort_as_analyzer(folder_path, unwhiten=True, gain_to_uV=None, offse # kilosort occasionally contains a few spikes just beyond the recording end point, which can lead # to errors later. To avoid this, we pad the recording with an extra second of blank time. - duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1 + duration = sorting.segments[0]._all_spikes[-1] / sampling_frequency + 1 if (phy_path / "probe.prb").is_file(): probegroup = read_prb(phy_path / "probe.prb") diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 97ac263cc5..1800138dae 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -463,9 +463,7 @@ def __init__( amplitude_vec = amplitude_vector[start:end] if amplitude_vector is not None else None # upsample_vec = upsample_vector[start:end] if upsample_vector is not None else None - parent_recording_segment = ( - None if parent_recording is None else parent_recording._recording_segments[segment_index] - ) + parent_recording_segment = None if parent_recording is None else parent_recording.segments[segment_index] recording_segment = InjectDriftingTemplatesRecordingSegment( self.dtype, self.spike_vector[start:end], diff --git a/src/spikeinterface/postprocessing/alignsorting.py b/src/spikeinterface/postprocessing/alignsorting.py index c2b23ba83e..cf4189a3c7 100644 --- a/src/spikeinterface/postprocessing/alignsorting.py +++ b/src/spikeinterface/postprocessing/alignsorting.py @@ -25,7 +25,7 @@ class AlignSortingExtractor(BaseSorting): def __init__(self, sorting, unit_peak_shifts): super().__init__(sorting.get_sampling_frequency(), sorting.unit_ids) - for segment in sorting._sorting_segments: + for segment in sorting.segments: self.add_sorting_segment(AlignSortingSegment(segment, unit_peak_shifts)) sorting.copy_metadata(self, only_main=False) diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 448be8d055..f8b5e16dc0 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Literal import warnings import importlib.util @@ -661,8 +661,9 @@ def get_convolution_weights( def compute_location_max_channel( templates_or_sorting_analyzer: SortingAnalyzer | Templates, unit_ids=None, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + peak_sign: Literal["neg", "pos", "both"] = "neg", + mode: Literal["extremum", "at_index", "peak_to_peak"] = "extremum", + operator: Literal["average", "median"] = "average", ) -> np.ndarray: """ Localize a unit using max channel. @@ -690,7 +691,7 @@ def compute_location_max_channel( 2d """ extremum_channels_index = get_template_extremum_channel( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index", operator=operator ) contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: diff --git a/src/spikeinterface/preprocessing/astype.py b/src/spikeinterface/preprocessing/astype.py index 41f88ce858..26d45dd711 100644 --- a/src/spikeinterface/preprocessing/astype.py +++ b/src/spikeinterface/preprocessing/astype.py @@ -42,7 +42,7 @@ def __init__( if round is None: round = np.issubdtype(dtype, np.integer) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = AstypeRecordingSegment( parent_segment, dtype, diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 113d1e22f1..8d1c4475cd 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -75,7 +75,7 @@ def __init__( self.parent_recording = parent_recording self.num_channels = n_pos_unique - for segment in parent_recording._recording_segments: + for segment in parent_recording.segments: recording_segment = AverageAcrossDirectionRecordingSegment( segment, self.num_channels, diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index 30ed2a7b5e..dd7676fd23 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -33,7 +33,7 @@ def __init__(self, recording, a_min=None, a_max=None): value_max = a_max BasePreprocessor.__init__(self, recording) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ClipRecordingSegment(parent_segment, a_min, value_min, a_max, value_max) self.add_recording_segment(rec_segment) @@ -130,7 +130,7 @@ def __init__( value_max = fill_value BasePreprocessor.__init__(self, recording) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ClipRecordingSegment(parent_segment, a_min, value_min, a_max, value_max) self.add_recording_segment(rec_segment) diff --git a/src/spikeinterface/preprocessing/common_reference.py b/src/spikeinterface/preprocessing/common_reference.py index b1469a0250..5a3a9b0043 100644 --- a/src/spikeinterface/preprocessing/common_reference.py +++ b/src/spikeinterface/preprocessing/common_reference.py @@ -154,7 +154,7 @@ def __init__( else: ref_channel_indices = None - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = CommonReferenceRecordingSegment( parent_segment, reference, diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index 1c7566ab20..da66cd9c3f 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -65,7 +65,7 @@ def __init__( BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: self.add_recording_segment( DecimateRecordingSegment( parent_segment, diff --git a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py index 90863a8df7..07be76d47d 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/deepinterpolation.py @@ -89,7 +89,7 @@ def __init__( self.model = model # add segment - for segment in recording._recording_segments: + for segment in recording.segments: recording_segment = DeepInterpolatedRecordingSegment( segment, self.model, diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index adcd1d80f8..92b07b8f35 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Literal import numpy as np diff --git a/src/spikeinterface/preprocessing/directional_derivative.py b/src/spikeinterface/preprocessing/directional_derivative.py index 124a6e2744..f302708055 100644 --- a/src/spikeinterface/preprocessing/directional_derivative.py +++ b/src/spikeinterface/preprocessing/directional_derivative.py @@ -50,7 +50,7 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype_) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = DirectionalDerivativeRecordingSegment( parent_segment, parent_channel_locations, diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 2eb8d7cdf8..1fc289f937 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -125,7 +125,7 @@ def __init__( f"chunking. Consider increasing the chunk_size or chunk_duration to minimize margin overhead." ) self.margin_samples = margin - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: self.add_recording_segment( FilterRecordingSegment( parent_segment, diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index b51e9603f5..1cf6873a7a 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -47,7 +47,7 @@ def __init__( if freq_min is None and freq_max is None: raise ValueError("At least one of `freq_min`,`freq_max` should be specified.") - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: # Sampling frequency is taken from recording since segments may not have it set (in case of time_vector) self.add_recording_segment( GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max, margin_sd, self.sampling_frequency) diff --git a/src/spikeinterface/preprocessing/filter_opencl.py b/src/spikeinterface/preprocessing/filter_opencl.py index 6db0c3d642..d0bffda0e0 100644 --- a/src/spikeinterface/preprocessing/filter_opencl.py +++ b/src/spikeinterface/preprocessing/filter_opencl.py @@ -72,7 +72,7 @@ def __init__( dtype = "float32" executor = OpenCLFilterExecutor(coefficients, num_channels, dtype, margin) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: self.add_recording_segment(FilterOpenCLRecordingSegment(parent_segment, executor, margin)) self._kwargs = dict( diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index f64e553980..497bbdd482 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -139,7 +139,7 @@ def __init__( dtype = fix_dtype(recording, dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = HighPassSpatialFilterSegment( parent_segment, n_channel_pad, diff --git a/src/spikeinterface/preprocessing/interpolate_bad_channels.py b/src/spikeinterface/preprocessing/interpolate_bad_channels.py index 17275f7a23..0e7e5f9950 100644 --- a/src/spikeinterface/preprocessing/interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/interpolate_bad_channels.py @@ -65,7 +65,7 @@ def __init__(self, recording, bad_channel_ids, sigma_um=None, p=1.3, weights=Non locations_bad = locations[self._bad_channel_idxs] weights = preprocessing_tools.get_kriging_channel_weights(locations_good, locations_bad, sigma_um, p) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = InterpolateBadChannelsSegment( parent_segment, self._good_channel_idxs, self._bad_channel_idxs, weights ) diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 7319e2994e..641d6af0d9 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -103,7 +103,7 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype=self._dtype) self.add_recording_segment(rec_segment) @@ -166,7 +166,7 @@ def __init__(self, recording, gain=1.0, offset=0.0, dtype="float32"): BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, self._dtype) self.add_recording_segment(rec_segment) @@ -211,7 +211,7 @@ def __init__(self, recording, mode="median", dtype="float32", **random_chunk_kwa BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype=self._dtype) self.add_recording_segment(rec_segment) @@ -313,7 +313,7 @@ def __init__( self.set_property(key="gain_to_uV", values=np.ones(num_chans, dtype="float32")) self.set_property(key="offset_to_uV", values=np.zeros(num_chans, dtype="float32")) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = ScaleRecordingSegment(parent_segment, gain, offset, dtype=self._dtype) self.add_recording_segment(rec_segment) diff --git a/src/spikeinterface/preprocessing/phase_shift.py b/src/spikeinterface/preprocessing/phase_shift.py index 4131f912f3..5648d689dd 100644 --- a/src/spikeinterface/preprocessing/phase_shift.py +++ b/src/spikeinterface/preprocessing/phase_shift.py @@ -61,7 +61,7 @@ def __init__(self, recording, margin_ms=40.0, inter_sample_shift=None, dtype=Non tmp_dtype = None BasePreprocessor.__init__(self, recording, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = PhaseShiftRecordingSegment(parent_segment, sample_shifts, margin, dtype, tmp_dtype) self.add_recording_segment(rec_segment) diff --git a/src/spikeinterface/preprocessing/rectify.py b/src/spikeinterface/preprocessing/rectify.py index 3b622149d1..7bd91a16d9 100644 --- a/src/spikeinterface/preprocessing/rectify.py +++ b/src/spikeinterface/preprocessing/rectify.py @@ -9,7 +9,7 @@ class RectifyRecording(BasePreprocessor): def __init__(self, recording): BasePreprocessor.__init__(self, recording) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = RectifyRecordingSegment(parent_segment) self.add_recording_segment(rec_segment) self._kwargs = dict(recording=recording) diff --git a/src/spikeinterface/preprocessing/remove_artifacts.py b/src/spikeinterface/preprocessing/remove_artifacts.py index 0863522fd8..3fc5449ff2 100644 --- a/src/spikeinterface/preprocessing/remove_artifacts.py +++ b/src/spikeinterface/preprocessing/remove_artifacts.py @@ -197,7 +197,7 @@ def __init__( time_pad = None BasePreprocessor.__init__(self, recording) - for seg_index, parent_segment in enumerate(recording._recording_segments): + for seg_index, parent_segment in enumerate(recording.segments): triggers = list_triggers[seg_index] labels = list_labels[seg_index] rec_segment = RemoveArtifactsRecordingSegment( diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index 773b68b977..902bd6d176 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -65,7 +65,7 @@ def __init__( margin = int(margin_ms * recording.get_sampling_frequency() / 1000) BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate, dtype=dtype) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: self.add_recording_segment( ResampleRecordingSegment( parent_segment, diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 189b97ec87..393c712919 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -103,7 +103,7 @@ def __init__( BasePreprocessor.__init__(self, recording) seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) - for seg_index, parent_segment in enumerate(recording._recording_segments): + for seg_index, parent_segment in enumerate(recording.segments): i0 = seg_limits[seg_index] i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] diff --git a/src/spikeinterface/preprocessing/tests/test_clip.py b/src/spikeinterface/preprocessing/tests/test_clip.py index cea15722a0..96020692a1 100644 --- a/src/spikeinterface/preprocessing/tests/test_clip.py +++ b/src/spikeinterface/preprocessing/tests/test_clip.py @@ -41,7 +41,7 @@ def test_blank_saturation(): traces1 = rec1.get_traces(segment_index=0, channel_ids=["0"]) assert traces1.shape[1] == 1 # use a smaller value to be sure - a_min = rec1._recording_segments[0].a_min + a_min = rec1.segments[0].a_min assert np.all(traces1 >= a_min) diff --git a/src/spikeinterface/preprocessing/tests/test_common_reference.py b/src/spikeinterface/preprocessing/tests/test_common_reference.py index 3fbc260b5f..e19cad59ba 100644 --- a/src/spikeinterface/preprocessing/tests/test_common_reference.py +++ b/src/spikeinterface/preprocessing/tests/test_common_reference.py @@ -96,7 +96,7 @@ def test_common_reference_channel_slicing(recording): start_frame = 0 end_frame = 10 - recording_segment_cmr = recording_cmr._recording_segments[0] + recording_segment_cmr = recording_cmr.segments[0] traces_cmr_all = recording_segment_cmr.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices ) @@ -106,7 +106,7 @@ def test_common_reference_channel_slicing(recording): assert np.all(traces_cmr_all[:, indices] == traces_cmr_sub) - recording_segment_car = recording_car._recording_segments[0] + recording_segment_car = recording_car.segments[0] traces_car_all = recording_segment_car.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices ) @@ -116,7 +116,7 @@ def test_common_reference_channel_slicing(recording): assert np.all(traces_car_all[:, indices] == traces_car_sub) - recording_segment_local = recording_local_car._recording_segments[0] + recording_segment_local = recording_local_car.segments[0] traces_local_all = recording_segment_local.get_traces( start_frame=start_frame, end_frame=end_frame, channel_indices=all_indices ) diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index e9493145a6..141345ca46 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -66,7 +66,7 @@ def test_decimate_with_times(): # test with t_start rec = generate_recording(durations=[5, 10]) t_starts = [10, 20] - for t_start, rec_segment in zip(t_starts, rec._recording_segments): + for t_start, rec_segment in zip(t_starts, rec.segments): rec_segment.t_start = t_start decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) for segment_index in range(rec.get_num_segments()): diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 75279bcae0..35f398f985 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -262,7 +262,7 @@ def reduce_high_freq_power_in_non_noisy_channels(recording, is_noisy, not_noisy) """ from scipy.signal import welch - for iseg, __ in enumerate(recording._recording_segments): + for iseg, __ in enumerate(recording.segments): data = recording.get_traces(iseg).T num_samples = recording.get_num_samples(iseg) @@ -291,7 +291,7 @@ def add_dead_channels(recording, is_dead): data[:, is_dead] = np.random.normal( mean, std * 0.1, size=(is_dead.size, recording.get_num_samples(segment_index)) ).T - recording._recording_segments[segment_index]._traces = data + recording.segments[segment_index]._traces = data if __name__ == "__main__": diff --git a/src/spikeinterface/preprocessing/tests/test_resample.py b/src/spikeinterface/preprocessing/tests/test_resample.py index 7e1d173fdb..c53b7b42bd 100644 --- a/src/spikeinterface/preprocessing/tests/test_resample.py +++ b/src/spikeinterface/preprocessing/tests/test_resample.py @@ -216,10 +216,10 @@ def test_resample_preserves_t_start(): t_start = 100.5 traces = np.random.randn(sampling_frequency * 2, 2).astype(np.float32) parent_rec = NumpyRecording(traces, sampling_frequency) - parent_rec._recording_segments[0].t_start = t_start + parent_rec.segments[0].t_start = t_start resampled = resample(parent_rec, 500) - assert resampled._recording_segments[0].t_start == t_start + assert resampled.segments[0].t_start == t_start assert not resampled.has_time_vector() assert np.isclose(resampled.get_times()[0], t_start) diff --git a/src/spikeinterface/preprocessing/tests/test_whiten.py b/src/spikeinterface/preprocessing/tests/test_whiten.py index 7c414df738..f4c0e4d166 100644 --- a/src/spikeinterface/preprocessing/tests/test_whiten.py +++ b/src/spikeinterface/preprocessing/tests/test_whiten.py @@ -366,8 +366,8 @@ def test_passed_W_and_M(self): whitened_recording = whiten(recording, W=test_W, M=test_M) for seg_idx in [0, 1]: - assert np.array_equal(whitened_recording._recording_segments[seg_idx].W, test_W) - assert np.array_equal(whitened_recording._recording_segments[seg_idx].M, test_M) + assert np.array_equal(whitened_recording.segments[seg_idx].W, test_W) + assert np.array_equal(whitened_recording.segments[seg_idx].M, test_M) assert whitened_recording._kwargs["W"] == test_W.tolist() assert whitened_recording._kwargs["M"] == test_M.tolist() diff --git a/src/spikeinterface/preprocessing/unsigned_to_signed.py b/src/spikeinterface/preprocessing/unsigned_to_signed.py index 62107155ee..ae1ce12281 100644 --- a/src/spikeinterface/preprocessing/unsigned_to_signed.py +++ b/src/spikeinterface/preprocessing/unsigned_to_signed.py @@ -31,7 +31,7 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype_signed) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = UnsignedToSignedRecordingSegment(parent_segment, dtype_signed, bit_depth) self.add_recording_segment(rec_segment) diff --git a/src/spikeinterface/preprocessing/whiten.py b/src/spikeinterface/preprocessing/whiten.py index d5f26d9b01..1d723b63a0 100644 --- a/src/spikeinterface/preprocessing/whiten.py +++ b/src/spikeinterface/preprocessing/whiten.py @@ -101,7 +101,7 @@ def __init__( BasePreprocessor.__init__(self, recording, dtype=dtype_) - for parent_segment in recording._recording_segments: + for parent_segment in recording.segments: rec_segment = WhitenRecordingSegment(parent_segment, W, M, dtype_, int_scale) self.add_recording_segment(rec_segment) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 35b984449d..45d4809cd8 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -32,7 +32,7 @@ def __init__(self, recording: BaseRecording, padding_start: int = 0, padding_end self.padding_start = padding_start self.padding_end = padding_end self.fill_value = fill_value - for segment in recording._recording_segments: + for segment in recording.segments: recording_segment = TracePaddedRecordingSegment( segment, recording.get_num_channels(), @@ -164,7 +164,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: self.parent_recording = recording self.num_channels = num_channels - for segment in recording._recording_segments: + for segment in recording.segments: recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.num_channels, self.channel_mapping) self.add_recording_segment(recording_segment) diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index 37c13b9395..947eaf391f 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -202,7 +202,7 @@ def __init__( # interpolation bins edges self.interpolation_time_bins_s = [] self.interpolation_time_bin_edges_s = [] - for segment_index, parent_segment in enumerate(recording._recording_segments): + for segment_index, parent_segment in enumerate(recording.segments): # in this case, interpolation_time_bin_size_s is set. s_end = parent_segment.get_num_samples() t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a50b9609b9..7c4c4b166e 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -422,7 +422,7 @@ def __init__( interpolation_time_bin_centers_s, interpolation_time_bin_edges_s ) - for segment_index, parent_segment in enumerate(recording._recording_segments): + for segment_index, parent_segment in enumerate(recording.segments): # finish the per-segment part of the time bin logic if interpolation_time_bin_centers_s is None: # in this case, interpolation_time_bin_size_s is set. diff --git a/src/spikeinterface/widgets/bombcell_curation.py b/src/spikeinterface/widgets/bombcell_curation.py index 1a1212ba5c..8c7c275ad1 100644 --- a/src/spikeinterface/widgets/bombcell_curation.py +++ b/src/spikeinterface/widgets/bombcell_curation.py @@ -1,7 +1,5 @@ """Widgets for visualizing unit labeling results.""" -from __future__ import annotations - import warnings import numpy as np diff --git a/src/spikeinterface/widgets/unit_labels.py b/src/spikeinterface/widgets/unit_labels.py index c5b55041c1..348f0e3b8d 100644 --- a/src/spikeinterface/widgets/unit_labels.py +++ b/src/spikeinterface/widgets/unit_labels.py @@ -1,7 +1,5 @@ """Widgets for visualizing unit labeling results.""" -from __future__ import annotations - import numpy as np from spikeinterface.curation.curation_tools import is_threshold_disabled From 54a4b18ffa172faadd1847a1212d2f5837c835e5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Mar 2026 15:00:06 +0100 Subject: [PATCH 4/7] Add chunkable files 1 --- src/spikeinterface/core/chunkable.py | 470 +++++++++++++++++++++++++++ 1 file changed, 470 insertions(+) create mode 100644 src/spikeinterface/core/chunkable.py diff --git a/src/spikeinterface/core/chunkable.py b/src/spikeinterface/core/chunkable.py new file mode 100644 index 0000000000..c3329a10e3 --- /dev/null +++ b/src/spikeinterface/core/chunkable.py @@ -0,0 +1,470 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Optional +import warnings + +import numpy as np + +from spikeinterface.core.base import BaseExtractor, BaseSegment + + +class ChunkableMixin(ABC): + """ + Abstract mixin class for chunkable objects. Note that the mixin can only be used + for classes that inherit from BaseExtractor. + Provides methods to handle chunked data access, that can be used for parallelization. + In addition, since chunkable objects are continuous data, time handling methods are provided. + + The Mixin is abstract since all methods need to be implemented in the child class in order + for it to function properly. + """ + + _preferred_mp_context = None + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if not issubclass(cls, BaseExtractor): + raise TypeError(f"{cls.__name__} must inherit from BaseExtractor to use Chunkable mixin.") + + @abstractmethod + def get_sampling_frequency(self) -> float: + raise NotImplementedError + + @abstractmethod + def get_num_samples(self, segment_index: int | None = None) -> int: + raise NotImplementedError + + @abstractmethod + def get_sample_size_in_bytes(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_shape(self, segment_index: int | None = None) -> tuple[int, ...]: + raise NotImplementedError + + @abstractmethod + def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray: + raise NotImplementedError + + def _extra_copy_metadata(self, other: "ChunkableMixin", **kwargs) -> None: + """ + Copy metadata from another Chunkable object. + + Parameters + ---------- + other : ChunkableMixin + The object from which to copy metadata. + """ + # inherit preferred mp context if any + if self.__class__._preferred_mp_context is not None: + other.__class__._preferred_mp_context = self.__class__._preferred_mp_context + + def get_preferred_mp_context(self): + """ + Get the preferred context for multiprocessing. + If None, the context is set by the multiprocessing package. + """ + return self.__class__._preferred_mp_context + + def get_memory_size(self, segment_index=None) -> int: + """ + Returns the memory size of segment_index in bytes. + + Parameters + ---------- + segment_index : int or None, default: None + The index of the segment for which the memory size should be calculated. + For multi-segment objects, it is required, default: None + With single segment recording returns the memory size of the single segment + + Returns + ------- + int + The memory size of the specified segment in bytes. + """ + segment_index = self._check_segment_index(segment_index) + num_samples = self.get_num_samples(segment_index=segment_index) + sample_size_in_bytes = self.get_sample_size_in_bytes() + + memory_bytes = num_samples * sample_size_in_bytes + + return memory_bytes + + def get_total_memory_size(self) -> int: + """ + Returns the sum in bytes of all the memory sizes of the segments. + + Returns + ------- + int + The total memory size in bytes for all segments. + """ + memory_per_segment = (self.get_memory_size(segment_index) for segment_index in range(self.get_num_segments())) + return sum(memory_per_segment) + + # Add time handling + def get_time_info(self, segment_index=None) -> dict: + """ + Retrieves the timing attributes for a given segment index. As with + other recorders this method only needs a segment index in the case + of multi-segment recordings. + + Returns + ------- + dict + A dictionary containing the following key-value pairs: + + - "sampling_frequency" : The sampling frequency of the RecordingSegment. + - "t_start" : The start time of the RecordingSegment. + - "time_vector" : The time vector of the RecordingSegment. + + Notes + ----- + The keys are always present, but the values may be None. + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + time_kwargs = rs.get_times_kwargs() + + return time_kwargs + + def get_times(self, segment_index=None, start_frame=None, end_frame=None) -> np.ndarray: + """Get time vector for a recording segment. + + If the segment has a time_vector, then it is returned. Otherwise + a time_vector is constructed on the fly with sampling frequency. + If t_start is defined and the time vector is constructed on the fly, + the first time will be t_start. Otherwise it will start from 0. + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + start_frame : int or None, default: None + The start frame for the time vector + end_frame : int or None, default: None + The end frame for the time vector + + Returns + ------- + np.array + The 1d times array + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + times = rs.get_times(start_frame=start_frame, end_frame=end_frame) + return times + + def get_start_time(self, segment_index=None) -> float: + """Get the start time of the recording segment. + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + + Returns + ------- + float + The start time in seconds + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + return rs.get_start_time() + + def get_end_time(self, segment_index=None) -> float: + """Get the stop time of the recording segment. + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + + Returns + ------- + float + The stop time in seconds + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + return rs.get_end_time() + + def has_time_vector(self, segment_index: Optional[int] = None): + """Check if the segment of the recording has a time vector. + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + + Returns + ------- + bool + True if the recording has time vectors, False otherwise + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + d = rs.get_times_kwargs() + return d["time_vector"] is not None + + def set_times(self, times, segment_index=None, with_warning=True): + """Set times for a recording segment. + + Parameters + ---------- + times : 1d np.array + The time vector + segment_index : int or None, default: None + The segment index (required for multi-segment) + with_warning : bool, default: True + If True, a warning is printed + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + + assert times.ndim == 1, "Time must have ndim=1" + assert rs.get_num_samples() == times.shape[0], "times have wrong shape" + + rs.t_start = None + rs.time_vector = times.astype("float64", copy=False) + + if with_warning: + warnings.warn( + "Setting times with Recording.set_times() is not recommended because " + "times are not always propagated across preprocessing" + "Use this carefully!" + ) + + def reset_times(self): + """ + Reset time information in-memory for all segments that have a time vector. + If the timestamps come from a file, the files won't be modified. but only the in-memory + attributes of the recording objects are deleted. Also `t_start` is set to None and the + segment's sampling frequency is set to the recording's sampling frequency. + """ + for segment_index in range(self.get_num_segments()): + rs = self.segments[segment_index] + if self.has_time_vector(segment_index): + rs.time_vector = None + rs.t_start = None + rs.sampling_frequency = self.sampling_frequency + + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: + """ + Shift all times by a scalar value. + + Parameters + ---------- + shift : int | float + The shift to apply. If positive, times will be increased by `shift`. + e.g. shifting by 1 will be like the recording started 1 second later. + If negative, the start time will be decreased i.e. as if the recording + started earlier. + + segment_index : int | None + The segment on which to shift the times. + If `None`, all segments will be shifted. + """ + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) + else: + segments_to_shift = (segment_index,) + + for segment_index in segments_to_shift: + rs = self.segments[segment_index] + + if self.has_time_vector(segment_index=segment_index): + rs.time_vector += shift + else: + new_start_time = 0 + shift if rs.t_start is None else rs.t_start + shift + rs.t_start = new_start_time + + def sample_index_to_time(self, sample_ind, segment_index=None): + """ + Transform sample index into time in seconds + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + return rs.sample_index_to_time(sample_ind) + + def time_to_sample_index(self, time_s, segment_index=None): + """ + Transform time in seconds into sample index + """ + segment_index = self._check_segment_index(segment_index) + rs = self.segments[segment_index] + return rs.time_to_sample_index(time_s) + + def get_total_samples(self) -> int: + """ + Returns the sum of the number of samples in each segment. + + Returns + ------- + int + The total number of samples + """ + num_segments = self.get_num_segments() + samples_per_segment = (self.get_num_samples(segment_index) for segment_index in range(num_segments)) + + return sum(samples_per_segment) + + def get_duration(self, segment_index=None) -> float: + """ + Returns the duration in seconds. + + Parameters + ---------- + segment_index : int or None, default: None + The sample index to retrieve the duration for. + For multi-segment objects, it is required, default: None + With single segment recording returns the duration of the single segment + + Returns + ------- + float + The duration in seconds + """ + segment_duration = ( + self.get_end_time(segment_index) - self.get_start_time(segment_index) + (1 / self.get_sampling_frequency()) + ) + return segment_duration + + def get_total_duration(self) -> float: + """ + Returns the total duration in seconds + + Returns + ------- + float + The duration in seconds + """ + duration = sum([self.get_duration(segment_index) for segment_index in range(self.get_num_segments())]) + return duration + + def _get_t_starts(self): + # handle t_starts + t_starts = [] + for rs in self.segments: + d = rs.get_times_kwargs() + t_starts.append(d["t_start"]) + + if all(t_start is None for t_start in t_starts): + t_starts = None + return t_starts + + def _get_time_vectors(self): + time_vectors = [] + for rs in self.segments: + d = rs.get_times_kwargs() + time_vectors.append(d["time_vector"]) + if all(time_vector is None for time_vector in time_vectors): + time_vectors = None + return time_vectors + + +class ChunkableSegment(BaseSegment): + """Class for chunkable segments, which provide methods to handle time kwargs.""" + + def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): + # sampling_frequency and time_vector are exclusive + if sampling_frequency is None: + assert time_vector is not None, "Pass either 'sampling_frequency' or 'time_vector'" + assert time_vector.ndim == 1, "time_vector should be a 1D array" + + if time_vector is None: + assert sampling_frequency is not None, "Pass either 'sampling_frequency' or 'time_vector'" + + self.sampling_frequency = sampling_frequency + self.t_start = t_start + self.time_vector = time_vector + + BaseSegment.__init__(self) + + def get_times(self, start_frame: int | None = None, end_frame: int | None = None) -> np.ndarray: + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = self.get_num_samples() + if self.time_vector is not None: + self.time_vector = np.asarray(self.time_vector) + return self.time_vector[start_frame:end_frame] + else: + time_vector = np.arange(start_frame, end_frame, dtype="float64") + time_vector /= self.sampling_frequency + if self.t_start is not None: + time_vector += self.t_start + return time_vector + + def get_start_time(self) -> float: + if self.time_vector is not None: + return self.time_vector[0] + else: + return self.t_start if self.t_start is not None else 0.0 + + def get_end_time(self) -> float: + if self.time_vector is not None: + return self.time_vector[-1] + else: + t_stop = (self.get_num_samples() - 1) / self.sampling_frequency + if self.t_start is not None: + t_stop += self.t_start + return t_stop + + def get_times_kwargs(self) -> dict: + """ + Retrieves the timing attributes characterizing a RecordingSegment + + Returns + ------- + dict + A dictionary containing the following key-value pairs: + + - "sampling_frequency" : The sampling frequency of the RecordingSegment. + - "t_start" : The start time of the RecordingSegment. + - "time_vector" : The time vector of the RecordingSegment. + + Notes + ----- + The keys are always present, but the values may be None. + """ + time_kwargs = dict( + sampling_frequency=self.sampling_frequency, t_start=self.t_start, time_vector=self.time_vector + ) + return time_kwargs + + def sample_index_to_time(self, sample_ind): + """ + Transform sample index into time in seconds + """ + if self.time_vector is None: + time_s = sample_ind / self.sampling_frequency + if self.t_start is not None: + time_s += self.t_start + else: + time_s = self.time_vector[sample_ind] + return time_s + + def time_to_sample_index(self, time_s): + """ + Transform time in seconds into sample index + """ + if self.time_vector is None: + if self.t_start is None: + sample_index = time_s * self.sampling_frequency + else: + sample_index = (time_s - self.t_start) * self.sampling_frequency + sample_index = np.round(sample_index).astype(np.int64) + else: + sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1 + + return sample_index + + def get_num_samples(self) -> int: + """Returns the number of samples in this signal segment + + Returns: + SampleIndex : Number of samples in the signal segment + """ + # must be implemented in subclass + raise NotImplementedError From eb19b3d840bae6adad54909266e9363ee6ba1b54 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Mar 2026 15:03:31 +0100 Subject: [PATCH 5/7] Add chunkable files 2 --- src/spikeinterface/core/baserecording.py | 452 ++---------- src/spikeinterface/core/chunkable_tools.py | 689 ++++++++++++++++++ .../core/tests/test_chunkable_tools.py | 174 +++++ 3 files changed, 915 insertions(+), 400 deletions(-) create mode 100644 src/spikeinterface/core/chunkable_tools.py create mode 100644 src/spikeinterface/core/tests/test_chunkable_tools.py diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 322e1c7547..1f95d04781 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -5,14 +5,13 @@ import numpy as np from probeinterface import read_probeinterface, write_probeinterface -from .base import BaseSegment +from .chunkable import ChunkableSegment, ChunkableMixin from .baserecordingsnippets import BaseRecordingSnippets from .core_tools import convert_bytes_to_str, convert_seconds_to_str from .job_tools import split_job_kwargs -from .recording_tools import write_binary_recording -class BaseRecording(BaseRecordingSnippets): +class BaseRecording(BaseRecordingSnippets, ChunkableMixin): """ Abstract class representing several a multichannel timeseries (or block of raw ephys traces). Internally handle list of RecordingSegment @@ -43,6 +42,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list, dtype): BaseRecordingSnippets.__init__( self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype ) + # initialize main annotation and properties self.annotate(is_filtered=False) @@ -178,110 +178,40 @@ def add_recording_segment(self, recording_segment: "BaseRecordingSegment") -> No """ super().add_segment(recording_segment) - def get_num_samples(self, segment_index: int | None = None) -> int: + def get_sample_size_in_bytes(self): """ - Returns the number of samples for a segment. - - Parameters - ---------- - segment_index : int or None, default: None - The segment index to retrieve the number of samples for. - For multi-segment objects, it is required, default: None - With single segment recording returns the number of samples in the segment + Returns the size of a single sample across all channels in bytes. Returns ------- int - The number of samples - """ - segment_index = self._check_segment_index(segment_index) - return int(self.segments[segment_index].get_num_samples()) - - get_num_frames = get_num_samples - - def get_total_samples(self) -> int: - """ - Returns the sum of the number of samples in each segment. - - Returns - ------- - int - The total number of samples - """ - num_segments = self.get_num_segments() - samples_per_segment = (self.get_num_samples(segment_index) for segment_index in range(num_segments)) - - return sum(samples_per_segment) - - def get_duration(self, segment_index=None) -> float: - """ - Returns the duration in seconds. - - Parameters - ---------- - segment_index : int or None, default: None - The sample index to retrieve the duration for. - For multi-segment objects, it is required, default: None - With single segment recording returns the duration of the single segment - - Returns - ------- - float - The duration in seconds - """ - segment_duration = ( - self.get_end_time(segment_index) - self.get_start_time(segment_index) + (1 / self.get_sampling_frequency()) - ) - return segment_duration - - def get_total_duration(self) -> float: + The size of a single sample in bytes """ - Returns the total duration in seconds - - Returns - ------- - float - The duration in seconds - """ - duration = sum([self.get_duration(segment_index) for segment_index in range(self.get_num_segments())]) - return duration + num_channels = self.get_num_channels() + dtype_size_bytes = self.get_dtype().itemsize + sample_size = num_channels * dtype_size_bytes + return sample_size - def get_memory_size(self, segment_index=None) -> int: + def get_num_samples(self, segment_index: int | None = None) -> int: """ - Returns the memory size of segment_index in bytes. + Returns the number of samples for a segment. Parameters ---------- segment_index : int or None, default: None - The index of the segment for which the memory size should be calculated. + The segment index to retrieve the number of samples for. For multi-segment objects, it is required, default: None - With single segment recording returns the memory size of the single segment + With single segment recording returns the number of samples in the segment Returns ------- int - The memory size of the specified segment in bytes. + The number of samples """ segment_index = self._check_segment_index(segment_index) - num_samples = self.get_num_samples(segment_index=segment_index) - num_channels = self.get_num_channels() - dtype_size_bytes = self.get_dtype().itemsize - - memory_bytes = num_samples * num_channels * dtype_size_bytes - - return memory_bytes - - def get_total_memory_size(self) -> int: - """ - Returns the sum in bytes of all the memory sizes of the segments. + return int(self.segments[segment_index].get_num_samples()) - Returns - ------- - int - The total memory size in bytes for all segments. - """ - memory_per_segment = (self.get_memory_size(segment_index) for segment_index in range(self.get_num_segments())) - return sum(memory_per_segment) + get_num_frames = get_num_samples def get_traces( self, @@ -364,228 +294,33 @@ def get_traces( traces = traces.astype("float32", copy=False) * gains + offsets return traces - def get_time_info(self, segment_index=None) -> dict: - """ - Retrieves the timing attributes for a given segment index. As with - other recorders this method only needs a segment index in the case - of multi-segment recordings. - - Returns - ------- - dict - A dictionary containing the following key-value pairs: - - - "sampling_frequency" : The sampling frequency of the RecordingSegment. - - "t_start" : The start time of the RecordingSegment. - - "time_vector" : The time vector of the RecordingSegment. - - Notes - ----- - The keys are always present, but the values may be None. - """ - - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - time_kwargs = rs.get_times_kwargs() - - return time_kwargs - - def get_times(self, segment_index=None) -> np.ndarray: - """Get time vector for a recording segment. - - If the segment has a time_vector, then it is returned. Otherwise - a time_vector is constructed on the fly with sampling frequency. - If t_start is defined and the time vector is constructed on the fly, - the first time will be t_start. Otherwise it will start from 0. - - Parameters - ---------- - segment_index : int or None, default: None - The segment index (required for multi-segment) - - Returns - ------- - np.array - The 1d times array - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - times = rs.get_times() - return times - - def get_start_time(self, segment_index=None) -> float: - """Get the start time of the recording segment. - - Parameters - ---------- - segment_index : int or None, default: None - The segment index (required for multi-segment) - - Returns - ------- - float - The start time in seconds - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - return rs.get_start_time() - - def get_end_time(self, segment_index=None) -> float: - """Get the stop time of the recording segment. - - Parameters - ---------- - segment_index : int or None, default: None - The segment index (required for multi-segment) - - Returns - ------- - float - The stop time in seconds - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - return rs.get_end_time() - - def has_time_vector(self, segment_index: int | None = None): - """Check if the segment of the recording has a time vector. - - Parameters - ---------- - segment_index : int or None, default: None - The segment index (required for multi-segment) - - Returns - ------- - bool - True if the recording has time vectors, False otherwise - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - d = rs.get_times_kwargs() - return d["time_vector"] is not None - - def set_times(self, times, segment_index=None, with_warning=True): - """Set times for a recording segment. - - Parameters - ---------- - times : 1d np.array - The time vector - segment_index : int or None, default: None - The segment index (required for multi-segment) - with_warning : bool, default: True - If True, a warning is printed - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - - assert times.ndim == 1, "Time must have ndim=1" - assert rs.get_num_samples() == times.shape[0], "times have wrong shape" - - rs.t_start = None - rs.time_vector = times.astype("float64", copy=False) - - if with_warning: - warnings.warn( - "Setting times with Recording.set_times() is not recommended because " - "times are not always propagated across preprocessing" - "Use this carefully!" - ) - - def reset_times(self): - """ - Reset time information in-memory for all segments that have a time vector. - If the timestamps come from a file, the files won't be modified. but only the in-memory - attributes of the recording objects are deleted. Also `t_start` is set to None and the - segment's sampling frequency is set to the recording's sampling frequency. + def get_data(self, start_frame: int, end_frame: int, segment_index: int | None = None, **kwargs) -> np.ndarray: """ - for segment_index in range(self.get_num_segments()): - rs = self.segments[segment_index] - if self.has_time_vector(segment_index): - rs.time_vector = None - rs.t_start = None - rs.sampling_frequency = self.sampling_frequency - - def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: + General retrieval function for chunkable objects """ - Shift all times by a scalar value. + return self.get_traces(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame, **kwargs) - Parameters - ---------- - shift : int | float - The shift to apply. If positive, times will be increased by `shift`. - e.g. shifting by 1 will be like the recording started 1 second later. - If negative, the start time will be decreased i.e. as if the recording - started earlier. - - segment_index : int | None - The segment on which to shift the times. - If `None`, all segments will be shifted. - """ - if segment_index is None: - segments_to_shift = range(self.get_num_segments()) - else: - segments_to_shift = (segment_index,) - - for segment_index in segments_to_shift: - rs = self.segments[segment_index] - - if self.has_time_vector(segment_index=segment_index): - rs.time_vector += shift - else: - new_start_time = 0 + shift if rs.t_start is None else rs.t_start + shift - rs.t_start = new_start_time - - def sample_index_to_time(self, sample_ind, segment_index=None): - """ - Transform sample index into time in seconds - """ - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - return rs.sample_index_to_time(sample_ind) - - def time_to_sample_index(self, time_s, segment_index=None): - segment_index = self._check_segment_index(segment_index) - rs = self.segments[segment_index] - return rs.time_to_sample_index(time_s) - - def _get_t_starts(self): - # handle t_starts - t_starts = [] - has_time_vectors = [] - for rs in self.segments: - d = rs.get_times_kwargs() - t_starts.append(d["t_start"]) - - if all(t_start is None for t_start in t_starts): - t_starts = None - return t_starts - - def _get_time_vectors(self): - time_vectors = [] - for rs in self.segments: - d = rs.get_times_kwargs() - time_vectors.append(d["time_vector"]) - if all(time_vector is None for time_vector in time_vectors): - time_vectors = None - return time_vectors + def get_shape(self, segment_index: int | None = None) -> tuple[int, ...]: + return (self.get_num_samples(segment_index=segment_index), self.get_num_channels()) def _save(self, format="binary", verbose: bool = False, **save_kwargs): kwargs, job_kwargs = split_job_kwargs(save_kwargs) if format == "binary": + from .chunkable_tools import write_binary + folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() t_starts = self._get_t_starts() - write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) + write_binary(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) from .binaryrecordingextractor import BinaryRecordingExtractor # This is created so it can be saved as json because the `BinaryFolderRecording` requires it loading # See the __init__ of `BinaryFolderRecording` + binary_rec = BinaryRecordingExtractor( file_paths=file_paths, sampling_frequency=self.get_sampling_frequency(), @@ -605,6 +340,13 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): cached = BinaryFolderRecording(folder_path=folder) + # timestamps are not saved in binary, so we have to set them explicitly + for segment_index in range(self.get_num_segments()): + if self.has_time_vector(segment_index): + # the use of get_times is preferred since timestamps are converted to array + time_vector = self.get_times(segment_index=segment_index) + cached.set_times(time_vector, segment_index=segment_index) + elif format == "memory": if kwargs.get("sharedmem", True): from .numpyextractors import SharedMemoryRecording @@ -615,6 +357,13 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): cached = NumpyRecording.from_recording(self, **job_kwargs) + # timestamps are not saved in memory, so we have to set them explicitly + for segment_index in range(self.get_num_segments()): + if self.has_time_vector(segment_index): + # the use of get_times is preferred since timestamps are converted to array + time_vector = self.get_times(segment_index=segment_index) + cached.set_times(time_vector, segment_index=segment_index) + elif format == "zarr": from .zarrextractors import ZarrRecordingExtractor @@ -625,6 +374,8 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): ) cached = ZarrRecordingExtractor(zarr_path, storage_options) + # timestamps are saved and restored in zarr, so no need to set them explicitly + elif format == "nwb": # TODO implement a format based on zarr raise NotImplementedError @@ -636,12 +387,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) - for segment_index in range(self.get_num_segments()): - if self.has_time_vector(segment_index): - # the use of get_times is preferred since timestamps are converted to array - time_vector = self.get_times(segment_index=segment_index) - cached.set_times(time_vector, segment_index=segment_index) - return cached def _extra_metadata_from_folder(self, folder): @@ -719,7 +464,7 @@ def _remove_channels(self, remove_channel_ids): sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording - def frame_slice(self, start_frame: int | None, end_frame: int | None) -> "BaseRecording": + def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRecording: """ Returns a new recording with sliced frames. Note that this operation is not in place. @@ -741,7 +486,7 @@ def frame_slice(self, start_frame: int | None, end_frame: int | None) -> "BaseRe sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame) return sub_recording - def time_slice(self, start_time: float | None, end_time: float | None) -> "BaseRecording": + def time_slice(self, start_time: float | None, end_time: float | None) -> BaseRecording: """ Returns a new recording object, restricted to the time interval [start_time, end_time]. @@ -799,7 +544,7 @@ def _select_segments(self, segment_indices): def get_channel_locations( self, channel_ids: list | np.ndarray | tuple | None = None, - axes: Literal["xy", "yz", "xz", "xyz"] = "xy", + axes: "xy" | "yz" | "xz" | "xyz" = "xy", ) -> np.ndarray: """ Get the physical locations of specified channels. @@ -888,110 +633,11 @@ def astype(self, dtype, round: bool | None = None): return astype(self, dtype=dtype, round=round) -class BaseRecordingSegment(BaseSegment): +class BaseRecordingSegment(ChunkableSegment): """ Abstract class representing a multichannel timeseries, or block of raw ephys traces """ - def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): - # sampling_frequency and time_vector are exclusive - if sampling_frequency is None: - assert time_vector is not None, "Pass either 'sampling_frequency' or 'time_vector'" - assert time_vector.ndim == 1, "time_vector should be a 1D array" - - if time_vector is None: - assert sampling_frequency is not None, "Pass either 'sampling_frequency' or 'time_vector'" - - self.sampling_frequency = sampling_frequency - self.t_start = t_start - self.time_vector = time_vector - - BaseSegment.__init__(self) - - def get_times(self) -> np.ndarray: - if self.time_vector is not None: - self.time_vector = np.asarray(self.time_vector) - return self.time_vector - else: - time_vector = np.arange(self.get_num_samples(), dtype="float64") - time_vector /= self.sampling_frequency - if self.t_start is not None: - time_vector += self.t_start - return time_vector - - def get_start_time(self) -> float: - if self.time_vector is not None: - return self.time_vector[0] - else: - return self.t_start if self.t_start is not None else 0.0 - - def get_end_time(self) -> float: - if self.time_vector is not None: - return self.time_vector[-1] - else: - t_stop = (self.get_num_samples() - 1) / self.sampling_frequency - if self.t_start is not None: - t_stop += self.t_start - return t_stop - - def get_times_kwargs(self) -> dict: - """ - Retrieves the timing attributes characterizing a RecordingSegment - - Returns - ------- - dict - A dictionary containing the following key-value pairs: - - - "sampling_frequency" : The sampling frequency of the RecordingSegment. - - "t_start" : The start time of the RecordingSegment. - - "time_vector" : The time vector of the RecordingSegment. - - Notes - ----- - The keys are always present, but the values may be None. - """ - time_kwargs = dict( - sampling_frequency=self.sampling_frequency, t_start=self.t_start, time_vector=self.time_vector - ) - return time_kwargs - - def sample_index_to_time(self, sample_ind): - """ - Transform sample index into time in seconds - """ - if self.time_vector is None: - time_s = sample_ind / self.sampling_frequency - if self.t_start is not None: - time_s += self.t_start - else: - time_s = self.time_vector[sample_ind] - return time_s - - def time_to_sample_index(self, time_s): - """ - Transform time in seconds into sample index - """ - if self.time_vector is None: - if self.t_start is None: - sample_index = time_s * self.sampling_frequency - else: - sample_index = (time_s - self.t_start) * self.sampling_frequency - sample_index = np.round(sample_index).astype(np.int64) - else: - sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1 - - return sample_index - - def get_num_samples(self) -> int: - """Returns the number of samples in this signal segment - - Returns: - SampleIndex : Number of samples in the signal segment - """ - # must be implemented in subclass - raise NotImplementedError - def get_traces( self, start_frame: int | None = None, @@ -1017,3 +663,9 @@ def get_traces( """ # must be implemented in subclass raise NotImplementedError + + def get_data(self, start_frame: int, end_frame: int, indices: list | np.array | tuple | None = None) -> np.ndarray: + """ + General retrieval function for chunkable objects + """ + return self.get_traces(start_frame=start_frame, end_frame=end_frame, channel_indices=indices) diff --git a/src/spikeinterface/core/chunkable_tools.py b/src/spikeinterface/core/chunkable_tools.py new file mode 100644 index 0000000000..558536332d --- /dev/null +++ b/src/spikeinterface/core/chunkable_tools.py @@ -0,0 +1,689 @@ +from __future__ import annotations +from pathlib import Path +import warnings + + +import numpy as np + +from .core_tools import add_suffix, make_shared_array +from .job_tools import ( + chunk_duration_to_chunk_size, + ensure_n_jobs, + fix_job_kwargs, + ChunkExecutor, + _shared_job_kwargs_doc, +) +from .chunkable import ChunkableMixin, ChunkableSegment + + +def write_binary( + chunkable: ChunkableMixin, + file_paths: list[Path | str] | Path | str, + file_timestamps_paths: list[Path | str] | Path | str | None = None, + dtype: np.typing.DTypeLike = None, + add_file_extension: bool = True, + byte_offset: int = 0, + verbose: bool = False, + **job_kwargs, +): + """ + Save the data of a chunkable object to binary format. + + Note : + time_axis is always 0 (contrary to previous version. + to get time_axis=1 (which is a bad idea) use `write_binary_file_handle()` + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to be saved to binary file + file_paths : list[Path | str] | Path | str + The path to the files to save data for each segment. + file_timestamps_paths : list[Path | str] | Path | str | None, default: None + The path to the timestamps file. If None, timestamps are not saved. + dtype : dtype or None, default: None + Type of the saved data + add_file_extension, bool, default: True + If True, and the file path does not end in "raw", "bin", or "dat" then "raw" is added as an extension. + byte_offset : int, default: 0 + Offset in bytes for the binary file (e.g. to write a header). This is useful in case you want to append data + to an existing file where you wrote a header or other data before. + verbose : bool + This is the verbosity of the ChunkExecutor + {} + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + file_path_list = [file_paths] if not isinstance(file_paths, list) else file_paths + num_segments = chunkable.get_num_segments() + if len(file_path_list) != num_segments: + raise ValueError("'file_paths' must be a list of the same size as the number of segments in the chunkable") + + file_path_list = [Path(file_path) for file_path in file_path_list] + if add_file_extension: + file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list] + + dtype = dtype if dtype is not None else chunkable.get_dtype() + + sample_size_bytes = chunkable.get_sample_size_in_bytes() + + file_path_dict = {segment_index: file_path for segment_index, file_path in enumerate(file_path_list)} + if file_timestamps_paths is not None: + file_timestamps_path_dict = { + segment_index: file_path for segment_index, file_path in enumerate(file_timestamps_paths) + } + else: + file_timestamps_path_dict = None + for segment_index, file_path in file_path_dict.items(): + num_samples = chunkable.get_num_samples(segment_index=segment_index) + data_size_bytes = sample_size_bytes * num_samples + file_size_bytes = data_size_bytes + byte_offset + + # Create an empty file with file_size_bytes + with open(file_path, "wb+") as file: + # The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408) + file.seek(file_size_bytes - 1) + file.write(b"\0") + + if file_timestamps_path_dict is not None: + file_timestamps_path = file_timestamps_path_dict[segment_index] + with open(file_timestamps_path, "wb+") as file: + file.seek(num_samples * 8 - 1) # 8 bytes for float64 timestamps + file.write(b"\0") + + assert Path(file_path).is_file() + + # use executor (loop or workers) + func = _write_binary_chunk + init_func = _init_binary_worker + init_args = (chunkable, file_path_dict, dtype, byte_offset, file_timestamps_path_dict) + executor = ChunkExecutor( + chunkable, func, init_func, init_args, job_name="write_binary", verbose=verbose, **job_kwargs + ) + executor.run() + + +# used by write_binary + ChunkExecutor +def _init_binary_worker(chunkable, file_path_dict, dtype, byte_offset, file_timestamps_path_dict=None): + # create a local dict per worker + worker_ctx = {} + worker_ctx["chunkable"] = chunkable + worker_ctx["byte_offset"] = byte_offset + worker_ctx["dtype"] = np.dtype(dtype) + + file_dict = {segment_index: open(file_path, "rb+") for segment_index, file_path in file_path_dict.items()} + worker_ctx["file_dict"] = file_dict + worker_ctx["file_timestamps_dict"] = file_timestamps_path_dict + + return worker_ctx + + +# used by write_binary + ChunkExecutor +def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + chunkable = worker_ctx["chunkable"] + dtype = worker_ctx["dtype"] + byte_offset = worker_ctx["byte_offset"] + file = worker_ctx["file_dict"][segment_index] + file_timestamps_dict = worker_ctx["file_timestamps_dict"] + sample_size_bytes = chunkable.get_sample_size_in_bytes() + + # Calculate byte offsets for the start frames relative to the entire recording + start_byte = byte_offset + start_frame * sample_size_bytes + + data = chunkable.get_data(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + data = data.astype(dtype, order="c", copy=False) + + file.seek(start_byte) + file.write(data.data) + # flush is important!! + file.flush() + + if file_timestamps_dict is not None: + file_timestamps = file_timestamps_dict[segment_index] + timestamps = chunkable.get_times(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + timestamps = timestamps.astype("float64", order="c", copy=False) + timestamp_byte_offset = start_frame * 8 # 8 bytes for float64 + file.seek(timestamp_byte_offset) + file.write(timestamps.data) + file.flush() + + +write_binary.__doc__ = write_binary.__doc__.format(_shared_job_kwargs_doc) + + +# used by write_memory +def _init_memory_worker(chunkable, arrays, shm_names, shapes, dtype): + # create a local dict per worker + worker_ctx = {} + worker_ctx["chunkable"] = chunkable + worker_ctx["dtype"] = np.dtype(dtype) + + if arrays is None: + # create it from share memory name + from multiprocessing.shared_memory import SharedMemory + + arrays = [] + # keep shm alive + worker_ctx["shms"] = [] + for i in range(len(shm_names)): + shm = SharedMemory(shm_names[i]) + worker_ctx["shms"].append(shm) + arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf) + arrays.append(arr) + + worker_ctx["arrays"] = arrays + + return worker_ctx + + +# used by write_memory +def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + chunkable = worker_ctx["chunkable"] + dtype = worker_ctx["dtype"] + arr = worker_ctx["arrays"][segment_index] + + # apply function + traces = chunkable.get_data(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + traces = traces.astype(dtype, copy=False) + arr[start_frame:end_frame, :] = traces + + +def write_memory( + chunkable: ChunkableMixin, dtype=None, verbose=False, buffer_type="auto", job_name="write_memory", **job_kwargs +): + """ + Save the traces into numpy arrays (memory). + try to use the SharedMemory introduce in py3.8 if n_jobs > 1 + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to be saved to memory + dtype : dtype, default: None + Type of the saved data + verbose : bool, default: False + If True, output is verbose (when chunks are used) + buffer_type : "auto" | "numpy" | "sharedmem", + The type of buffer to use for storing the data. + job_name : str, default: "write_memory" + Name of the job + {} + + Returns + --------- + arrays : one array per segment + """ + job_kwargs = fix_job_kwargs(job_kwargs) + + if dtype is None: + dtype = chunkable.get_dtype() + + # create sharedmmep + arrays = [] + shm_names = [] + shms = [] + shapes = [] + + n_jobs = ensure_n_jobs(chunkable, n_jobs=job_kwargs.get("n_jobs", 1)) + if buffer_type == "auto": + if n_jobs > 1: + buffer_type = "sharedmem" + else: + buffer_type = "numpy" + + for segment_index in range(chunkable.get_num_segments()): + shape = chunkable.get_shape(segment_index=segment_index) + shapes.append(shape) + if buffer_type == "sharedmem": + arr, shm = make_shared_array(shape, dtype) + shm_names.append(shm.name) + shms.append(shm) + else: + arr = np.zeros(shape, dtype=dtype) + shms.append(None) + arrays.append(arr) + + # use executor (loop or workers) + func = _write_memory_chunk + init_func = _init_memory_worker + if n_jobs > 1: + init_args = (chunkable, None, shm_names, shapes, dtype) + else: + init_args = (chunkable, arrays, None, None, dtype) + + executor = ChunkExecutor(chunkable, func, init_func, init_args, verbose=verbose, job_name=job_name, **job_kwargs) + executor.run() + + return arrays, shms + + +write_memory.__doc__ = write_memory.__doc__.format(_shared_job_kwargs_doc) + + +def write_chunkable_to_zarr( + chunkable: ChunkableMixin, + zarr_group, + dataset_paths, + dataset_timestamps_paths=None, + extra_chunks=None, + dtype=None, + compressor_data=None, + filters_data=None, + compressor_times=None, + filters_times=None, + verbose=False, + **job_kwargs, +): + """ + Save the trace of a chunkable object in several zarr format. + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to be saved in .dat format + zarr_group : zarr.Group + The zarr group to add traces to + dataset_paths : list + List of paths to traces datasets in the zarr group + dataset_timestamps_paths : list or None, default: None + List of paths to timestamps datasets in the zarr group. If None, timestamps are not saved. + extra_chunks : tuple or None, default: None + Extra chunking dimensions to use for the zarr dataset. + The first dimension is always time and controlled by the job_kwargs. + This is for example useful to chunk by channel, with `extra_chunks=(channel_chunk_size,)`. + dtype : dtype, default: None + Type of the saved data + compressor_data : zarr compressor or None, default: None + Zarr compressor for data + filters_data : list, default: None + List of zarr filters for data + compressor_times : zarr compressor or None, default: None + Zarr compressor for timestamps + filters_times : list, default: None + List of zarr filters for timestamps + verbose : bool, default: False + If True, output is verbose (when chunks are used) + {} + """ + from .job_tools import ( + ensure_chunk_size, + fix_job_kwargs, + ChunkExecutor, + ) + + assert dataset_paths is not None, "Provide 'dataset_paths' to save data in zarr format" + if dataset_timestamps_paths is not None: + assert ( + len(dataset_timestamps_paths) == chunkable.get_num_segments() + ), "dataset_timestamps_paths should have the same length as the number of segments in the chunkable" + else: + dataset_timestamps_paths = [None] * chunkable.get_num_segments() + + if not isinstance(dataset_paths, list): + dataset_paths = [dataset_paths] + assert len(dataset_paths) == chunkable.get_num_segments() + + if dtype is None: + dtype = chunkable.get_dtype() + + job_kwargs = fix_job_kwargs(job_kwargs) + chunk_size = ensure_chunk_size(chunkable, **job_kwargs) + + if extra_chunks is not None: + assert len(extra_chunks) == len(chunkable.get_shape(0)[1:]), ( + "extra_chunks should have the same length as the number of dimensions " + "of the chunkable minus one (time axis)" + ) + + # create zarr datasets files + zarr_datasets = [] + zarr_timestamps_datasets = [] + + for segment_index in range(chunkable.get_num_segments()): + num_samples = chunkable.get_num_samples(segment_index) + dset_name = dataset_paths[segment_index] + shape = chunkable.get_shape(segment_index) + dset = zarr_group.create_dataset( + name=dset_name, + shape=shape, + chunks=(chunk_size,) + extra_chunks if extra_chunks is not None else (chunk_size,), + dtype=dtype, + filters=filters_data, + compressor=compressor_data, + ) + zarr_datasets.append(dset) + if dataset_timestamps_paths[segment_index] is not None: + tset_name = dataset_timestamps_paths[segment_index] + zarr_timestamps_datasets.append( + zarr_group.create_dataset( + name=tset_name, + shape=(num_samples,), + chunks=(chunk_size,), + dtype="float64", + filters=filters_times, + compressor=compressor_times, + ) + ) + else: + zarr_timestamps_datasets.append(None) + + # use executor (loop or workers) + func = _write_zarr_chunk + init_func = _init_zarr_worker + init_args = (chunkable, zarr_datasets, dtype, zarr_timestamps_datasets) + executor = ChunkExecutor( + chunkable, func, init_func, init_args, verbose=verbose, job_name="write_zarr", **job_kwargs + ) + executor.run() + + # save t_starts + t_starts = np.zeros(chunkable.get_num_segments(), dtype="float64") * np.nan + for segment_index in range(chunkable.get_num_segments()): + time_info = chunkable.get_time_info(segment_index) + if time_info["t_start"] is not None: + t_starts[segment_index] = time_info["t_start"] + + if np.any(~np.isnan(t_starts)): + zarr_group.create_dataset(name="t_starts", data=t_starts, compressor=None) + + +# used by write_zarr_recording + ChunkExecutor +def _init_zarr_worker(chunkable, zarr_datasets, dtype, zarr_timestamps_datasets=None): + import zarr + + # create a local dict per worker + worker_ctx = {} + worker_ctx["chunkable"] = chunkable + worker_ctx["zarr_datasets"] = zarr_datasets + if zarr_timestamps_datasets is not None and len(zarr_timestamps_datasets) > 0: + worker_ctx["zarr_timestamps_datasets"] = zarr_timestamps_datasets + else: + worker_ctx["zarr_timestamps_datasets"] = None + worker_ctx["dtype"] = np.dtype(dtype) + + return worker_ctx + + +# used by write_zarr_recording + ChunkExecutor +def _write_zarr_chunk(segment_index, start_frame, end_frame, worker_ctx): + import gc + + # recover variables of the worker + chunkable = worker_ctx["chunkable"] + dtype = worker_ctx["dtype"] + zarr_dataset = worker_ctx["zarr_datasets"][segment_index] + if worker_ctx["zarr_timestamps_datasets"] is not None: + zarr_timestamps_dataset = worker_ctx["zarr_timestamps_datasets"][segment_index] + else: + zarr_timestamps_dataset = None + + # apply function + data = chunkable.get_data( + start_frame=start_frame, + end_frame=end_frame, + segment_index=segment_index, + ) + data = data.astype(dtype) + zarr_dataset[start_frame:end_frame, :] = data + + if zarr_timestamps_dataset is not None: + timestamps = chunkable.get_times(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) + zarr_timestamps_dataset[start_frame:end_frame] = timestamps + + # fix memory leak by forcing garbage collection + del data + gc.collect() + + +def get_random_sample_slices( + chunkable: ChunkableMixin, + method="full_random", + num_chunks_per_segment=20, + chunk_duration="500ms", + chunk_size=None, + margin_frames=0, + seed=None, +): + """ + Get random slice of a chunkable object across segments. + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to get random chunks from + method : "full_random" + The method used to get random slices. + * "full_random" : legacy method, used until version 0.101.0, there is no constrain on slices + and they can overlap. + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + chunk_duration : str | float | None, default "500ms" + The duration of each chunk in 's' or 'ms' + chunk_size : int | None + Size of a chunk in number of frames. This is used only if chunk_duration is None. + This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. + concatenated : bool, default: True + If True chunk are concatenated along time axis + seed : int, default: None + Random seed + margin_frames : int, default: 0 + Margin in number of frames to avoid edge effects + + Returns + ------- + chunk_list : np.array + Array of concatenate chunks per segment + + + """ + # TODO: if segment have differents length make another sampling that dependant on the length of the segment + # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY + # And randomize the number of chunk per segment weighted by segment duration + + if method == "full_random": + if chunk_size is None: + if chunk_duration is not None: + chunk_size = chunk_duration_to_chunk_size(chunk_duration, chunkable) + else: + raise ValueError("get_random_sample_slices need chunk_size or chunk_duration") + + # check chunk size + num_segments = chunkable.get_num_segments() + for segment_index in range(num_segments): + chunk_size_limit = chunkable.get_num_samples(segment_index) - 2 * margin_frames + if chunk_size > chunk_size_limit: + chunk_size = chunk_size_limit - 1 + warnings.warn( + f"chunk_size is greater than the number " + f"of samples for segment index {segment_index}. " + f"Using {chunk_size}." + ) + rng = np.random.default_rng(seed) + slices = [] + low = margin_frames + size = num_chunks_per_segment + for segment_index in range(num_segments): + num_frames = chunkable.get_num_samples(segment_index) + high = num_frames - chunk_size - margin_frames + # here we set endpoint to True, because the this represents the start of the + # chunk, and should be inclusive + random_starts = rng.integers(low=low, high=high, size=size, endpoint=True) + random_starts = np.sort(random_starts) + slices += [(segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts] + else: + raise ValueError(f"get_random_sample_slices : wrong method {method}") + + return slices + + +def get_chunks(chunkable: ChunkableMixin, concatenated=True, get_data_kwargs=None, **random_slices_kwargs): + """ + Extract random chunks across segments. + + Internally, it uses `get_random_sample_slices()` and retrieves the traces chunk as a list + or a concatenated unique array. + + Please read `get_random_sample_slices()` for more details on parameters. + + # TODO: handle this in recording tools: + return * will be get_data_kwargs + + Parameters + ---------- + chunkable : ChunkableMixin + The chunkable object to get random chunks from + return_scaled : bool | None, default: None + DEPRECATED. Use return_in_uV instead. + return_in_uV : bool, default: False + If True and the chunkable has scaling (gain_to_uV and offset_to_uV properties), + traces are scaled to uV + num_chunks_per_segment : int, default: 20 + Number of chunks per segment + concatenated : bool, default: True + If True chunk are concatenated along time axis + **random_slices_kwargs : dict + Options transmited to get_random_sample_slices(), please read documentation from this + function for more details. + + Returns + ------- + chunk_list : np.array | list of np.array + Array of concatenate chunks per segment + """ + slices = get_random_sample_slices(chunkable, **random_slices_kwargs) + + chunk_list = [] + get_data_kwargs = get_data_kwargs if get_data_kwargs is not None else {} + for segment_index, start_frame, end_frame in slices: + traces_chunk = chunkable.get_data( + start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, **get_data_kwargs + ) + chunk_list.append(traces_chunk) + + if concatenated: + return np.concatenate(chunk_list, axis=0) + else: + return chunk_list + + +def get_chunk_with_margin( + chunkable_segment: ChunkableSegment, + start_frame, + end_frame, + last_dimension_indices, + margin, + add_zeros=False, + add_reflect_padding=False, + window_on_margin=False, + dtype=None, +): + """ + Helper to get chunk with margin + + The margin is extracted from the recording when possible. If + at the edge of the recording, no margin is used unless one + of `add_zeros` or `add_reflect_padding` is True. In the first + case zero padding is used, in the second case np.pad is called + with mod="reflect". + """ + length = int(chunkable_segment.get_num_samples()) + + if last_dimension_indices is None: + last_dimension_indices = slice(None) + + if not (add_zeros or add_reflect_padding): + if window_on_margin and not add_zeros: + raise ValueError("window_on_margin requires add_zeros=True") + + if start_frame is None: + left_margin = 0 + start_frame = 0 + elif start_frame < margin: + left_margin = start_frame + else: + left_margin = margin + + if end_frame is None: + right_margin = 0 + end_frame = length + elif end_frame > (length - margin): + right_margin = length - end_frame + else: + right_margin = margin + + data_chunk = chunkable_segment.get_data( + start_frame - left_margin, + end_frame + right_margin, + last_dimension_indices, + ) + + else: + # either add_zeros or reflect_padding + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = length + + chunk_size = end_frame - start_frame + full_size = chunk_size + 2 * margin + + if start_frame < margin: + start_frame2 = 0 + left_pad = margin - start_frame + else: + start_frame2 = start_frame - margin + left_pad = 0 + + if end_frame > (length - margin): + end_frame2 = length + right_pad = end_frame + margin - length + else: + end_frame2 = end_frame + margin + right_pad = 0 + + data_chunk = chunkable_segment.get_data(start_frame2, end_frame2, last_dimension_indices) + + if dtype is not None or window_on_margin or left_pad > 0 or right_pad > 0: + need_copy = True + else: + need_copy = False + + left_margin = margin + right_margin = margin + + if need_copy: + if dtype is None: + dtype = data_chunk.dtype + + left_margin = margin + if end_frame < (length + margin): + right_margin = margin + else: + right_margin = end_frame + margin - length + + if add_zeros: + data_chunk2 = np.zeros((full_size, data_chunk.shape[1]), dtype=dtype) + i0 = left_pad + i1 = left_pad + data_chunk.shape[0] + data_chunk2[i0:i1, :] = data_chunk + if window_on_margin: + # apply inplace taper on border + taper = (1 - np.cos(np.arange(margin) / margin * np.pi)) / 2 + taper = taper[:, np.newaxis] + data_chunk2[:margin] *= taper + data_chunk2[-margin:] *= taper[::-1] + data_chunk = data_chunk2 + elif add_reflect_padding: + # in this case, we don't want to taper + data_chunk = np.pad( + data_chunk.astype(dtype, copy=False), + [(left_pad, right_pad)] + [(0, 0)] * (data_chunk.ndim - 1), + mode="reflect", + ) + else: + # we need a copy to change the dtype + data_chunk = np.asarray(data_chunk, dtype=dtype) + + return data_chunk, left_margin, right_margin diff --git a/src/spikeinterface/core/tests/test_chunkable_tools.py b/src/spikeinterface/core/tests/test_chunkable_tools.py new file mode 100644 index 0000000000..3d686166e7 --- /dev/null +++ b/src/spikeinterface/core/tests/test_chunkable_tools.py @@ -0,0 +1,174 @@ +import numpy as np + +from spikeinterface.core import generate_recording + +from spikeinterface.core.binaryrecordingextractor import BinaryRecordingExtractor +from spikeinterface.core.generate import NoiseGeneratorRecording + + +from spikeinterface.core.chunkable_tools import ( + write_binary, + write_memory, + get_random_sample_slices, + get_chunks, +) + + +def test_write_binary(tmp_path): + # Test write_binary() with loop (n_jobs=1) + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + + durations = [10.0] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=1) + write_binary(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) + + +def test_write_binary_offset(tmp_path): + # Test write_binary() with loop (n_jobs=1) + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + + durations = [10.0] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=1) + byte_offset = 125 + write_binary(recording, file_paths=file_paths, dtype=dtype, byte_offset=byte_offset, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, + sampling_frequency=sampling_frequency, + num_channels=num_channels, + dtype=dtype, + file_offset=byte_offset, + ) + assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) + + +def test_write_binary_parallel(tmp_path): + # Test write_binary() with parallel processing (n_jobs=2) + + # Setup + sampling_frequency = 30_000 + num_channels = 2 + dtype = "float32" + durations = [10.30, 3.5] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + dtype=dtype, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn") + write_binary(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + for segment_index in range(recording.get_num_segments()): + binary_traces = recorder_binary.get_traces(segment_index=segment_index) + recording_traces = recording.get_traces(segment_index=segment_index) + assert np.allclose(binary_traces, recording_traces) + + +def test_write_binary_multiple_segment(tmp_path): + # Test write_binary() with multiple segments (n_jobs=2) + # Setup + sampling_frequency = 30_000 + num_channels = 10 + dtype = "float32" + + durations = [10.30, 3.5] + recording = NoiseGeneratorRecording( + durations=durations, + num_channels=num_channels, + sampling_frequency=sampling_frequency, + strategy="tile_pregenerated", + ) + file_paths = [tmp_path / "binary01.raw", tmp_path / "binary02.raw"] + + # Write binary recording + job_kwargs = dict(n_jobs=2, chunk_memory="100k", mp_context="spawn") + write_binary(recording, file_paths=file_paths, dtype=dtype, verbose=False, **job_kwargs) + + # Check if written data matches original data + recorder_binary = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype + ) + + for segment_index in range(recording.get_num_segments()): + binary_traces = recorder_binary.get_traces(segment_index=segment_index) + recording_traces = recording.get_traces(segment_index=segment_index) + assert np.allclose(binary_traces, recording_traces) + + +def test_write_memory_recording(): + # 2 segments + recording = NoiseGeneratorRecording( + num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" + ) + recording = recording.save() + + # write with loop + traces_list, shms = write_memory(recording, dtype=None, verbose=True, n_jobs=1) + + traces_list, shms = write_memory( + recording, dtype=None, verbose=True, n_jobs=1, chunk_memory="100k", progress_bar=True + ) + + # write parallel + traces_list, shms = write_memory(recording, dtype=None, verbose=False, n_jobs=2, chunk_memory="100k") + # need to clean the buffer + del traces_list + for shm in shms: + shm.unlink() + + +def test_get_random_sample_slices(): + rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) + rec_slices = get_random_sample_slices( + rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 + ) + assert len(rec_slices) == 40 + for seg_ind, start, stop in rec_slices: + assert stop - start == 500 + assert seg_ind in (0, 1) + + +def test_get_chunks(): + rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) + chunks = get_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) + assert chunks.shape == (50000, 1) From e24251fd0891fb13559c360a24bdc2b580f63b7b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Mar 2026 15:58:16 +0100 Subject: [PATCH 6/7] Modify node pipeline to work with chunkable objects --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/baserecording.py | 10 +- src/spikeinterface/core/chunkable.py | 1 - src/spikeinterface/core/chunkable_tools.py | 3 +- src/spikeinterface/core/job_tools.py | 107 ++-- src/spikeinterface/core/node_pipeline.py | 132 ++--- src/spikeinterface/core/recording_tools.py | 484 +----------------- .../core/tests/test_job_tools.py | 24 +- .../core/tests/test_node_pipeline.py | 16 +- .../core/tests/test_recording_tools.py | 15 +- src/spikeinterface/core/waveform_tools.py | 18 +- src/spikeinterface/core/zarrextractors.py | 8 +- src/spikeinterface/exporters/to_ibl.py | 4 +- .../postprocessing/amplitude_scalings.py | 5 +- .../postprocessing/principal_component.py | 4 +- .../postprocessing/spike_amplitudes.py | 6 +- .../preprocessing/detect_artifacts.py | 4 +- .../sortingcomponents/clustering/main.py | 2 +- .../sortingcomponents/matching/base.py | 11 +- .../sortingcomponents/matching/circus.py | 4 +- .../sortingcomponents/matching/main.py | 2 +- .../sortingcomponents/matching/nearest.py | 4 +- .../sortingcomponents/matching/tdc_peeler.py | 8 +- .../sortingcomponents/matching/wobble.py | 2 +- .../motion/motion_interpolation.py | 2 +- .../peak_detection/by_channel.py | 9 +- .../peak_detection/iterative.py | 14 +- .../peak_detection/locally_exclusive.py | 17 +- .../peak_detection/matched_filtering.py | 6 +- .../tests/test_peak_detection.py | 12 +- .../peak_localization/base.py | 2 +- .../peak_localization/center_of_mass.py | 2 +- .../peak_localization/grid.py | 4 +- .../peak_localization/monopolar.py | 2 +- src/spikeinterface/sortingcomponents/tools.py | 14 +- .../waveforms/features_from_peaks.py | 6 +- .../waveforms/hanning_filter.py | 2 +- .../waveforms/neural_network_denoiser.py | 2 +- .../waveforms/savgol_denoiser.py | 2 +- .../waveforms/temporal_pca.py | 6 +- .../waveforms/tests/test_temporal_pca.py | 2 +- .../waveforms/waveform_thresholder.py | 2 +- 42 files changed, 272 insertions(+), 710 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 168494caf7..66b2e0ef10 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -95,7 +95,7 @@ get_best_job_kwargs, ensure_n_jobs, ensure_chunk_size, - ChunkRecordingExecutor, + ChunkExecutor, split_job_kwargs, fix_job_kwargs, ) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 1f95d04781..651208776e 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -464,7 +464,7 @@ def _remove_channels(self, remove_channel_ids): sub_recording = ChannelSliceRecording(self, new_channel_ids) return sub_recording - def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRecording: + def frame_slice(self, start_frame: int | None, end_frame: int | None) -> "BaseRecording": """ Returns a new recording with sliced frames. Note that this operation is not in place. @@ -486,7 +486,7 @@ def frame_slice(self, start_frame: int | None, end_frame: int | None) -> BaseRec sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame) return sub_recording - def time_slice(self, start_time: float | None, end_time: float | None) -> BaseRecording: + def time_slice(self, start_time: float | None, end_time: float | None) -> "BaseRecording": """ Returns a new recording object, restricted to the time interval [start_time, end_time]. @@ -544,7 +544,7 @@ def _select_segments(self, segment_indices): def get_channel_locations( self, channel_ids: list | np.ndarray | tuple | None = None, - axes: "xy" | "yz" | "xz" | "xyz" = "xy", + axes: Literal["xy", "yz", "xz", "xyz"] = "xy", ) -> np.ndarray: """ Get the physical locations of specified channels. @@ -664,7 +664,9 @@ def get_traces( # must be implemented in subclass raise NotImplementedError - def get_data(self, start_frame: int, end_frame: int, indices: list | np.array | tuple | None = None) -> np.ndarray: + def get_data( + self, start_frame: int, end_frame: int, indices: list | np.ndarray | tuple | None = None + ) -> np.ndarray: """ General retrieval function for chunkable objects """ diff --git a/src/spikeinterface/core/chunkable.py b/src/spikeinterface/core/chunkable.py index c3329a10e3..e7cddbc65e 100644 --- a/src/spikeinterface/core/chunkable.py +++ b/src/spikeinterface/core/chunkable.py @@ -1,4 +1,3 @@ -from __future__ import annotations from abc import ABC, abstractmethod from typing import Optional import warnings diff --git a/src/spikeinterface/core/chunkable_tools.py b/src/spikeinterface/core/chunkable_tools.py index 558536332d..160bdd74fc 100644 --- a/src/spikeinterface/core/chunkable_tools.py +++ b/src/spikeinterface/core/chunkable_tools.py @@ -1,4 +1,3 @@ -from __future__ import annotations from pathlib import Path import warnings @@ -549,7 +548,7 @@ def get_chunks(chunkable: ChunkableMixin, concatenated=True, get_data_kwargs=Non Returns ------- - chunk_list : np.array | list of np.array + chunk_list : np.ndarray | list of np.array Array of concatenate chunks per segment """ slices = get_random_sample_slices(chunkable, **random_slices_kwargs) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 4bb1356769..1fab85e673 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -6,7 +6,6 @@ import platform import os import warnings -from spikeinterface.core.core_tools import convert_string_to_bytes, convert_bytes_to_str, convert_seconds_to_str import sys from tqdm.auto import tqdm @@ -16,6 +15,8 @@ import threading from threadpoolctl import threadpool_limits +from spikeinterface.core.core_tools import convert_string_to_bytes, convert_bytes_to_str, convert_seconds_to_str + _shared_job_kwargs_doc = """**job_kwargs : keyword arguments for parallel processing: * chunk_duration or chunk_size or chunk_memory or total_memory - chunk_size : int @@ -204,16 +205,16 @@ def divide_segment_into_chunks(num_frames, chunk_size): return chunks -def divide_recording_into_chunks(recording, chunk_size): - recording_slices = [] +def divide_chunkable_into_chunks(recording, chunk_size): + slices = [] for segment_index in range(recording.get_num_segments()): num_frames = recording.get_num_samples(segment_index) chunks = divide_segment_into_chunks(num_frames, chunk_size) - recording_slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) - return recording_slices + slices.extend([(segment_index, frame_start, frame_stop) for frame_start, frame_stop in chunks]) + return slices -def ensure_n_jobs(recording, n_jobs=1): +def ensure_n_jobs(extractor, n_jobs=1): if n_jobs == -1: n_jobs = os.cpu_count() elif n_jobs == 0: @@ -231,19 +232,19 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.check_if_memory_serializable(): + if not extractor.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( - "Recording is not serializable to memory and can't be processed in parallel. " + "Extractor is not serializable to memory and can't be processed in parallel. " "You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1." ) return n_jobs -def chunk_duration_to_chunk_size(chunk_duration, recording): +def chunk_duration_to_chunk_size(chunk_duration, chunkable: "ChunkableMixin"): if isinstance(chunk_duration, float): - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + chunk_size = int(chunk_duration * chunkable.get_sampling_frequency()) elif isinstance(chunk_duration, str): if chunk_duration.endswith("ms"): chunk_duration = float(chunk_duration.replace("ms", "")) / 1000.0 @@ -251,17 +252,23 @@ def chunk_duration_to_chunk_size(chunk_duration, recording): chunk_duration = float(chunk_duration.replace("s", "")) else: raise ValueError("chunk_duration must ends with s or ms") - chunk_size = int(chunk_duration * recording.get_sampling_frequency()) + chunk_size = int(chunk_duration * chunkable.get_sampling_frequency()) else: raise ValueError("chunk_duration must be str or float") return chunk_size def ensure_chunk_size( - recording, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs + chunkable: "ChunkableMixin", + total_memory=None, + chunk_size=None, + chunk_memory=None, + chunk_duration=None, + n_jobs=1, + **other_kwargs, ): """ - "chunk_size" is the traces.shape[0] for each worker. + "chunk_size" is the number of samples for each worker. Flexible chunk_size setter with 3 ways: * "chunk_size" : is the length in sample for each chunk independently of channel count and dtype. @@ -292,24 +299,20 @@ def ensure_chunk_size( assert total_memory is None # set by memory per worker size chunk_memory = convert_string_to_bytes(chunk_memory) - n_bytes = np.dtype(recording.get_dtype()).itemsize - num_channels = recording.get_num_channels() - chunk_size = int(chunk_memory / (num_channels * n_bytes)) + chunk_size = int(chunk_memory / chunkable.get_sample_size_in_bytes()) elif total_memory is not None: # clip by total memory size - n_jobs = ensure_n_jobs(recording, n_jobs=n_jobs) + n_jobs = ensure_n_jobs(chunkable, n_jobs=n_jobs) total_memory = convert_string_to_bytes(total_memory) - n_bytes = np.dtype(recording.get_dtype()).itemsize - num_channels = recording.get_num_channels() - chunk_size = int(total_memory / (num_channels * n_bytes * n_jobs)) + chunk_size = int(total_memory / (chunkable.get_sample_size_in_bytes() * n_jobs)) elif chunk_duration is not None: - chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) + chunk_size = chunk_duration_to_chunk_size(chunk_duration, chunkable) else: # Edge case to define single chunk per segment for n_jobs=1. # All chunking parameters equal None mean single chunk per segment if n_jobs == 1: - num_segments = recording.get_num_segments() - samples_in_larger_segment = max([recording.get_num_samples(segment) for segment in range(num_segments)]) + num_segments = chunkable.get_num_segments() + samples_in_larger_segment = max([chunkable.get_num_samples(segment) for segment in range(num_segments)]) chunk_size = samples_in_larger_segment else: raise ValueError("For n_jobs >1 you must specify total_memory or chunk_size or chunk_memory") @@ -317,9 +320,9 @@ def ensure_chunk_size( return chunk_size -class ChunkRecordingExecutor: +class ChunkExecutor: """ - Core class for parallel processing to run a "function" over chunks on a recording. + Core class for parallel processing to run a "function" over chunks on a chunkable extractor. It supports running a function: * in loop with chunk processing (low RAM usage) @@ -331,8 +334,8 @@ class ChunkRecordingExecutor: Parameters ---------- - recording : RecordingExtractor - The recording to be processed + chunkable : ChunkableMixin + The chunkable object to be processed. func : function Function that runs on each chunk init_func : function @@ -380,7 +383,7 @@ class ChunkRecordingExecutor: def __init__( self, - recording, + chunkable: "ChunkableMixin", func, init_func, init_args, @@ -399,7 +402,7 @@ def __init__( max_threads_per_worker=1, need_worker_index=False, ): - self.recording = recording + self.chunkable = chunkable self.func = func self.init_func = init_func self.init_args = init_args @@ -418,7 +421,7 @@ def __init__( else: mp_context = "spawn" - preferred_mp_context = recording.get_preferred_mp_context() + preferred_mp_context = chunkable.get_preferred_mp_context() if preferred_mp_context is not None and preferred_mp_context != mp_context: warnings.warn( f"Your processing chain using pool_engine='process' and mp_context='{mp_context}' is not possible." @@ -434,9 +437,8 @@ def __init__( self.handle_returns = handle_returns self.gather_func = gather_func - self.n_jobs = ensure_n_jobs(recording, n_jobs=n_jobs) - self.chunk_size = ensure_chunk_size( - recording, + self.n_jobs = ensure_n_jobs(self.chunkable, n_jobs=n_jobs) + self.chunk_size = self.ensure_chunk_size( total_memory=total_memory, chunk_size=chunk_size, chunk_memory=chunk_memory, @@ -451,9 +453,9 @@ def __init__( self.need_worker_index = need_worker_index if verbose: - chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize + chunk_memory = self.get_chunk_memory() total_memory = chunk_memory * self.n_jobs - chunk_duration = self.chunk_size / recording.get_sampling_frequency() + chunk_duration = self.chunk_size / chunkable.sampling_frequency chunk_memory_str = convert_bytes_to_str(chunk_memory) total_memory_str = convert_bytes_to_str(total_memory) chunk_duration_str = convert_seconds_to_str(chunk_duration) @@ -468,13 +470,24 @@ def __init__( f"chunk_duration={chunk_duration_str}", ) - def run(self, recording_slices=None): + def get_chunk_memory(self): + return self.chunk_size * self.chunkable.get_sample_size_in_bytes() + + def ensure_chunk_size( + self, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None, n_jobs=1, **other_kwargs + ): + return ensure_chunk_size( + self.chunkable, total_memory, chunk_size, chunk_memory, chunk_duration, n_jobs, **other_kwargs + ) + + def run(self, slices=None): """ Runs the defined jobs. """ - if recording_slices is None: - recording_slices = divide_recording_into_chunks(self.recording, self.chunk_size) + if slices is None: + # TODO: rename + slices = divide_chunkable_into_chunks(self.chunkable, self.chunk_size) if self.handle_returns: returns = [] @@ -483,9 +496,7 @@ def run(self, recording_slices=None): if self.n_jobs == 1: if self.progress_bar: - recording_slices = tqdm( - recording_slices, desc=f"{self.job_name} (no parallelization)", total=len(recording_slices) - ) + slices = tqdm(slices, desc=f"{self.job_name} (no parallelization)", total=len(slices)) init_args = self.init_args if self.need_worker_index: @@ -496,7 +507,7 @@ def run(self, recording_slices=None): if self.need_worker_index: worker_dict["worker_index"] = worker_index - for segment_index, frame_start, frame_stop in recording_slices: + for segment_index, frame_start, frame_stop in slices: res = self.func(segment_index, frame_start, frame_stop, worker_dict) if self.handle_returns: returns.append(res) @@ -504,7 +515,7 @@ def run(self, recording_slices=None): self.gather_func(res) else: - n_jobs = min(self.n_jobs, len(recording_slices)) + n_jobs = min(self.n_jobs, len(slices)) if self.pool_engine == "process": @@ -534,13 +545,13 @@ def run(self, recording_slices=None): array_pid, ), ) as executor: - results = executor.map(process_function_wrapper, recording_slices) + results = executor.map(process_function_wrapper, slices) if self.progress_bar: results = tqdm( results, desc=f"{self.job_name} (workers: {n_jobs} processes {self.mp_context})", - total=len(recording_slices), + total=len(slices), ) for res in results: @@ -559,7 +570,7 @@ def run(self, recording_slices=None): if self.progress_bar: # here the tqdm threading do not work (maybe collision) so we need to create a pbar # before thread spawning - pbar = tqdm(desc=f"{self.job_name} (workers: {n_jobs} threads)", total=len(recording_slices)) + pbar = tqdm(desc=f"{self.job_name} (workers: {n_jobs} threads)", total=len(slices)) if self.need_worker_index: lock = threading.Lock() @@ -580,8 +591,8 @@ def run(self, recording_slices=None): ), ) as executor: - recording_slices2 = [(thread_local_data,) + tuple(args) for args in recording_slices] - results = executor.map(thread_function_wrapper, recording_slices2) + slices2 = [(thread_local_data,) + tuple(args) for args in slices] + results = executor.map(thread_function_wrapper, slices2) for res in results: if self.progress_bar: diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 43cdd30c87..a91a4909b0 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -11,8 +11,9 @@ import numpy as np from spikeinterface.core.base import base_peak_dtype, spike_peak_dtype +from spikeinterface.core.chunkable import ChunkableMixin from spikeinterface.core import BaseRecording, get_chunk_with_margin -from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs, _shared_job_kwargs_doc +from spikeinterface.core.job_tools import ChunkExecutor, fix_job_kwargs, _shared_job_kwargs_doc from spikeinterface.core import get_channel_distances @@ -24,7 +25,7 @@ class PipelineNode: def __init__( self, - recording: BaseRecording, + chunkable: ChunkableMixin, return_output: bool | tuple[bool] = True, parents: list[Type["PipelineNode"]] | None = None, ): @@ -36,8 +37,8 @@ def __init__( Parameters ---------- - recording : BaseRecording - The recording object. + chunkable : ChunkableMixin + The chunkable object. return_output : bool or tuple[bool], default: True Whether or not the output of the node is returned by the pipeline. When a Node have several toutputs then this can be a tuple of bool @@ -45,7 +46,7 @@ def __init__( Pass parents nodes to perform a previous computation. """ - self.recording = recording + self.chunkable = chunkable self.return_output = return_output if isinstance(parents, str): # only one parents is allowed @@ -54,14 +55,14 @@ def __init__( self._kwargs = dict() - def get_trace_margin(self): + def get_data_margin(self): # can optionaly be overwritten return 0 def get_dtype(self): raise NotImplementedError - def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *args): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin, *args): raise NotImplementedError @@ -76,7 +77,7 @@ class PeakSource(PipelineNode): # between processes or threads need_first_call_before_pipeline = False - def get_trace_margin(self): + def get_data_margin(self): raise NotImplementedError def get_dtype(self): @@ -93,7 +94,7 @@ def get_peak_slice( def _first_call_before_pipeline(self): # see need_first_call_before_pipeline = True - margin = self.get_trace_margin() + margin = self.get_data_margin() traces = self.recording.get_traces(start_frame=0, end_frame=margin * 2 + 1, segment_index=0) self.compute(traces, 0, margin * 2 + 1, 0, margin) @@ -116,7 +117,7 @@ def __init__(self, recording, peaks): i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) - def get_trace_margin(self): + def get_data_margin(self): return 0 def get_dtype(self): @@ -128,7 +129,7 @@ def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) return i0, i1 - def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin, peak_slice): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] @@ -154,6 +155,9 @@ class SpikeRetriever(PeakSource): * compute_spike_amplitudes() * compute_principal_components() + Parameters + ---------- + sorting : BaseSorting The sorting object. recording : BaseRecording @@ -208,7 +212,7 @@ def __init__( i0, i1 = np.searchsorted(self.peaks["segment_index"], [segment_index, segment_index + 1]) self.segment_slices.append(slice(i0, i1)) - def get_trace_margin(self): + def get_data_margin(self): return 0 def get_dtype(self): @@ -225,7 +229,7 @@ def get_peak_slice(self, segment_index, start_frame, end_frame, max_margin): i0, i1 = np.searchsorted(peaks_in_segment["sample_index"], [start_frame, end_frame]) return i0, i1 - def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peak_slice): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin, peak_slice): # get local peaks sl = self.segment_slices[segment_index] peaks_in_segment = self.peaks[sl] @@ -242,14 +246,14 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea local_peaks["in_margin"][:] = False mask = local_peaks["sample_index"] < max_margin local_peaks["in_margin"][mask] = True - mask = local_peaks["sample_index"] >= traces.shape[0] - max_margin + mask = local_peaks["sample_index"] >= chunk.shape[0] - max_margin local_peaks["in_margin"][mask] = True if not self.channel_from_template: # handle channel spike per spike for i, peak in enumerate(local_peaks): chans = np.flatnonzero(self.neighbours_mask[peak["channel_index"]]) - sparse_wfs = traces[peak["sample_index"], chans] + sparse_wfs = chunk[peak["sample_index"], chans] if self.peak_sign == "neg": local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)] elif self.peak_sign == "pos": @@ -259,7 +263,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea # handle amplitude for i, peak in enumerate(local_peaks): - local_peaks["amplitude"][i] = traces[peak["sample_index"], peak["channel_index"]] + local_peaks["amplitude"][i] = chunk[peak["sample_index"], peak["channel_index"]] return (local_peaks,) @@ -311,7 +315,8 @@ def __init__( Whether or not the output of the node is returned by the pipeline """ - PipelineNode.__init__(self, recording=recording, parents=parents, return_output=return_output) + PipelineNode.__init__(self, recording, parents=parents, return_output=return_output) + self.recording = recording self.ms_before = ms_before self.ms_after = ms_after self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0) @@ -350,18 +355,18 @@ def __init__( WaveformsNode.__init__( self, - recording=recording, + recording, parents=parents, ms_before=ms_before, ms_after=ms_after, return_output=return_output, ) - def get_trace_margin(self): + def get_data_margin(self): return max(self.nbefore, self.nafter) - def compute(self, traces, peaks): - waveforms = traces[peaks["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)] + def compute(self, chunk, peaks): + waveforms = chunk[peaks["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)] return waveforms @@ -407,7 +412,7 @@ def __init__( """ WaveformsNode.__init__( self, - recording=recording, + recording, parents=parents, ms_before=ms_before, ms_after=ms_after, @@ -425,15 +430,15 @@ def __init__( self.neighbours_mask = self.channel_distance <= radius_um self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1)) - def get_trace_margin(self): + def get_data_margin(self): return max(self.nbefore, self.nafter) - def compute(self, traces, peaks): - sparse_wfs = np.zeros((peaks.shape[0], self.nbefore + self.nafter, self.max_num_chans), dtype=traces.dtype) + def compute(self, chunk, peaks): + sparse_wfs = np.zeros((peaks.shape[0], self.nbefore + self.nafter, self.max_num_chans), dtype=chunk.dtype) for i, peak in enumerate(peaks): (chans,) = np.nonzero(self.neighbours_mask[peak["channel_index"]]) - sparse_wfs[i, :, : len(chans)] = traces[ + sparse_wfs[i, :, : len(chans)] = chunk[ peak["sample_index"] - self.nbefore : peak["sample_index"] + self.nafter, : ][:, chans] @@ -500,12 +505,11 @@ def check_graph(nodes, check_for_peak_source=True): Check that node list is orderd in a good (parents are before children) """ - if check_for_peak_source: - node0 = nodes[0] - if not isinstance(node0, PeakSource): - raise ValueError( - "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" - ) + node0 = nodes[0] + if not isinstance(node0, PeakSource) and check_for_peak_source: + raise ValueError( + "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" + ) for i, node in enumerate(nodes): assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode" @@ -521,19 +525,19 @@ def check_graph(nodes, check_for_peak_source=True): def run_node_pipeline( - recording, - nodes, - job_kwargs, - job_name="pipeline", - gather_mode="memory", - gather_kwargs={}, - squeeze_output=True, - folder=None, - names=None, - verbose=False, - skip_after_n_peaks=None, - recording_slices=None, - check_for_peak_source=True, + chunkable: ChunkableMixin, + nodes: list[PipelineNode], + job_kwargs: dict, + job_name: str = "pipeline", + gather_mode: str = "memory", + gather_kwargs: dict = {}, + squeeze_output: bool = True, + folder: str | None = None, + names: list[str] | None = None, + verbose: bool = False, + skip_after_n_peaks: int | None = None, + slices: list[tuple] | None = None, + check_for_peak_source: bool = False, ): """ Machinery to compute in parallel operations on peaks and traces. @@ -561,11 +565,12 @@ def run_node_pipeline( Parameters ---------- - - recording: Recording - + chunkable: ChunkableMixin + The chunkable object to run the pipeline on. This is typically a recording but it can be anything that have the + same interface for getting chunks with margin. nodes: a list of PipelineNode - + The list of nodes to run in the pipeline. The order of the nodes is important as it defines + the order of computation. job_kwargs: dict The classical job_kwargs job_name : str @@ -585,12 +590,12 @@ def run_node_pipeline( skip_after_n_peaks : None | int Skip the computation after n_peaks. This is not an exact because internally this skip is done per worker in average. - recording_slices : None | list[tuple] + slices : None | list[tuple] Optionaly give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). If None (default), the function iterates over the entire duration of the recording. - check_for_peak_source : bool, default True - Whether to check that the first node is a PeakSource (PeakDetector or PeakRetriever or + check_for_peak_source : bool, default False + Whether to check the graph of PeakSource nodes. Returns ------- @@ -598,7 +603,6 @@ def run_node_pipeline( a tuple of vector for the output of nodes having return_output=True. If squeeze_output=True and only one output then directly np.array. """ - check_graph(nodes, check_for_peak_source=check_for_peak_source) job_kwargs = fix_job_kwargs(job_kwargs) @@ -621,10 +625,10 @@ def run_node_pipeline( # See need_first_call_before_pipeline : this trigger numba compilation before the run node0._first_call_before_pipeline() - init_args = (recording, nodes, skip_after_n_peaks_per_worker) + init_args = (chunkable, nodes, skip_after_n_peaks_per_worker) - processor = ChunkRecordingExecutor( - recording, + processor = ChunkExecutor( + chunkable, _compute_peak_pipeline_chunk, _init_peak_pipeline, init_args, @@ -634,30 +638,30 @@ def run_node_pipeline( **job_kwargs, ) - processor.run(recording_slices=recording_slices) + processor.run(slices=slices) outs = gather_func.finalize_buffers(squeeze_output=squeeze_output) return outs -def _init_peak_pipeline(recording, nodes, skip_after_n_peaks_per_worker): +def _init_peak_pipeline(chunkable, nodes, skip_after_n_peaks_per_worker): # create a local dict per worker worker_ctx = {} - worker_ctx["recording"] = recording + worker_ctx["chunkable"] = chunkable worker_ctx["nodes"] = nodes - worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes) + worker_ctx["max_margin"] = max(node.get_data_margin() for node in nodes) worker_ctx["skip_after_n_peaks_per_worker"] = skip_after_n_peaks_per_worker worker_ctx["num_peaks"] = 0 return worker_ctx def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_ctx): - recording = worker_ctx["recording"] + chunkable = worker_ctx["chunkable"] max_margin = worker_ctx["max_margin"] nodes = worker_ctx["nodes"] skip_after_n_peaks_per_worker = worker_ctx["skip_after_n_peaks_per_worker"] - recording_segment = recording.segments[segment_index] + chunkable_segment = chunkable.segments[segment_index] retrievers = find_parents_of_type(nodes, (SpikeRetriever, PeakRetriever)) # get peak slices once for all retrievers peak_slice_by_retriever = {} @@ -678,7 +682,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c if load_trace_and_compute: traces_chunk, left_margin, right_margin = get_chunk_with_margin( - recording_segment, start_frame, end_frame, None, max_margin, add_zeros=True + chunkable_segment, start_frame, end_frame, None, max_margin, add_zeros=True ) # compute the graph pipeline_outputs = {} @@ -693,7 +697,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c # to handle compatibility peak detector is a special case # with specific margin # TODO later when in master: change this later - extra_margin = max_margin - node.get_trace_margin() + extra_margin = max_margin - node.get_data_margin() if extra_margin: trace_detection = traces_chunk[extra_margin:-extra_margin] else: diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 48eb2d7fd4..084cac6aba 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -3,7 +3,6 @@ import warnings from pathlib import Path import os -import mmap import tqdm import numpy.typing as npt @@ -12,15 +11,19 @@ from .core_tools import add_suffix, make_shared_array from .job_tools import ( ensure_chunk_size, - ensure_n_jobs, divide_segment_into_chunks, fix_job_kwargs, - ChunkRecordingExecutor, + ChunkExecutor, _shared_job_kwargs_doc, - chunk_duration_to_chunk_size, split_job_kwargs, ) +from .chunkable_tools import get_random_sample_slices, get_chunks, get_chunk_with_margin + +# for back-compatibility imports +from .chunkable_tools import write_binary as write_binary_recording +from .chunkable_tools import write_memory as write_memory_recording + def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0): """ @@ -52,124 +55,11 @@ def read_binary_recording(file, num_channels, dtype, time_axis=0, offset=0): return samples -# used by write_binary_recording + ChunkRecordingExecutor -def _init_binary_worker(recording, file_path_dict, dtype, byte_offest): - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["byte_offset"] = byte_offest - worker_ctx["dtype"] = np.dtype(dtype) - - file_dict = {segment_index: open(file_path, "rb+") for segment_index, file_path in file_path_dict.items()} - worker_ctx["file_dict"] = file_dict - - return worker_ctx - - -def write_binary_recording( - recording: "BaseRecording", - file_paths: list[Path | str] | Path | str, - dtype: npt.DTypeLike | None = None, - add_file_extension: bool = True, - byte_offset: int = 0, - verbose: bool = False, - **job_kwargs, -): - """ - Save the trace of a recording extractor in several binary .dat format. - - Note : - time_axis is always 0 (contrary to previous version. - to get time_axis=1 (which is a bad idea) use `write_binary_recording_file_handle()` - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object to be saved in .dat format - file_path : str or list[str] - The path to the file. - dtype : dtype or None, default: None - Type of the saved data - add_file_extension, bool, default: True - If True, and the file path does not end in "raw", "bin", or "dat" then "raw" is added as an extension. - byte_offset : int, default: 0 - Offset in bytes for the binary file (e.g. to write a header). This is useful in case you want to append data - to an existing file where you wrote a header or other data before. - verbose : bool - This is the verbosity of the ChunkRecordingExecutor - {} - """ - job_kwargs = fix_job_kwargs(job_kwargs) - - file_path_list = [file_paths] if not isinstance(file_paths, list) else file_paths - num_segments = recording.get_num_segments() - if len(file_path_list) != num_segments: - raise ValueError("'file_paths' must be a list of the same size as the number of segments in the recording") - - file_path_list = [Path(file_path) for file_path in file_path_list] - if add_file_extension: - file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list] - - dtype = dtype if dtype is not None else recording.get_dtype() - - dtype_size_bytes = np.dtype(dtype).itemsize - num_channels = recording.get_num_channels() - - file_path_dict = {segment_index: file_path for segment_index, file_path in enumerate(file_path_list)} - for segment_index, file_path in file_path_dict.items(): - num_frames = recording.get_num_frames(segment_index=segment_index) - data_size_bytes = dtype_size_bytes * num_frames * num_channels - file_size_bytes = data_size_bytes + byte_offset - - # Create an empty file with file_size_bytes - with open(file_path, "wb+") as file: - # The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408) - file.seek(file_size_bytes - 1) - file.write(b"\0") - - assert Path(file_path).is_file() - - # use executor (loop or workers) - func = _write_binary_chunk - init_func = _init_binary_worker - init_args = (recording, file_path_dict, dtype, byte_offset) - executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name="write_binary_recording", verbose=verbose, **job_kwargs - ) - executor.run() - - -# used by write_binary_recording + ChunkRecordingExecutor -def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): - # recover variables of the worker - recording = worker_ctx["recording"] - dtype = worker_ctx["dtype"] - byte_offset = worker_ctx["byte_offset"] - file = worker_ctx["file_dict"][segment_index] - - num_channels = recording.get_num_channels() - dtype_size_bytes = np.dtype(dtype).itemsize - - # Calculate byte offsets for the start frames relative to the entire recording - start_byte = byte_offset + start_frame * num_channels * dtype_size_bytes - - traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) - traces = traces.astype(dtype, order="c", copy=False) - - file.seek(start_byte) - file.write(traces.data) - # flush is important!! - file.flush() - - -write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc) - - -def write_binary_recording_file_handle( +def write_binary_file_handle( recording, file_handle=None, time_axis=0, dtype=None, byte_offset=0, verbose=False, **job_kwargs ): """ - Old variant version of write_binary_recording with one file handle. + Old variant version of write_binary with one file handle. Can be useful in some case ??? Not used anymore at the moment. @@ -209,115 +99,6 @@ def write_binary_recording_file_handle( file_handle.write(traces.tobytes()) -# used by write_memory_recording -def _init_memory_worker(recording, arrays, shm_names, shapes, dtype): - # create a local dict per worker - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["dtype"] = np.dtype(dtype) - - if arrays is None: - # create it from share memory name - from multiprocessing.shared_memory import SharedMemory - - arrays = [] - # keep shm alive - worker_ctx["shms"] = [] - for i in range(len(shm_names)): - shm = SharedMemory(shm_names[i]) - worker_ctx["shms"].append(shm) - arr = np.ndarray(shape=shapes[i], dtype=dtype, buffer=shm.buf) - arrays.append(arr) - - worker_ctx["arrays"] = arrays - - return worker_ctx - - -# used by write_memory_recording -def _write_memory_chunk(segment_index, start_frame, end_frame, worker_ctx): - # recover variables of the worker - recording = worker_ctx["recording"] - dtype = worker_ctx["dtype"] - arr = worker_ctx["arrays"][segment_index] - - # apply function - traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) - traces = traces.astype(dtype, copy=False) - arr[start_frame:end_frame, :] = traces - - -def write_memory_recording(recording, dtype=None, verbose=False, buffer_type="auto", **job_kwargs): - """ - Save the traces into numpy arrays (memory). - try to use the SharedMemory introduce in py3.8 if n_jobs > 1 - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor object to be saved in .dat format - dtype : dtype, default: None - Type of the saved data - verbose : bool, default: False - If True, output is verbose (when chunks are used) - buffer_type : "auto" | "numpy" | "sharedmem" - {} - - Returns - --------- - arrays : one array per segment - """ - job_kwargs = fix_job_kwargs(job_kwargs) - - if dtype is None: - dtype = recording.get_dtype() - - # create sharedmmep - arrays = [] - shm_names = [] - shms = [] - shapes = [] - - n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) - if buffer_type == "auto": - if n_jobs > 1: - buffer_type = "sharedmem" - else: - buffer_type = "numpy" - - for segment_index in range(recording.get_num_segments()): - num_frames = recording.get_num_samples(segment_index) - num_channels = recording.get_num_channels() - shape = (num_frames, num_channels) - shapes.append(shape) - if buffer_type == "sharedmem": - arr, shm = make_shared_array(shape, dtype) - shm_names.append(shm.name) - shms.append(shm) - else: - arr = np.zeros(shape, dtype=dtype) - shms.append(None) - arrays.append(arr) - - # use executor (loop or workers) - func = _write_memory_chunk - init_func = _init_memory_worker - if n_jobs > 1: - init_args = (recording, None, shm_names, shapes, dtype) - else: - init_args = (recording, arrays, None, None, dtype) - - executor = ChunkRecordingExecutor( - recording, func, init_func, init_args, verbose=verbose, job_name="write_memory_recording", **job_kwargs - ) - executor.run() - - return arrays, shms - - -write_memory_recording.__doc__ = write_memory_recording.__doc__.format(_shared_job_kwargs_doc) - - def write_to_h5_dataset_format( recording, dataset_path, @@ -458,101 +239,16 @@ def write_to_h5_dataset_format( return save_path -def get_random_recording_slices( - recording, - method="full_random", - num_chunks_per_segment=20, - chunk_duration="500ms", - chunk_size=None, - margin_frames=0, - seed=None, -): - """ - Get random slice of a recording across segments. - - This is used for instance in get_noise_levels() and get_random_data_chunks() to estimate noise on traces. - - Parameters - ---------- - recording : BaseRecording - The recording to get random chunks from - method : "full_random" - The method used to get random slices. - * "full_random" : legacy method, used until version 0.101.0, there is no constrain on slices - and they can overlap. - num_chunks_per_segment : int, default: 20 - Number of chunks per segment - chunk_duration : str | float | None, default "500ms" - The duration of each chunk in 's' or 'ms' - chunk_size : int | None - Size of a chunk in number of frames. This is used only if chunk_duration is None. - This is kept for backward compatibility, you should prefer 'chunk_duration=500ms' instead. - concatenated : bool, default: True - If True chunk are concatenated along time axis - seed : int, default: None - Random seed - margin_frames : int, default: 0 - Margin in number of frames to avoid edge effects - - Returns - ------- - chunk_list : np.array - Array of concatenate chunks per segment - - - """ - # TODO: if segment have differents length make another sampling that dependant on the length of the segment - # Should be done by changing kwargs with total_num_chunks=XXX and total_duration=YYYY - # And randomize the number of chunk per segment weighted by segment duration - - if method == "full_random": - if chunk_size is None: - if chunk_duration is not None: - chunk_size = chunk_duration_to_chunk_size(chunk_duration, recording) - else: - raise ValueError("get_random_recording_slices need chunk_size or chunk_duration") - - # check chunk size - num_segments = recording.get_num_segments() - for segment_index in range(num_segments): - chunk_size_limit = recording.get_num_frames(segment_index) - 2 * margin_frames - if chunk_size > chunk_size_limit: - chunk_size = chunk_size_limit - 1 - warnings.warn( - f"chunk_size is greater than the number " - f"of samples for segment index {segment_index}. " - f"Using {chunk_size}." - ) - rng = np.random.default_rng(seed) - recording_slices = [] - low = margin_frames - size = num_chunks_per_segment - for segment_index in range(num_segments): - num_frames = recording.get_num_frames(segment_index) - high = num_frames - chunk_size - margin_frames - # here we set endpoint to True, because the this represents the start of the - # chunk, and should be inclusive - random_starts = rng.integers(low=low, high=high, size=size, endpoint=True) - random_starts = np.sort(random_starts) - recording_slices += [ - (segment_index, start_frame, (start_frame + chunk_size)) for start_frame in random_starts - ] - else: - raise ValueError(f"get_random_recording_slices : wrong method {method}") - - return recording_slices - - def get_random_data_chunks( recording, return_scaled=None, return_in_uV=False, concatenated=True, **random_slices_kwargs ): """ Extract random chunks across segments. - Internally, it uses `get_random_recording_slices()` and retrieves the traces chunk as a list + Internally, it uses `get_random_sample_slices()` and retrieves the traces chunk as a list or a concatenated unique array. - Please read `get_random_recording_slices()` for more details on parameters. + Please read `get_random_sample_slices()` for more details on parameters. Parameters @@ -569,7 +265,7 @@ def get_random_data_chunks( concatenated : bool, default: True If True chunk are concatenated along time axis **random_slices_kwargs : dict - Options transmited to get_random_recording_slices(), please read documentation from this + Options transmited to get_random_sample_slices(), please read documentation from this function for more details. Returns @@ -586,22 +282,12 @@ def get_random_data_chunks( ) return_in_uV = return_scaled - recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) - - chunk_list = [] - for segment_index, start_frame, end_frame in recording_slices: - traces_chunk = recording.get_traces( - start_frame=start_frame, - end_frame=end_frame, - segment_index=segment_index, - return_in_uV=return_in_uV, - ) - chunk_list.append(traces_chunk) - - if concatenated: - return np.concatenate(chunk_list, axis=0) - else: - return chunk_list + return get_chunks( + recording, + concatenated=concatenated, + get_data_kwargs=dict(return_in_uV=return_in_uV), + **random_slices_kwargs, + ) def get_channel_distances(recording): @@ -718,7 +404,7 @@ def get_noise_levels( force_recompute : bool If True, noise levels are recomputed even if they are already stored in the recording extractor random_slices_kwargs : dict - Options transmited to get_random_recording_slices(), please read documentation from this + Options transmitted to get_random_sample_slices(), please read documentation from this function for more details. {} @@ -753,7 +439,7 @@ def get_noise_levels( msg = ( "get_noise_levels(recording, num_chunks_per_segment=20) is deprecated\n" "Now, you need to use get_noise_levels(recording, random_slices_kwargs=dict(num_chunks_per_segment=20, chunk_size=1000))\n" - "Please read get_random_recording_slices() documentation for more options." + "Please read get_random_sample_slices() documentation for more options." ) # if the user use both the old and the new behavior then an error is raised assert len(random_slices_kwargs) == 0, msg @@ -762,7 +448,7 @@ def get_noise_levels( if "chunk_size" in job_kwargs: random_slices_kwargs["chunk_size"] = job_kwargs["chunk_size"] - recording_slices = get_random_recording_slices(recording, **random_slices_kwargs) + slices = get_random_sample_slices(recording, **random_slices_kwargs) noise_levels_chunks = [] @@ -772,7 +458,7 @@ def append_noise_chunk(res): func = _noise_level_chunk init_func = _noise_level_chunk_init init_args = (recording, return_in_uV, method) - executor = ChunkRecordingExecutor( + executor = ChunkExecutor( recording, func, init_func, @@ -782,7 +468,7 @@ def append_noise_chunk(res): gather_func=append_noise_chunk, **job_kwargs, ) - executor.run(recording_slices=recording_slices) + executor.run(slices=slices) noise_levels_chunks = np.stack(noise_levels_chunks) noise_levels = np.mean(noise_levels_chunks, axis=0) @@ -795,130 +481,6 @@ def append_noise_chunk(res): get_noise_levels.__doc__ = get_noise_levels.__doc__.format(_shared_job_kwargs_doc) -def get_chunk_with_margin( - rec_segment, - start_frame, - end_frame, - channel_indices, - margin, - add_zeros=False, - add_reflect_padding=False, - window_on_margin=False, - dtype=None, -): - """ - Helper to get chunk with margin - - The margin is extracted from the recording when possible. If - at the edge of the recording, no margin is used unless one - of `add_zeros` or `add_reflect_padding` is True. In the first - case zero padding is used, in the second case np.pad is called - with mod="reflect". - """ - length = int(rec_segment.get_num_samples()) - - if channel_indices is None: - channel_indices = slice(None) - - if not (add_zeros or add_reflect_padding): - if window_on_margin and not add_zeros: - raise ValueError("window_on_margin requires add_zeros=True") - - if start_frame is None: - left_margin = 0 - start_frame = 0 - elif start_frame < margin: - left_margin = start_frame - else: - left_margin = margin - - if end_frame is None: - right_margin = 0 - end_frame = length - elif end_frame > (length - margin): - right_margin = length - end_frame - else: - right_margin = margin - - traces_chunk = rec_segment.get_traces( - start_frame - left_margin, - end_frame + right_margin, - channel_indices, - ) - - else: - # either add_zeros or reflect_padding - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = length - - chunk_size = end_frame - start_frame - full_size = chunk_size + 2 * margin - - if start_frame < margin: - start_frame2 = 0 - left_pad = margin - start_frame - else: - start_frame2 = start_frame - margin - left_pad = 0 - - if end_frame > (length - margin): - end_frame2 = length - right_pad = end_frame + margin - length - else: - end_frame2 = end_frame + margin - right_pad = 0 - - traces_chunk = rec_segment.get_traces(start_frame2, end_frame2, channel_indices) - - if dtype is not None or window_on_margin or left_pad > 0 or right_pad > 0: - need_copy = True - else: - need_copy = False - - left_margin = margin - right_margin = margin - - if need_copy: - if dtype is None: - dtype = traces_chunk.dtype - - left_margin = margin - if end_frame < (length + margin): - right_margin = margin - else: - right_margin = end_frame + margin - length - - if add_zeros: - traces_chunk2 = np.zeros((full_size, traces_chunk.shape[1]), dtype=dtype) - i0 = left_pad - i1 = left_pad + traces_chunk.shape[0] - traces_chunk2[i0:i1, :] = traces_chunk - if window_on_margin: - # apply inplace taper on border - taper = (1 - np.cos(np.arange(margin) / margin * np.pi)) / 2 - taper = taper[:, np.newaxis] - traces_chunk2[:margin] *= taper - traces_chunk2[-margin:] *= taper[::-1] - # enforce non writable when original was not - # (this help numba to have the same signature and not compile twice) - traces_chunk2.flags.writeable = traces_chunk.flags.writeable - traces_chunk = traces_chunk2 - elif add_reflect_padding: - # in this case, we don't want to taper - traces_chunk = np.pad( - traces_chunk.astype(dtype, copy=False), - [(left_pad, right_pad), (0, 0)], - mode="reflect", - ) - else: - # we need a copy to change the dtype - traces_chunk = np.asarray(traces_chunk, dtype=dtype) - - return traces_chunk, left_margin, right_margin - - def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), flip=False): """ Order channels by depth, by first ordering the x-axis, and then the y-axis. diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 7a2accc887..7e635922db 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -9,10 +9,10 @@ divide_segment_into_chunks, ensure_n_jobs, ensure_chunk_size, - ChunkRecordingExecutor, + ChunkExecutor, fix_job_kwargs, split_job_kwargs, - divide_recording_into_chunks, + divide_chunkable_into_chunks, ) @@ -71,7 +71,7 @@ def test_ensure_chunk_size(): # Test edge case to define single chunk for n_jobs=1 chunk_size = ensure_chunk_size(recording, n_jobs=1, chunk_size=None) - chunks = divide_recording_into_chunks(recording, chunk_size) + chunks = divide_chunkable_into_chunks(recording, chunk_size) assert len(chunks) == recording.get_num_segments() for chunk in chunks: segment_index, start_frame, end_frame = chunk @@ -96,13 +96,13 @@ def init_func(arg1, arg2, arg3): return worker_dict -def test_ChunkRecordingExecutor(): +def test_ChunkExecutor(): recording = generate_recording(num_channels=2) init_args = "a", 120, "yep" # no chunk - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, verbose=True, progress_bar=False, n_jobs=1, chunk_size=None ) processor.run() @@ -113,7 +113,7 @@ def gathering_result(res): pass # chunk + loop + gather_func - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -139,7 +139,7 @@ def __call__(self, res): gathering_func2 = GatherClass() # process + gather_func - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -153,12 +153,12 @@ def __call__(self, res): job_name="job_name", ) processor.run() - num_chunks = len(divide_recording_into_chunks(recording, processor.chunk_size)) + num_chunks = len(divide_chunkable_into_chunks(recording, processor.chunk_size)) assert gathering_func2.pos == num_chunks # process spawn - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -174,7 +174,7 @@ def __call__(self, res): processor.run() # thread - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, @@ -258,7 +258,7 @@ def test_worker_index(): # making this 2 times ensure to test that global variables are correctly reset for pool_engine in ("process", "thread"): # print(pool_engine) - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func2, init_func2, @@ -322,7 +322,7 @@ def test_get_best_job_kwargs(): # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() - # test_ChunkRecordingExecutor() + # test_ChunkExecutor() # test_fix_job_kwargs() # test_split_job_kwargs() test_worker_index() diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 4f8e600a3f..90e430a0db 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -5,7 +5,7 @@ from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording from spikeinterface.core.base import spike_peak_dtype -from spikeinterface.core.job_tools import divide_recording_into_chunks +from spikeinterface.core.job_tools import divide_chunkable_into_chunks # from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ( @@ -27,12 +27,12 @@ def __init__(self, recording, parents=None, return_output=True, param0=5.5): def get_dtype(self): return self._dtype - def compute(self, traces, peaks): + def compute(self, chunk, peaks): amps = np.zeros(peaks.size, dtype=self._dtype) amps["abs_amplitude"] = np.abs(peaks["amplitude"]) return amps - def get_trace_margin(self): + def get_data_margin(self): return 5 @@ -44,7 +44,7 @@ def __init__(self, recording, return_output=True, parents=None): def get_dtype(self): return np.dtype("float32") - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): kernel = np.array([0.1, 0.8, 0.1]) denoised_waveforms = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=waveforms) return denoised_waveforms @@ -57,7 +57,7 @@ def __init__(self, recording, return_output=True, parents=None): def get_dtype(self): return np.dtype("float32") - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): rms_by_channels = np.sum(waveforms**2, axis=1) return rms_by_channels @@ -220,11 +220,9 @@ def test_skip_after_n_peaks_and_recording_slices(): assert some_amplitudes.size < spikes.size # slices : 1 every 4 - recording_slices = divide_recording_into_chunks(recording, 10_000) + recording_slices = divide_chunkable_into_chunks(recording, 10_000) recording_slices = recording_slices[::4] - some_amplitudes = run_node_pipeline( - recording, nodes, job_kwargs, gather_mode="memory", recording_slices=recording_slices - ) + some_amplitudes = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory", slices=recording_slices) tolerance = 1.2 assert some_amplitudes.size < (spikes.size // 4) * tolerance diff --git a/src/spikeinterface/core/tests/test_recording_tools.py b/src/spikeinterface/core/tests/test_recording_tools.py index 405a2ecccf..7a327ea44f 100644 --- a/src/spikeinterface/core/tests/test_recording_tools.py +++ b/src/spikeinterface/core/tests/test_recording_tools.py @@ -11,7 +11,6 @@ from spikeinterface.core.recording_tools import ( write_binary_recording, write_memory_recording, - get_random_recording_slices, get_random_data_chunks, get_chunk_with_margin, get_closest_channels, @@ -168,17 +167,6 @@ def test_write_memory_recording(): shm.unlink() -def test_get_random_recording_slices(): - rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) - rec_slices = get_random_recording_slices( - rec, method="full_random", num_chunks_per_segment=20, chunk_duration="500ms", margin_frames=0, seed=0 - ) - assert len(rec_slices) == 40 - for seg_ind, start, stop in rec_slices: - assert stop - start == 500 - assert seg_ind in (0, 1) - - def test_get_random_data_chunks(): rec = generate_recording(num_channels=1, sampling_frequency=1000.0, durations=[10.0, 20.0]) chunks = get_random_data_chunks(rec, num_chunks_per_segment=50, chunk_size=500, seed=0) @@ -366,9 +354,8 @@ def test_do_recording_attributes_match(): # test_write_binary_recording(tmp_path) # test_write_memory_recording() - test_get_random_recording_slices() # test_get_random_data_chunks() # test_get_closest_channels() # test_get_noise_levels() # test_get_noise_levels_output() - # test_order_channels_by_depth() + test_order_channels_by_depth() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 58aac7faf2..f69e640efa 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -17,7 +17,7 @@ from spikeinterface.core.baserecording import BaseRecording from .baserecording import BaseRecording -from .job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc +from .job_tools import ChunkExecutor, _shared_job_kwargs_doc from .core_tools import make_shared_array from .job_tools import fix_job_kwargs @@ -294,16 +294,14 @@ def distribute_waveforms_to_buffers( ) if job_name is None: job_name = f"extract waveforms {mode} multi buffer" - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs - ) + processor = ChunkExecutor(recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs) processor.run() distribute_waveforms_to_buffers.__doc__ = distribute_waveforms_to_buffers.__doc__.format(_shared_job_kwargs_doc) -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _init_worker_distribute_buffers( recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_in_uV, inds_by_unit, mode, sparsity_mask ): @@ -350,7 +348,7 @@ def _init_worker_distribute_buffers( return worker_dict -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker recording = worker_dict["recording"] @@ -563,7 +561,7 @@ def extract_waveforms_to_single_buffer( if job_name is None: job_name = f"extract waveforms {mode} mono buffer" - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs ) processor.run() @@ -620,7 +618,7 @@ def _init_worker_distribute_single_buffer( return worker_dict -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker recording = worker_dict["recording"] @@ -948,7 +946,7 @@ def estimate_templates_with_accumulator( if job_name is None: job_name = "estimate_templates_with_accumulator" - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, job_name=job_name, verbose=verbose, need_worker_index=True, **job_kwargs ) processor.run() @@ -1035,7 +1033,7 @@ def _init_worker_estimate_templates( return worker_dict -# used by ChunkRecordingExecutor +# used by ChunkExecutor def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker recording = worker_dict["recording"] diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 8f266e0123..b4be717ed0 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -560,7 +560,7 @@ def add_traces_to_zarr( from .job_tools import ( ensure_chunk_size, fix_job_kwargs, - ChunkRecordingExecutor, + ChunkExecutor, ) assert dataset_paths is not None, "Provide 'file_path'" @@ -597,13 +597,13 @@ def add_traces_to_zarr( func = _write_zarr_chunk init_func = _init_zarr_worker init_args = (recording, zarr_datasets, dtype) - executor = ChunkRecordingExecutor( + executor = ChunkExecutor( recording, func, init_func, init_args, verbose=verbose, job_name="write_zarr_recording", **job_kwargs ) executor.run() -# used by write_zarr_recording + ChunkRecordingExecutor +# used by write_zarr_recording + ChunkExecutor def _init_zarr_worker(recording, zarr_datasets, dtype): import zarr @@ -616,7 +616,7 @@ def _init_zarr_worker(recording, zarr_datasets, dtype): return worker_ctx -# used by write_zarr_recording + ChunkRecordingExecutor +# used by write_zarr_recording + ChunkExecutor def _write_zarr_chunk(segment_index, start_frame, end_frame, worker_ctx): import gc diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 8f18536daa..faef0d9560 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -7,7 +7,7 @@ import numpy as np from spikeinterface.core import SortingAnalyzer, BaseRecording, get_random_data_chunks -from spikeinterface.core.job_tools import fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc +from spikeinterface.core.job_tools import fix_job_kwargs, ChunkExecutor, _shared_job_kwargs_doc from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.exporters import export_to_phy @@ -258,7 +258,7 @@ def compute_rms( func = _compute_rms_chunk init_func = _init_rms_worker init_args = (recording,) - executor = ChunkRecordingExecutor( + executor = ChunkExecutor( recording, func, init_func, diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 310be8cceb..b78e60e94e 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -230,9 +230,10 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks): + def compute(self, chunk, peaks): from scipy.stats import linregress + traces = chunk # scale traces with margin to match scaling of templates if self._gains is not None: traces = traces.astype("float32") * self._gains + self._offsets @@ -330,7 +331,7 @@ def compute(self, traces, peaks): # TODO: switch to collision mask and return that (to use concatenation) return (scalings, spike_collision_mask) - def get_trace_margin(self): + def get_data_margin(self): return self._margin diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index bb48a08e64..bce6c8e6a4 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -11,7 +11,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs +from spikeinterface.core.job_tools import ChunkExecutor, _shared_job_kwargs_doc, fix_job_kwargs from spikeinterface.core.analyzer_extension_core import _inplace_sparse_realign_waveforms @@ -412,7 +412,7 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): unit_channels, pca_model, ) - processor = ChunkRecordingExecutor( + processor = ChunkExecutor( recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs ) processor.run() diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 0495e2c56e..c837e52f58 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -90,11 +90,13 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks): + def compute(self, chunk, peaks): sample_indices = peaks["sample_index"].copy() unit_index = peaks["unit_index"] chan_inds = peaks["channel_index"] + traces = chunk + # apply shifts per spike sample_indices += self._peak_shifts[unit_index] @@ -110,5 +112,5 @@ def compute(self, traces, peaks): return amplitudes - def get_trace_margin(self): + def get_data_margin(self): return self._margin diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 92b07b8f35..d40e043e20 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -114,7 +114,7 @@ def __init__( else: self.diff_threshold_unscaled = None - def get_trace_margin(self) -> int: + def get_data_margin(self) -> int: """Return the number of margin samples required on each side of a chunk.""" return 0 @@ -326,7 +326,7 @@ def __init__( # internal dtype self._dtype = np.dtype([("sample_index", "int64"), ("segment_index", "int64"), ("front", "bool")]) - def get_trace_margin(self) -> int: + def get_data_margin(self) -> int: """Return the number of margin samples required on each side of a chunk.""" return 0 diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index b262d36dab..b091734c1c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -26,7 +26,7 @@ def find_clusters_from_peaks( verbose : Bool, default: False If True, output is verbose job_kwargs : dict - Parameters for ChunkRecordingExecutor + Parameters for ChunkExecutor {method_doc} diff --git a/src/spikeinterface/sortingcomponents/matching/base.py b/src/spikeinterface/sortingcomponents/matching/base.py index 88a6522148..483d6f867d 100644 --- a/src/spikeinterface/sortingcomponents/matching/base.py +++ b/src/spikeinterface/sortingcomponents/matching/base.py @@ -20,21 +20,22 @@ def __init__(self, recording, templates, return_output=True): templates, Templates ), f"The templates supplied is of type {type(templates)} and must be a Templates" self.templates = templates + self.recording = recording PeakDetector.__init__(self, recording, return_output=return_output, parents=None) def get_dtype(self): return np.dtype(_base_matching_dtype) - def get_trace_margin(self): + def get_data_margin(self): raise NotImplementedError - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - spikes = self.compute_matching(traces, start_frame, end_frame, segment_index) + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): + spikes = self.compute_matching(chunk, start_frame, end_frame, segment_index) spikes["segment_index"] = segment_index - margin = self.get_trace_margin() + margin = self.get_data_margin() if margin > 0 and spikes.size > 0: - keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (traces.shape[0] - margin)) + keep = (spikes["sample_index"] >= margin) & (spikes["sample_index"] < (chunk.shape[0] - margin)) spikes = spikes[keep] # node pipeline need to return a tuple diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 8a19ad458b..96c1d725ba 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -322,7 +322,7 @@ def get_extra_outputs(self): output[key] = getattr(self, key) return output - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): @@ -709,7 +709,7 @@ def _prepare_templates(self): self.circus_templates = templates_array - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 116aff416e..2ec8b0a0fa 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -39,7 +39,7 @@ def find_spikes_from_templates( verbose : Bool, default: False If True, output is verbose job_kwargs : dict - Parameters for ChunkRecordingExecutor + Parameters for ChunkExecutor {method_doc} diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 3e0eb0b632..15b5327219 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -94,7 +94,7 @@ def __init__( self.lookup_tables["templates"][i] = np.flatnonzero(self.neighborhood_mask[i]) self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_mask[i]) - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): @@ -191,7 +191,7 @@ def __init__( projected_temporal_templates = self.svd_model.transform(temporal_templates) self.svd_templates = from_temporal_representation(projected_temporal_templates, self.num_channels) - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index 947eaf391f..f870fed243 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -318,12 +318,12 @@ def __init__( # noise_levels=None, ) - self.detector_margin0 = self.fast_spike_detector.get_trace_margin() - self.detector_margin1 = self.fine_spike_detector.get_trace_margin() if use_fine_detector else 0 + self.detector_margin0 = self.fast_spike_detector.get_data_margin() + self.detector_margin1 = self.fine_spike_detector.get_data_margin() if use_fine_detector else 0 self.peeler_margin = max(self.nbefore, self.nafter) * 2 self.margin = max(self.peeler_margin, self.detector_margin0, self.detector_margin1) - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): @@ -505,7 +505,7 @@ def _find_spikes_one_level(self, traces, spikes_prev_loop, use_fine_detector, le peak_detector = self.fast_spike_detector # print('peak_detector', peak_detector) - detector_margin = peak_detector.get_trace_margin() + detector_margin = peak_detector.get_data_margin() if self.peeler_margin > detector_margin: margin_shift = self.peeler_margin - detector_margin diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 5245f3230d..9e5402b399 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -492,7 +492,7 @@ def _push_to_torch(self): self.template_data.compressed_templates = (temporal, singular, spatial, temporal_jittered) self.is_pushed = True - def get_trace_margin(self): + def get_data_margin(self): return self.margin def compute_matching(self, traces, start_frame, end_frame, segment_index): diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 7c4c4b166e..a433eeb643 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -223,7 +223,7 @@ def interpolate_motion_on_traces( # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing - # in ChunkRecordingExecutor) + # in ChunkExecutor) np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin]) current_start_index = next_start_index diff --git a/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py b/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py index 732ada21bc..37ccffad3b 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py @@ -56,11 +56,11 @@ def __init__( self.peak_sign = peak_sign self.detect_threshold = detect_threshold - def get_trace_margin(self): + def get_data_margin(self): return self.exclude_sweep_size - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): + traces = chunk traces_center = traces[self.exclude_sweep_size : -self.exclude_sweep_size, :] length = traces_center.shape[0] @@ -162,7 +162,8 @@ def __init__( self.device = device self.return_tensor = return_tensor - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): + traces = chunk peak_sample_ind, peak_chan_ind, peak_amplitude = _torch_detect_peaks( traces, self.peak_sign, self.abs_thresholds, self.exclude_sweep_size, None, self.device ) diff --git a/src/spikeinterface/sortingcomponents/peak_detection/iterative.py b/src/spikeinterface/sortingcomponents/peak_detection/iterative.py index 63bce6d921..4547319934 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/iterative.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/iterative.py @@ -56,7 +56,7 @@ def __init__( self.num_iterations = num_iterations self.tresholds = tresholds - def get_trace_margin(self) -> int: + def get_data_margin(self) -> int: """ Calculate the maximum trace margin from the internal pipeline. Using the strategy as use by the Node pipeline @@ -68,10 +68,10 @@ def get_trace_margin(self) -> int: The maximum trace margin. """ internal_pipeline = (self.peak_detector_node, self.waveform_extraction_node, self.waveform_denoising_node) - pipeline_margin = (node.get_trace_margin() for node in internal_pipeline if hasattr(node, "get_trace_margin")) + pipeline_margin = [node.get_data_margin() for node in internal_pipeline] return max(pipeline_margin) - def compute(self, traces_chunk, start_frame, end_frame, segment_index, max_margin) -> Tuple[np.ndarray, np.ndarray]: + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin) -> Tuple[np.ndarray, np.ndarray]: """ Perform the iterative peak detection, waveform extraction, and denoising. @@ -94,7 +94,7 @@ def compute(self, traces_chunk, start_frame, end_frame, segment_index, max_margi A tuple containing a single ndarray with the detected peaks. """ - traces_chunk = np.array(traces_chunk, copy=True, dtype="float32") + traces_chunk = np.array(chunk, copy=True, dtype="float32") local_peaks_list = [] all_waveforms = [] @@ -110,7 +110,7 @@ def compute(self, traces_chunk, start_frame, end_frame, segment_index, max_margi ) (local_peaks,) = self.peak_detector_node.compute( - traces=traces_chunk, + traces_chunk, start_frame=start_frame, end_frame=end_frame, segment_index=segment_index, @@ -124,9 +124,9 @@ def compute(self, traces_chunk, start_frame, end_frame, segment_index, max_margi if local_peaks.size == 0: break - waveforms = self.waveform_extraction_node.compute(traces=traces_chunk, peaks=local_peaks) + waveforms = self.waveform_extraction_node.compute(traces_chunk, peaks=local_peaks) denoised_waveforms = self.waveform_denoising_node.compute( - traces=traces_chunk, peaks=local_peaks, waveforms=waveforms + traces_chunk, peaks=local_peaks, waveforms=waveforms ) self.substract_waveforms_from_traces( diff --git a/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py b/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py index e0c8a29cfd..ff075092e6 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py @@ -64,6 +64,7 @@ def __init__( assert peak_sign in ("both", "neg", "pos") assert noise_levels is not None + self.recording = recording self.noise_levels = noise_levels self.abs_thresholds = self.noise_levels * detect_threshold @@ -83,13 +84,13 @@ def __init__( self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance <= radius_um - def get_trace_margin(self): + def get_data_margin(self): # the +1 in the border is important because we need peak in the border return self.exclude_sweep_size + 1 - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): assert HAVE_NUMBA, "You need to install numba" - + traces = chunk peak_sample_ind, peak_chan_ind = detect_peaks_numba_locally_exclusive_on_chunk( traces, self.peak_sign, self.abs_thresholds, self.exclude_sweep_size, self.neighbours_mask ) @@ -238,12 +239,14 @@ def __init__( for i, neigh in enumerate(self.neighbour_indices_by_chan): self.neighbours_idxs[i, : len(neigh)] = neigh - def get_trace_margin(self): + def get_data_margin(self): return self.exclude_sweep_size - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): from .by_channel import _torch_detect_peaks + traces = chunk + peak_sample_ind, peak_chan_ind, peak_amplitude = _torch_detect_peaks( traces, self.peak_sign, self.abs_thresholds, self.exclude_sweep_size, self.neighbours_idxs, self.device ) @@ -291,7 +294,9 @@ def __init__( self.abs_thresholds, self.exclude_sweep_size, self.neighbours_mask, self.peak_sign, **opencl_context_kwargs ) - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): + traces = chunk + peak_sample_ind, peak_chan_ind = self.executor.detect_peak(traces) peak_sample_ind += self.exclude_sweep_size peak_amplitude = traces[peak_sample_ind, peak_chan_ind] diff --git a/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py b/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py index 509c3f76f8..e7d9081929 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py @@ -98,12 +98,12 @@ def __init__( def get_dtype(self): return self._dtype - def get_trace_margin(self): + def get_data_margin(self): return self.exclude_sweep_size + self.conv_margin + 1 - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin): assert HAVE_NUMBA, "You need to install numba" + traces = chunk conv_traces = self.get_convolved_traces(traces) conv_traces = conv_traces[:, self.conv_margin : -self.conv_margin] conv_traces = conv_traces.reshape(self.num_z_factors, self.num_templates, conv_traces.shape[1]) diff --git a/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py index e0d2a49b9d..1a5afa9612 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py @@ -152,7 +152,7 @@ def test_iterative_peak_detection(recording, job_kwargs, pca_model_folder_path, return_output=(True, True), ) - peaks, waveforms = run_node_pipeline(recording=recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) + peaks, waveforms = run_node_pipeline(recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) # Assert there is a field call iteration in structured array peaks assert "iteration" in peaks.dtype.names assert peaks.shape[0] == waveforms.shape[0] @@ -197,7 +197,7 @@ def test_iterative_peak_detection_sparse(recording, job_kwargs, pca_model_folder return_output=(True, True), ) - peaks, waveforms = run_node_pipeline(recording=recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) + peaks, waveforms = run_node_pipeline(recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) # Assert there is a field call iteration in structured array peaks assert "iteration" in peaks.dtype.names assert peaks.shape[0] == waveforms.shape[0] @@ -239,7 +239,7 @@ def test_iterative_peak_detection_thresholds(recording, job_kwargs, pca_model_fo tresholds=tresholds, ) - peaks, waveforms = run_node_pipeline(recording=recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) + peaks, waveforms = run_node_pipeline(recording, nodes=[iterative_peak_detector], job_kwargs=job_kwargs) # Assert there is a field call iteration in structured array peaks assert "iteration" in peaks.dtype.names assert peaks.shape[0] == waveforms.shape[0] @@ -435,15 +435,15 @@ def test_peak_sign_consistency(recording, job_kwargs, detection_class): kwargs["peak_sign"] = "neg" peak_detection_node = detection_class(**kwargs) - negative_peaks = run_node_pipeline(recording=recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) + negative_peaks = run_node_pipeline(recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) kwargs["peak_sign"] = "pos" peak_detection_node = detection_class(**kwargs) - positive_peaks = run_node_pipeline(recording=recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) + positive_peaks = run_node_pipeline(recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) kwargs["peak_sign"] = "both" peak_detection_node = detection_class(**kwargs) - all_peaks = run_node_pipeline(recording=recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) + all_peaks = run_node_pipeline(recording, nodes=[peak_detection_node], job_kwargs=job_kwargs) # To account for exclusion of positive peaks that are to close to negative peaks. # This should be excluded by the detection method when is exclusive so using peak_sign="both" should diff --git a/src/spikeinterface/sortingcomponents/peak_localization/base.py b/src/spikeinterface/sortingcomponents/peak_localization/base.py index 17853c85aa..5e6a16fe4c 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/base.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/base.py @@ -9,7 +9,7 @@ class LocalizeBase(PipelineNode): def __init__(self, recording, parents, return_output=True, radius_um=75.0): PipelineNode.__init__(self, recording, parents=parents, return_output=return_output) - + self.recording = recording self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) diff --git a/src/spikeinterface/sortingcomponents/peak_localization/center_of_mass.py b/src/spikeinterface/sortingcomponents/peak_localization/center_of_mass.py index 9f868c3cd7..ffa97c144f 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/center_of_mass.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/center_of_mass.py @@ -42,7 +42,7 @@ def __init__(self, recording, parents, return_output=True, radius_um=75.0, featu self.nbefore = waveform_extractor.nbefore self._kwargs.update(dict(feature=feature)) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): peak_locations = np.zeros(peaks.size, dtype=self._dtype) for main_chan in np.unique(peaks["channel_index"]): diff --git a/src/spikeinterface/sortingcomponents/peak_localization/grid.py b/src/spikeinterface/sortingcomponents/peak_localization/grid.py index e39773d66e..43d1324a9e 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/grid.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/grid.py @@ -61,7 +61,7 @@ def __init__( peak_sign="neg", weight_method={}, ): - PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) + LocalizeBase.__init__(self, recording, return_output=return_output, parents=parents) self.radius_um = radius_um self.margin_um = margin_um @@ -120,7 +120,7 @@ def __init__( ) ) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): peak_locations = np.zeros(peaks.size, dtype=self._dtype) nb_weights = self.weights.shape[0] diff --git a/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py b/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py index 8840a5a00d..d9b75e265a 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py @@ -80,7 +80,7 @@ def __init__( self._dtype = np.dtype(dtype_localize_by_method["monopolar_triangulation"]) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): peak_locations = np.zeros(peaks.size, dtype=self._dtype) for i, peak in enumerate(peaks): diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 73a14bdee7..67850f241f 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -191,23 +191,15 @@ def get_prototype_and_waveforms_from_recording( nodes = [node0, node1] - recording_slices = get_shuffled_recording_slices(recording, job_kwargs=job_kwargs, seed=seed) - # res = detect_peaks( - # recording, - # pipeline_nodes=pipeline_nodes, - # skip_after_n_peaks=n_peaks, - # recording_slices=recording_slices, - # method="locally_exclusive", - # method_kwargs=detection_kwargs, - # job_kwargs=job_kwargs, - # ) + slices = get_shuffled_recording_slices(recording, job_kwargs=job_kwargs, seed=seed) + res = run_node_pipeline( recording, nodes, job_kwargs, job_name="get protoype waveforms", skip_after_n_peaks=n_peaks, - recording_slices=recording_slices, + slices=slices, ) rng = np.random.default_rng(seed) diff --git a/src/spikeinterface/sortingcomponents/waveforms/features_from_peaks.py b/src/spikeinterface/sortingcomponents/waveforms/features_from_peaks.py index 1695f2cc59..ab546f7460 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/waveforms/features_from_peaks.py @@ -97,7 +97,7 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): if self.all_channels: if self.peak_sign == "neg": amplitudes = np.min(waveforms, axis=1) @@ -131,7 +131,7 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): if self.all_channels: all_ptps = np.ptp(waveforms, axis=1) else: @@ -182,7 +182,7 @@ def __init__( def get_dtype(self): return self._dtype - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): all_projections = np.zeros((peaks.size, self.projections.shape[1]), dtype=self._dtype) for main_chan in np.unique(peaks["channel_index"]): diff --git a/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py b/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py index c6d1070e6d..ecb2edba39 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py +++ b/src/spikeinterface/sortingcomponents/waveforms/hanning_filter.py @@ -42,6 +42,6 @@ def __init__( self.hanning = hanning[:, None] self._kwargs.update(dict()) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): denoised_waveforms = waveforms * self.hanning return denoised_waveforms diff --git a/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py b/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py index d094bae3e0..257c35f860 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py +++ b/src/spikeinterface/sortingcomponents/waveforms/neural_network_denoiser.py @@ -91,7 +91,7 @@ def load_model(self): return denoiser - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): num_channels = waveforms.shape[2] # Collapse channels and transform to torch tensor diff --git a/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py b/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py index 1ed9e4bffa..b9cb6ec1ab 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py +++ b/src/spikeinterface/sortingcomponents/waveforms/savgol_denoiser.py @@ -49,7 +49,7 @@ def __init__( self.order = min(self.order, self.window_length - 1) self._kwargs.update(dict(order=order, window_length_ms=window_length_ms)) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): # Denoise import scipy.signal diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index 0170038c96..9b642126b6 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -216,7 +216,7 @@ def __init__( self.n_components = self.pca_model.n_components self.dtype = np.dtype(dtype) - def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) -> np.ndarray: + def compute(self, chunk: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) -> np.ndarray: """ Projects the waveforms using the PCA model trained in the fit method or loaded from the model_folder_path. @@ -285,7 +285,7 @@ def __init__( model_folder_path=model_folder_path, ) - def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) -> np.ndarray: + def compute(self, chunk: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) -> np.ndarray: """ Projects the waveforms using the PCA model trained in the fit method or loaded from the model_folder_path. @@ -374,7 +374,7 @@ def __init__( # this is the final sparse channel count self.out_num_channels = max(np.sum(self.final_sparsity_mask, axis=1)) - def compute(self, traces, start_frame, end_frame, segment_index, max_margin, peaks, waveforms) -> np.ndarray: + def compute(self, chunk, start_frame, end_frame, segment_index, max_margin, peaks, waveforms) -> np.ndarray: """ Projects the waveforms using the PCA model trained in the fit method or loaded from the model_folder_path. diff --git a/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py index 286c103d22..dc01ab3431 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py @@ -179,7 +179,7 @@ def test_pca_projection_sparsity(generated_recording, detected_peaks, model_path def test_initialization_with_wrong_parents_failure(generated_recording, model_path_of_trained_pca): recording = generated_recording model_folder_path = model_path_of_trained_pca - dummy_parent = PipelineNode(recording=recording) + dummy_parent = PipelineNode(recording) extract_waveforms = ExtractSparseWaveforms( recording=recording, ms_before=1, ms_after=1, radius_um=40, return_output=True ) diff --git a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py index ec223d0047..2191c6dd2a 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py +++ b/src/spikeinterface/sortingcomponents/waveforms/waveform_thresholder.py @@ -74,7 +74,7 @@ def __init__( dict(feature=feature, threshold=threshold, operator=operator, noise_levels=self.noise_levels) ) - def compute(self, traces, peaks, waveforms): + def compute(self, chunk, peaks, waveforms): if self.feature == "ptp": wf_data = np.ptp(waveforms, axis=1) / self.noise_levels elif self.feature == "mean": From 34f6bab05bb65b8b71dc8288b5f2f947d06aa934 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Mar 2026 09:21:15 +0100 Subject: [PATCH 7/7] add segments to base recording and sorting for typing --- src/spikeinterface/core/baserecording.py | 5 +++++ src/spikeinterface/core/basesorting.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 322e1c7547..f23b524271 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -168,6 +168,11 @@ def __sub__(self, other): return SubtractRecordings(self, other) + @property + def segments(self) -> list["BaseRecordingSegment"]: + """List of recording segments.""" + return self._segments + def add_recording_segment(self, recording_segment: "BaseRecordingSegment") -> None: """Adds a recording segment. diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 5afa8ac495..cb68f3d455 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -60,6 +60,11 @@ def _repr_html_(self, display_name=True): html_repr = html_header + html_unit_ids + html_extra return html_repr + @property + def segments(self) -> list["BaseSortingSegment"]: + """List of sorting segments.""" + return self._segments + @property def unit_ids(self): return self._main_ids