diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index f1c88e0f15..8b5b6de3b8 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -11,7 +11,6 @@ from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel class ClusteringBenchmark(Benchmark): diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index be1cf18fbf..eae0bf0e59 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -5,8 +5,6 @@ from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy -from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel from pathlib import Path @@ -33,7 +31,8 @@ def test_benchmark_clustering(create_cache_folder): # sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False) # sorting_analyzer.compute(["random_spikes", "templates"]) - extremum_channel_inds = get_template_extremum_channel(gt_analyzer, outputs="index") + extremum_channel_inds = gt_analyzer.get_main_channels(outputs="index", with_dict=True) + spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py index b9207caaa3..82a51e8292 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_detection.py @@ -6,7 +6,6 @@ from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.core.template_tools import get_template_extremum_channel @pytest.mark.skip() @@ -30,7 +29,7 @@ def test_benchmark_peak_detection(create_cache_folder): sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("templates", **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 168494caf7..decf034877 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -163,9 +163,13 @@ # template tools from .template_tools import ( get_template_amplitudes, - get_template_extremum_channel, - get_template_extremum_channel_peak_shift, - get_template_extremum_amplitude, + get_template_main_channel_peak_shift, + get_template_main_channel_amplitude, + + # this is not needed anymore + get_template_extremum_channel, # keep for backward compatibility can be removed in 0.105 + get_template_extremum_channel_peak_shift, # keep for backward compatibility can be removed in 0.105 + get_template_extremum_amplitude, # keep for backward compatibility can be removed in 0.105 ) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index d49065e28d..a9ef4e7d8b 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -13,6 +13,9 @@ class BaseSorting(BaseExtractor): """ Abstract class representing several segment several units and relative spiketrains. """ + _main_properties = [ + "main_channel_index", + ] def __init__(self, sampling_frequency: float, unit_ids: list): BaseExtractor.__init__(self, unit_ids) @@ -784,6 +787,7 @@ def _compute_and_cache_spike_vector(self) -> None: self._cached_spike_vector = spikes self._cached_spike_vector_segment_slices = segment_slices + # TODO sam : change extremum_channel_inds to main_channel_index with vector def to_spike_vector( self, concatenated=True, @@ -804,7 +808,8 @@ def to_spike_vector( extremum_channel_inds : None or dict, default: None If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index". This can be convinient for computing spikes postion after sorter. - This dict can be computed with `get_template_extremum_channel(we, outputs="index")` + This dict can be given by analyzer.get_main_channels(outputs="index", with_dict=True) + use_cache : bool, default: True When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). This caching only occurs when extremum_channel_inds=None. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 9b40a23dbd..50c8e1329a 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2443,6 +2443,10 @@ def generate_ground_truth_recording( **generate_templates_kwargs, ) sorting.set_property("gt_unit_locations", unit_locations) + distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :], axis=2) + main_channel_index = np.argmin(distances, axis=1) + sorting.set_property("main_channel_index", main_channel_index) + else: assert templates.shape[0] == num_units diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 2c38248c1a..b8e5fac0d8 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -143,6 +143,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea return (local_peaks,) +# TODO sam replace extremum_channels_indices by main_channel_index + # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): """ diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8e16757bcc..3892e68c51 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -44,6 +44,10 @@ from .zarrextractors import get_default_zarr_compressor, ZarrSortingExtractor, super_zarr_open from .node_pipeline import run_node_pipeline +from .waveform_tools import estimate_templates_with_accumulator +from .sorting_tools import random_spikes_selection + + # high level function def create_sorting_analyzer( @@ -51,6 +55,10 @@ def create_sorting_analyzer( recording, format="memory", folder=None, + main_channel_index=None, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", + num_spikes_for_main_channel=100, sparse=True, sparsity=None, set_sparsity_by_dict_key=False, @@ -58,7 +66,9 @@ def create_sorting_analyzer( return_in_uV=True, overwrite=False, backend_options=None, - **sparsity_kwargs, + sparsity_kwargs=None, + seed=None, + **job_kwargs ) -> "SortingAnalyzer": """ Create a SortingAnalyzer by pairing a Sorting and the corresponding Recording. @@ -68,6 +78,11 @@ def create_sorting_analyzer( This object will be also use used for plotting purpose. + The main_channel_index can be externally provided. If not then this is taken from + sorting property. If not then the main_channel_index is estimated using + `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse + the recording. + Parameters ---------- @@ -81,6 +96,17 @@ def create_sorting_analyzer( The mode to store analyzer. If "folder", the analyzer is stored on disk in the specified folder. The "folder" argument must be specified in case of mode "folder". If "memory" is used, the analyzer is stored in RAM. Use this option carefully! + main_channel_index : None | np.array + The main_channel_index can be externally provided + main_channel_peak_sign : "both" | "neg" + In case when the main_channel_index is estimated wich sign to consider "both" or "neg". + main_channel_peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum" : take the peak value (max or min depending on `peak_sign`) + * "at_index" : take value at `nbefore` index + * "peak_to_peak" : take the peak-to-peak amplitude + num_spikes_for_main_channel : int, default: 100 + How many spikes per units to compute the main channel. sparse : bool, default: True If True, then a sparsity mask is computed using the `estimate_sparsity()` function using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. @@ -106,8 +132,8 @@ def create_sorting_analyzer( * storage_options: dict | None (fsspec storage options) * saving_options: dict | None (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) - - sparsity_kwargs : keyword arguments + sparsity_kwargs : dict | None + Dict given to estimate the sparsity. Returns ------- @@ -143,6 +169,9 @@ def create_sorting_analyzer( sparsity off (or give external sparsity) like this. """ + if sparsity_kwargs is None: + sparsity_kwargs = dict() + if isinstance(sorting, dict) and isinstance(recording, dict): if sorting.keys() != recording.keys(): @@ -167,27 +196,54 @@ def create_sorting_analyzer( return_in_uV=return_in_uV, overwrite=overwrite, backend_options=backend_options, - **sparsity_kwargs, + sparsity_kwargs=sparsity_kwargs, + **job_kwargs ) - if format != "memory" and not is_path_remote(folder): - folder = clean_zarr_folder_name(folder) if format == "zarr" else folder - if Path(folder).is_dir(): - if overwrite: - shutil.rmtree(folder) - else: - raise ValueError(f"Folder {folder} already exists! Use overwrite=True to overwrite it.") + if format != "memory": + if format == "zarr": + if not is_path_remote(folder): + folder = clean_zarr_folder_name(folder) + if not is_path_remote(folder): + if Path(folder).is_dir(): + if not overwrite: + raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.") + else: + shutil.rmtree(folder) + + + + # retrieve or compute the main channel index per unit + if main_channel_index is None: + if "main_channel_index" in sorting.get_property_keys(): + main_channel_index = sorting.get_property("main_channel_index") + + if main_channel_index is None: + # this is weird but due to the cyclic import + from .template_tools import estimate_main_channel_from_recording + main_channel_index = estimate_main_channel_from_recording( + recording, + sorting, + peak_sign=main_channel_peak_sign, + peak_mode=main_channel_peak_mode, + num_spikes_for_main_channel=num_spikes_for_main_channel, + seed=seed, + **job_kwargs + ) # handle sparsity if sparsity is not None: # some checks assert isinstance(sparsity, ChannelSparsity), "'sparsity' must be a ChannelSparsity object" - error_msg = "If external sparsity is given, unit_ids must match sorting" - assert np.array_equal(sorting.unit_ids, sparsity.unit_ids), error_msg - error_msg = "If external sparsity is given, channel_ids must match recording" - assert np.array_equal(recording.channel_ids, sparsity.channel_ids), error_msg + assert np.array_equal( + sorting.unit_ids, sparsity.unit_ids + ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" + assert np.array_equal( + recording.channel_ids, sparsity.channel_ids + ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" + assert all(sparsity.mask[u, c] for u, c in enumerate(main_channel_index)), "sparsity si not constistentent with main_channel_index" elif sparse: - sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs) + sparsity = estimate_sparsity(sorting, recording, main_channel_index=main_channel_index, **sparsity_kwargs) else: sparsity = None @@ -209,6 +265,9 @@ def create_sorting_analyzer( recording, format=format, folder=folder, + main_channel_index=main_channel_index, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, @@ -278,6 +337,8 @@ def __init__( format: str | None = None, sparsity: ChannelSparsity | None = None, return_in_uV: bool = True, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", backend_options: dict | None = None, ): # very fast init because checks are done in load and create @@ -288,6 +349,8 @@ def __init__( self.format = format self.sparsity = sparsity self.return_in_uV = return_in_uV + self.main_channel_peak_sign = main_channel_peak_sign + self.main_channel_peak_mode = main_channel_peak_mode # For backward compatibility self.return_scaled = return_in_uV @@ -341,12 +404,17 @@ def create( "zarr", ] = "memory", folder=None, + main_channel_index=None, sparsity=None, return_scaled=None, return_in_uV=True, + main_channel_peak_sign="both", + main_channel_peak_mode="extremum", backend_options=None, ): assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" + assert main_channel_index is not None, "To create a SortingAnalyzer you need to specify the main_channel_index" + # some checks if sorting.sampling_frequency != recording.sampling_frequency: if math.isclose(sorting.sampling_frequency, recording.sampling_frequency, abs_tol=1e-2, rel_tol=1e-5): @@ -375,9 +443,16 @@ def create( from spikeinterface.curation.remove_excess_spikes import RemoveExcessSpikesSorting sorting = RemoveExcessSpikesSorting(sorting=sorting, recording=recording) - + + # This will ensure that the sorting saved always will have this main_channel + + + if format == "memory": - sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, rec_attributes=None) + sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes=None) elif format == "binary_folder": sorting_analyzer = cls.create_binary_folder( folder, @@ -385,6 +460,8 @@ def create( recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, rec_attributes=None, backend_options=backend_options, ) @@ -398,12 +475,16 @@ def create( recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, rec_attributes=None, backend_options=backend_options, ) else: raise ValueError("SortingAnalyzer.create: wrong format") + sorting_analyzer.set_sorting_property.set_sorting_property("main_channel_index", main_channel_index, save=True) + return sorting_analyzer @classmethod @@ -436,7 +517,10 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe return sorting_analyzer @classmethod - def create_memory(cls, sorting, recording, sparsity, return_in_uV, rec_attributes): + def create_memory(cls, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + rec_attributes): # used by create and save_as if rec_attributes is None: @@ -457,11 +541,18 @@ def create_memory(cls, sorting, recording, sparsity, return_in_uV, rec_attribute format="memory", sparsity=sparsity, return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, + ) return sorting_analyzer @classmethod - def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): + def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + + rec_attributes, backend_options): # used by create and save_as folder = Path(folder) @@ -525,16 +616,111 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV settings_file = folder / f"settings.json" settings = dict( return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, ) with open(settings_file, mode="w") as f: json.dump(check_json(settings), f, indent=4) return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) + @classmethod + def _handle_backward_compatibility_settings_pre_init(cls, settings, sorting, sparsity): + """ + backward compatibility before the __init__ to handle the settings: + * return_scaled > return_in_uV + * main_channel_peak_sign + * main_channel_peak_mode + + Note : + * see also _handle_backward_compatibility_settings_post_init + * there is also something at extension level to handle changes in paramaters with deferents mechanism + """ + + new_settings = dict() + new_settings.update(settings) + if "return_scaled" in settings: + new_settings["return_in_uV"] = new_settings.pop("return_scaled") + elif "return_in_uV" in settings: + pass + else: + # old version did not have settings at all + new_settings["return_in_uV"] = True + + if "main_channel_peak_sign" not in settings: + # before 0.104.0 was not in main_channel_peak_sign + # TODO make something more fancy that exlore the previous params of extension + new_settings["main_channel_peak_sign"] = "both" + new_settings["main_channel_peak_mode"] = "extremum" + + return new_settings + + def _handle_backward_compatibility_settings_post_init(self): + """ + backward compatibility after the __init__ to : + * main_channel_index + + Note : + * see also _handle_backward_compatibility_settings_pre_init + * there is also something at extension level to handle changes in paramaters with deferents mechanism + """ + + + if "main_channel_index" not in self.sorting.get_property_keys(): + + warnings.warn("This loaded analyzer is from an older verion main_channel_index need to be computed from templates") + + main_channel_index = None + if self.has_extension("templates"): + # first try to load templates extension + ext = self.get_extension("templates") + + for k in ("average", "median"): + if k in ext.data: + from .template_tools import _get_main_channel_from_template_array + templates_array = ext.data[k] + # TODO @alessio @chris : we need to discuss this + peak_sign = "both" # or "neg" ????? + peak_mode = "extremum" + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, ext.nbefore) + break + + if main_channel_index is None: + if not self.has_recording(): + # TODO @alessio @chris : we need to discuss this + # what to do in this case ??????? + raise ValueError("This analyzer cannot be load and is from an old version, the recording is not available") + else: + + # otherwise we need to estimate the + + from .template_tools import estimate_main_channel_from_recording + # TODO @alessio @chris : we need to discuss this + peak_sign = "both" # or "neg" ????? + peak_mode = "extremum" + + main_channel_index = estimate_main_channel_from_recording( + self.recording, + self.sorting, + peak_sign=peak_sign, + peak_mode=peak_mode, + num_spikes_for_main_channel=100, + seed=None, + ) + + # this is only in memory + self.sorting.set_property("main_channel_index", main_channel_index) + # TODO @alessio @chris : we need to discuss this + # this save also to disk but maybe there is no write for the analyzer... + self.set_sorting_property("main_channel_index", main_channel_index, save=True) + + @classmethod def load_from_binary_folder(cls, folder, recording=None, backend_options=None): from .loading import load + # TODO check that sorting has main_channel_index and ensure backward compatibility + folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" @@ -585,13 +771,19 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): if settings_file.exists(): with open(settings_file, "r") as f: settings = json.load(f) + need_to_create = False else: + need_to_create = True + settings = dict() + + settings = cls._handle_backward_compatibility_settings_pre_init(settings, sorting, sparsity) + + if need_to_create: warnings.warn("settings.json not found for this folder writing one with return_in_uV=True") - settings = dict(return_in_uV=True) with open(settings_file, "w") as f: json.dump(check_json(settings), f, indent=4) - return_in_uV = settings.get("return_in_uV", settings.get("return_scaled", True)) + sorting_analyzer = SortingAnalyzer( sorting=sorting, @@ -599,7 +791,9 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): rec_attributes=rec_attributes, format="binary_folder", sparsity=sparsity, - return_in_uV=return_in_uV, + return_in_uV = settings["return_in_uV"], + main_channel_peak_sign = settings["main_channel_peak_sign"], + main_channel_peak_mode = settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -614,7 +808,11 @@ def _get_zarr_root(self, mode="r+"): return zarr_root @classmethod - def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_attributes, backend_options): + def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, + main_channel_peak_sign, + main_channel_peak_mode, + + rec_attributes, backend_options): # used by create and save_as import zarr import numcodecs @@ -638,7 +836,11 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) - settings = dict(return_in_uV=return_in_uV) + settings = dict( + return_in_uV=return_in_uV, + main_channel_peak_sign=main_channel_peak_sign, + main_channel_peak_mode=main_channel_peak_mode, + ) zarr_root.attrs["settings"] = check_json(settings) # the recording @@ -702,6 +904,8 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): import zarr from .loading import load + # TODO check that sorting has main_channel_index and ensure backward compatibility + backend_options = {} if backend_options is None else backend_options storage_options = backend_options.get("storage_options", {}) @@ -757,10 +961,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): ) else: sparsity = None - - return_in_uV = zarr_root.attrs["settings"].get( - "return_in_uV", zarr_root.attrs["settings"].get("return_scaled", True) - ) + + settings = zarr_root.attrs["settings"] + settings = cls._handle_backward_compatibility_settings_pre_init(settings, sorting, sparsity) sorting_analyzer = SortingAnalyzer( sorting=sorting, @@ -768,7 +971,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): rec_attributes=rec_attributes, format="zarr", sparsity=sparsity, - return_in_uV=return_in_uV, + return_in_uV = settings["return_in_uV"], + main_channel_peak_sign = settings["main_channel_peak_sign"], + main_channel_peak_mode = settings["main_channel_peak_mode"], backend_options=backend_options, ) sorting_analyzer.folder = folder @@ -870,6 +1075,24 @@ def get_sorting_property(self, key: str, ids: Optional[Iterable] = None) -> np.n Array of values for the property """ return self.sorting.get_property(key, ids=ids) + + def get_main_channels(self, outputs="index", with_dict=False): + """ + + """ + main_channel_index = self.get_sorting_property("main_channel_index") + if outputs == "index": + main_chans = main_channel_index + elif outputs == "id": + main_chans = self.channel_ids[main_channel_index] + else: + raise ValueError("wrong outputs") + + if with_dict: + return dict(zip(self.unit_ids, main_chans)) + else: + return main_chans + def are_units_mergeable( self, @@ -1092,7 +1315,10 @@ def _save_or_select_or_merge_or_split( if format == "memory": # This make a copy of actual SortingAnalyzer new_sorting_analyzer = SortingAnalyzer.create_memory( - sorting_provenance, recording, sparsity, self.return_in_uV, self.rec_attributes + sorting_provenance, recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, + self.rec_attributes ) elif format == "binary_folder": @@ -1105,6 +1331,9 @@ def _save_or_select_or_merge_or_split( recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, + self.rec_attributes, backend_options=backend_options, ) @@ -1118,6 +1347,8 @@ def _save_or_select_or_merge_or_split( recording, sparsity, self.return_in_uV, + self.main_channel_peak_sign, + self.main_channel_peak_mode, self.rec_attributes, backend_options=backend_options, ) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 91eb7df864..9d766732fa 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -27,8 +27,6 @@ In this case the sparsity for each unit is given by the channels that have the same property value as the unit. Use the "by_property" argument to specify the property name. - peak_sign : "neg" | "pos" | "both" - Sign of the template to compute best channels. num_channels : int Number of channels for "best_channels" method. radius_um : float @@ -81,14 +79,14 @@ class ChannelSparsity: Using the N best channels (largest template amplitude): - >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_best_channels(sorting_analyzer, num_channels) Using a neighborhood by radius: - >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_radius(sorting_analyzer, radius_um) Using a SNR threshold: - >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold, peak_sign="neg") + >>> sparsity = ChannelSparsity.from_snr(sorting_analyzer, threshold) Using a template energy threshold: >>> sparsity = ChannelSparsity.from_energy(sorting_analyzer, threshold) @@ -367,15 +365,17 @@ def from_closest_channels(cls, templates_or_sorting_analyzer, num_channels, peak return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): + def from_radius_and_main_channel(cls, unit_ids, channel_ids, main_channel_index, channel_locations, radius_um): """ - Construct sparsity from a radius around the best channel. + Construct sparsity from a radius around the main channel. Use the "radius_um" argument to specify the radius in um. Parameters ---------- - templates_or_sorting_analyzer : Templates | SortingAnalyzer - A Templates or a SortingAnalyzer object. + main_channel_index : np.array + Main channel index per units. + channel_locations : np.array + Channel locations of the recording. radius_um : float Radius in um for "radius" method. peak_sign : "neg" | "pos" | "both" @@ -386,19 +386,40 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): sparsity : ChannelSparsity The estimated sparsity. """ - from .template_tools import get_template_extremum_channel - - mask = np.zeros( - (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" - ) - channel_locations = templates_or_sorting_analyzer.get_channel_locations() + mask = np.zeros((len(unit_ids), len(channel_ids)), dtype="bool") distances = np.linalg.norm(channel_locations[:, np.newaxis] - channel_locations[np.newaxis, :], axis=2) - best_chan = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, outputs="index") - for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): - chan_ind = best_chan[unit_id] - (chan_inds,) = np.nonzero(distances[chan_ind, :] <= radius_um) + for unit_ind, main_chan in enumerate(main_channel_index): + (chan_inds,) = np.nonzero(distances[main_chan, :] <= radius_um) mask[unit_ind, chan_inds] = True - return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) + return cls(mask, unit_ids, channel_ids) + + + @classmethod + def from_radius(cls, templates_or_sorting_analyzer, radius_um): + """ + Construct sparsity from a radius around the main channel. + Use the "radius_um" argument to specify the radius in um. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + radius_um : float + Radius in um for "radius" method. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ + main_channel_index = templates_or_sorting_analyzer.get_main_channels(outputs="index") + channel_locations = templates_or_sorting_analyzer.get_channel_locations() + return cls.from_radius_and_main_channel( + templates_or_sorting_analyzer.unit_ids, + templates_or_sorting_analyzer.channel_ids, + main_channel_index, + channel_locations, + radius_um) @classmethod def from_snr( @@ -726,6 +747,7 @@ def estimate_sparsity( amplitude_mode: "extremum" | "peak_to_peak" = "extremum", by_property: str | None = None, noise_levels: np.ndarray | list | None = None, + main_channel_index: np.ndarray | list | None = None, **job_kwargs, ): """ @@ -734,11 +756,10 @@ def estimate_sparsity( For the "snr" method, `noise_levels` must passed with the `noise_levels` argument. These can be computed with the `get_noise_levels()` function. - Contrary to the previous implementation: - * all units are computed in one read of recording - * it doesn't require a folder - * it doesn't consume too much memory - * it uses internally the `estimate_templates_with_accumulator()` which is fast and parallel + If main_channel_index is given and method="radius" then there is not need estimate + the templates otherwise the templates must be estimated using + `estimate_templates_with_accumulator()` which is fast and parallel but need to traverse + the recording. Note that the "energy" method is not supported because it requires a `SortingAnalyzer` object. @@ -757,6 +778,9 @@ def estimate_sparsity( noise_levels : np.ndarray | None, default: None Noise levels required for the "snr" and "energy" methods. You can use the `get_noise_levels()` function to compute them. + main_channel_index : np.array | None, default: None + Main channel indicies can be provided in case of method="radius", this avoid the + `estimate_templates_with_accumulator()` which is slow. {} Returns @@ -781,7 +805,14 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - if method != "by_property": + if method == "radius" and main_channel_index is not None: + assert main_channel_index.size == sorting.unit_ids.size + chan_locs = recording.get_channel_locations() + sparsity = ChannelSparsity.from_radius_and_main_channel( + sorting.unit_ids, recording.channel_ids, main_channel_index, chan_locs, radius_um + ) + + elif method != "by_property": nbefore = int(ms_before * recording.sampling_frequency / 1000.0) nafter = int(ms_after * recording.sampling_frequency / 1000.0) @@ -827,7 +858,7 @@ def estimate_sparsity( sparsity = ChannelSparsity.from_closest_channels(templates, num_channels) elif method == "radius": assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" - sparsity = ChannelSparsity.from_radius(templates, radius_um, peak_sign=peak_sign) + sparsity = ChannelSparsity.from_radius(templates, radius_um) elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" assert noise_levels is not None, ( diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 3e5a517b0a..dbbf87c560 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -482,3 +482,27 @@ def get_channel_locations(self) -> np.ndarray: assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set" channel_locations = self.probe.contact_positions return channel_locations + + def get_main_channels(self, + peak_sign: "neg" | "both" | "pos" = "both", + peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + outputs="index", + with_dict=True + ): + from .template_tools import _get_main_channel_from_template_array + + templates_array = self.get_dense_templates() + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, self.nbefore) + + if outputs == "index": + main_chans = main_channel_index + elif outputs == "id": + main_chans = self.channel_ids[main_channel_index] + else: + raise ValueError("wrong outputs") + + if with_dict: + return dict(zip(self.unit_ids, main_chans)) + else: + return main_chans + diff --git a/src/spikeinterface/core/template_tools.py b/src/spikeinterface/core/template_tools.py index 0293c23876..427b763e48 100644 --- a/src/spikeinterface/core/template_tools.py +++ b/src/spikeinterface/core/template_tools.py @@ -3,8 +3,13 @@ import numpy as np from .template import Templates +from .waveform_tools import estimate_templates_with_accumulator +from .sorting_tools import random_spikes_selection from .sortinganalyzer import SortingAnalyzer +import warnings + + def get_dense_templates_array( one_object: Templates | SortingAnalyzer, return_in_uV: bool = True, operator="average" @@ -47,6 +52,33 @@ def get_dense_templates_array( raise ValueError("Input should be Templates or SortingAnalyzer") return templates_array +def get_main_channel_templates_array(one_object: Templates | SortingAnalyzer, return_in_uV: bool = True): + """ + Return dense templates as numpy array from either a Templates object or a SortingAnalyzer. + + Parameters + ---------- + one_object : Templates | SortingAnalyzer + The Templates or SortingAnalyzer objects. If SortingAnalyzer, it needs the "templates" extension. + return_in_uV : bool, default: True + If True, templates are scaled. + + Returns + ------- + main_channel_templates : np.ndarray + The dense templates (num_units, num_samples) + """ + # TODO later: do not load the dense templates array if this is not necessary (when sprse internally) + main_channels = one_object.get_main_channels(outputs="index", with_dict=False) + templates_array = get_dense_templates_array(one_object, return_in_uV=return_in_uV) + num_units = templates_array.shape[0] + num_samples = templates_array.shape[1] + main_channel_templates = np.zeros((num_units, num_samples), dtype=templates_array.dtype) + for i in range(num_units): + main_channel_templates[i, :] = templates_array[i, :, main_channels[i]] + return main_channel_templates + + def _get_nbefore(one_object): if isinstance(one_object, Templates): @@ -62,9 +94,9 @@ def _get_nbefore(one_object): def get_template_amplitudes( templates_or_sorting_analyzer, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", - return_in_uV: bool = True, + peak_sign: None | "neg" | "pos" | "both" = None, + peak_mode: None | "extremum" | "at_index" | "peak_to_peak" = None, + # return_in_uV: bool = True, abs_value: bool = True, operator: str = "average", ): @@ -75,15 +107,15 @@ def get_template_amplitudes( ---------- templates_or_sorting_analyzer : Templates | SortingAnalyzer A Templates or a SortingAnalyzer object - peak_sign : "neg" | "pos" | "both" + peak_sign : None | "neg" | "pos" | "both" + Used only when input is Templates. Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + peak_mode : None | "extremum" | "at_index" | "peak_to_peak", default: None + Used only when input is Templates. Where the amplitude is computed * "extremum" : take the peak value (max or min depending on `peak_sign`) * "at_index" : take value at `nbefore` index * "peak_to_peak" : take the peak-to-peak amplitude - return_in_uV : bool, default True - The amplitude is scaled or not. abs_value : bool = True Whether the extremum amplitude should be returned as an absolute value or not operator : str, default: "average" @@ -95,8 +127,24 @@ def get_template_amplitudes( peak_values : dict Dictionary with unit ids as keys and template amplitudes as values """ + + if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): + assert peak_sign is None, "get_template_amplitudes() peak_sign is now contained in SortingAnalyzer, should be None here" + assert peak_mode is None, "get_template_amplitudes() peak_mode is now contained in SortingAnalyzer, should be None here" + peak_sign = templates_or_sorting_analyzer.main_channel_peak_sign + peak_mode = templates_or_sorting_analyzer.main_channel_peak_mode + return_in_uV = templates_or_sorting_analyzer.return_in_uV + elif isinstance(templates_or_sorting_analyzer, Templates): + return_in_uV = templates_or_sorting_analyzer.is_in_uV + if peak_sign is None: + warnings.warn("get_template_amplitudes() with Templates should provide a peak_sign") + peak_sign = "both" + if peak_mode is None: + warnings.warn("get_template_amplitudes() with Templates should provide a peak_mode") + peak_mode = "extremum" + assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'" - assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" + assert peak_mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" unit_ids = templates_or_sorting_analyzer.unit_ids before = _get_nbefore(templates_or_sorting_analyzer) @@ -110,19 +158,19 @@ def get_template_amplitudes( for unit_ind, unit_id in enumerate(unit_ids): template = templates_array[unit_ind, :, :] - if mode == "extremum": + if peak_mode == "extremum": if peak_sign == "both": values = np.max(np.abs(template), axis=0) elif peak_sign == "neg": values = np.min(template, axis=0) elif peak_sign == "pos": values = np.max(template, axis=0) - elif mode == "at_index": + elif peak_mode == "at_index": if peak_sign == "both": values = np.abs(template[before, :]) elif peak_sign in ["neg", "pos"]: values = template[before, :] - elif mode == "peak_to_peak": + elif peak_mode == "peak_to_peak": values = np.ptp(template, axis=0) if abs_value: @@ -133,14 +181,100 @@ def get_template_amplitudes( return peak_values + +def _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore): + # Step1 : max on time axis + if peak_mode == "extremum": + if peak_sign == "both": + values = np.max(np.abs(templates_array), axis=1) + elif peak_sign == "neg": + values = -np.min(templates_array, axis=1) + elif peak_sign == "pos": + values = np.max(templates_array, axis=1) + elif peak_mode == "at_index": + if peak_sign == "both": + values = np.abs(templates_array[:, nbefore, :]) + elif peak_sign in ["neg", "pos"]: + values = templates_array[:, nbefore, :] + elif peak_mode == "peak_to_peak": + values = np.ptp(templates_array, axis=1) + + # Step2: max on channel axis + main_channel_index = np.argmax(values, axis=1) + + return main_channel_index + + +def estimate_main_channel_from_recording( + recording, + sorting, + peak_sign: "neg" | "both" | "pos" = "both", + peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + num_spikes_for_main_channel=100, + ms_before = 1.0, + ms_after = 2.5, + seed=None, + **job_kwargs +): + """ + Estimate the main channel from recording using `estimate_templates_with_accumulator()` + + """ + + if peak_sign == "pos": + warnings.warn( + "estimate_main_channel_from_recording() with peak_sign='pos' is a strange case maybe you " \ + "should revert the traces instead" + ) + + + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_main_channel, + margin_size=max(nbefore, nafter), + seed=seed, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates_with_accumulator( + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + return_in_uV=False, + job_name="estimate_main_channel", + **job_kwargs, + ) + + main_channel_index = _get_main_channel_from_template_array(templates_array, peak_mode, peak_sign, nbefore) + + return main_channel_index + + + + + +# TODO remove this in 0.105.0 def get_template_extremum_channel( templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + peak_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", outputs: "id" | "index" = "id", operator: str = "average", ): """ + Depracted will be removed in 0.105.0. + Use analyzer.get_main_channels() or tempates.get_main_channels(peak_sign=...) instead. + + Compute the channel with the extremum peak for each unit. Parameters @@ -149,7 +283,7 @@ def get_template_extremum_channel( A Templates or a SortingAnalyzer object peak_sign : "neg" | "pos" | "both" Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" Where the amplitude is computed * "extremum" : take the peak value (max or min depending on `peak_sign`) * "at_index" : take value at `nbefore` index @@ -163,48 +297,32 @@ def get_template_extremum_channel( Returns ------- - extremum_channels : dict + main_channels : dict Dictionary with unit ids as keys and extremum channels (id or index based on "outputs") as values """ - assert peak_sign in ("both", "neg", "pos"), "`peak_sign` must be one of `both`, `neg`, or `pos`" - assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" - assert outputs in ("id", "index"), "`outputs` must be either `id` or `index`" - - unit_ids = templates_or_sorting_analyzer.unit_ids - channel_ids = templates_or_sorting_analyzer.channel_ids + warnings.warn("get_template_extremum_channel() is deprecated use analyzer.get_main_channels() instead") - # if SortingAnalyzer need to use global SortingAnalyzer return_scaled otherwise - # we use the Templates is_in_uV if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - # For backward compatibility - if hasattr(templates_or_sorting_analyzer, "return_scaled"): - return_in_uV = templates_or_sorting_analyzer.return_scaled - else: - return_in_uV = templates_or_sorting_analyzer.return_in_uV - else: - return_in_uV = templates_or_sorting_analyzer.is_in_uV + assert peak_sign is None, "get_template_extremum_channel() peak_sign is now contained in SortingAnalyzer, should be None here" + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs=outputs, with_dict=True) + elif isinstance(templates_or_sorting_analyzer, Templates): + if peak_sign is None: + warnings.warn("get_template_extremum_channel() with Templates should provide a peak_sign") + peak_sign = "both" + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs=outputs, peak_sign=peak_sign, with_dict=True) + + return main_channels - peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, return_in_uV=return_in_uV, operator=operator - ) - extremum_channels_id = {} - extremum_channels_index = {} - for unit_id in unit_ids: - max_ind = np.argmax(np.abs(peak_values[unit_id])) - extremum_channels_id[unit_id] = channel_ids[max_ind] - extremum_channels_index[unit_id] = max_ind - if outputs == "id": - return extremum_channels_id - elif outputs == "index": - return extremum_channels_index -def get_template_extremum_channel_peak_shift( - templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", operator: str = "average" -): +# TODO remove this in 0.105.0 +def get_template_extremum_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None): """ + Depracted will be removed in 0.105.0. + Use get_template_main_channel_peak_shift() instead. + In some situations spike sorters could return a spike index with a small shift related to the waveform peak. This function estimates and return these alignment shifts for the mean template. This function is internally used by `compute_spike_amplitudes()` to accurately retrieve the spike amplitudes. @@ -213,7 +331,33 @@ def get_template_extremum_channel_peak_shift( ---------- templates_or_sorting_analyzer : Templates | SortingAnalyzer A Templates or a SortingAnalyzer object - peak_sign : "neg" | "pos" | "both" + peak_sign : None + Sign of the template to find extremum channels + + Returns + ------- + shifts : dict + Dictionary with unit ids as keys and shifts as values + """ + + warnings.warn("get_template_extremum_channel_peak_shift() is deprecated use get_template_main_channel_peak_shift() instead" + "Will be removed in 0.105.0" + ) + + return get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None, with_dict=True) + + +def get_template_main_channel_peak_shift(templates_or_sorting_analyzer, peak_sign=None, with_dict=True): + """ + In some situations spike sorters could return a spike index with a small shift related to the waveform peak. + This function estimates and return these alignment shifts for the mean template. + This function is internally used by `compute_spike_amplitudes()` to accurately retrieve the spike amplitudes. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object + peak_sign : None | "neg" | "pos" | "both" Sign of the template to find extremum channels operator : str, default: "average" If the "templates" extension of the SortingAnalyzer contains several operators (e.g., "average" and "median"), @@ -224,53 +368,57 @@ def get_template_extremum_channel_peak_shift( shifts : dict Dictionary with unit ids as keys and shifts as values """ - unit_ids = templates_or_sorting_analyzer.unit_ids - channel_ids = templates_or_sorting_analyzer.channel_ids - nbefore = _get_nbefore(templates_or_sorting_analyzer) - - extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign) + - shifts = {} - - # We need to use the SortingAnalyzer return_scaled - # We need to use the Templates is_in_uV if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - # For backward compatibility - if hasattr(templates_or_sorting_analyzer, "return_scaled"): - return_in_uV = templates_or_sorting_analyzer.return_scaled - else: - return_in_uV = templates_or_sorting_analyzer.return_in_uV - else: - return_in_uV = templates_or_sorting_analyzer.is_in_uV + assert peak_sign is None + peak_sign = templates_or_sorting_analyzer.main_channel_peak_sign + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=False) + elif isinstance(templates_or_sorting_analyzer, Templates): + if peak_sign is None: + warnings.warn("get_template_main_channel_peak_shift() with Templates should provide a peak_sign") + peak_sign = "both" + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", peak_sign=peak_sign, with_dict=False) - templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_in_uV=return_in_uV) - - for unit_ind, unit_id in enumerate(unit_ids): - template = templates_array[unit_ind, :, :] + unit_ids = templates_or_sorting_analyzer.unit_ids + nbefore = _get_nbefore(templates_or_sorting_analyzer) - chan_id = extremum_channels_ids[unit_id] - chan_ind = list(channel_ids).index(chan_id) + + main_channel_templates = get_main_channel_templates_array(templates_or_sorting_analyzer) + shifts = [] + for unit_ind, unit_id in enumerate(unit_ids): + chan_ind = main_channels[unit_ind] + template = main_channel_templates[chan_ind] if peak_sign == "both": - peak_pos = np.argmax(np.abs(template[:, chan_ind])) + peak_pos = np.argmax(np.abs(template)) elif peak_sign == "neg": - peak_pos = np.argmin(template[:, chan_ind]) + peak_pos = np.argmin(template) elif peak_sign == "pos": - peak_pos = np.argmax(template[:, chan_ind]) + peak_pos = np.argmax(template) shift = peak_pos - nbefore - shifts[unit_id] = shift + shifts.append(shift) + + if with_dict: + shifts = dict(zip(unit_ids, shifts)) + else: + shifts = np.array(shifts) return shifts +# TODO remove this in 0.105.0 def get_template_extremum_amplitude( templates_or_sorting_analyzer, peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "at_index", + peak_mode: "extremum" | "at_index" | "peak_to_peak" = "at_index", abs_value: bool = True, operator: str = "average", ): """ + Depracted will be removed in 0.105.0. + Use get_template_main_channel_amplitude() instead. + Computes amplitudes on the best channel. Parameters @@ -279,7 +427,7 @@ def get_template_extremum_amplitude( A Templates or a SortingAnalyzer object peak_sign : "neg" | "pos" | "both" Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" Where the amplitude is computed * "extremum": take the peak value (max or min depending on `peak_sign`) * "at_index": take value at `nbefore` index @@ -296,35 +444,67 @@ def get_template_extremum_amplitude( amplitudes : dict Dictionary with unit ids as keys and amplitudes as values """ - assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'" - assert mode in ("extremum", "at_index", "peak_to_peak"), "'mode' must be 'extremum', 'at_index', or 'peak_to_peak'" - unit_ids = templates_or_sorting_analyzer.unit_ids - channel_ids = templates_or_sorting_analyzer.channel_ids - extremum_channels_ids = get_template_extremum_channel(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode) + warnings.warn("get_template_extremum_amplitude() is deprecated use get_template_main_channel_amplitude() instead" + "Will be removed in 0.105.0" + ) + return get_template_main_channel_amplitude( + templates_or_sorting_analyzer, + peak_sign=peak_sign, + peak_mode=peak_mode, + abs_value=abs_value, + + ) +def get_template_main_channel_amplitude( + templates_or_sorting_analyzer, + peak_sign: None | "neg" | "pos" | "both" = None, + peak_mode: None | "extremum" | "at_index" | "peak_to_peak" = None, + abs_value: bool = True, + with_dict=True, +): + """ + Computes amplitudes on the best channel. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object + peak_sign : "neg" | "pos" | "both" + Sign of the template to find extremum channels + peak_mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" + Where the amplitude is computed + * "extremum": take the peak value (max or min depending on `peak_sign`) + * "at_index": take value at `nbefore` index + * "peak_to_peak": take the peak-to-peak amplitude + abs_value : bool = True + Whether the extremum amplitude should be returned as an absolute value or not + + + Returns + ------- + amplitudes : dict + Dictionary with unit ids as keys and amplitudes as values + """ if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - # For backward compatibility - if hasattr(templates_or_sorting_analyzer, "return_scaled"): - return_in_uV = templates_or_sorting_analyzer.return_scaled - else: - return_in_uV = templates_or_sorting_analyzer.return_in_uV - else: - return_in_uV = templates_or_sorting_analyzer.is_in_uV + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=False) + elif isinstance(templates_or_sorting_analyzer, Templates): + if peak_sign is None: + warnings.warn("get_template_main_channel_peak_shift() with Templates should provide a peak_sign") + peak_sign = "both" + main_channels = templates_or_sorting_analyzer.get_main_channels(outputs="index", peak_sign=peak_sign, with_dict=False) extremum_amplitudes = get_template_amplitudes( - templates_or_sorting_analyzer, - peak_sign=peak_sign, - mode=mode, - return_in_uV=return_in_uV, - abs_value=abs_value, - operator=operator, + templates_or_sorting_analyzer, peak_sign=peak_sign, peak_mode=peak_mode, abs_value=abs_value ) - unit_amplitudes = {} - for unit_id in unit_ids: - channel_id = extremum_channels_ids[unit_id] - best_channel = list(channel_ids).index(channel_id) - unit_amplitudes[unit_id] = extremum_amplitudes[unit_id][best_channel] + unit_ids = templates_or_sorting_analyzer.unit_ids + unit_amplitudes = [] + for unit_ind, unit_id in enumerate(unit_ids): + chan_ind = main_channels[unit_ind] + unit_amplitudes.append(extremum_amplitudes[unit_id][chan_ind]) + + if with_dict: + unit_amplitudes = dict(zip()) - return unit_amplitudes + return unit_amplitudes \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_loading.py b/src/spikeinterface/core/tests/test_loading.py index bfaf97ec4a..44a7da5424 100644 --- a/src/spikeinterface/core/tests/test_loading.py +++ b/src/spikeinterface/core/tests/test_loading.py @@ -227,3 +227,8 @@ def test_remote_analyzer(): "quality_metrics", ]: assert ext in analyzer.get_saved_extension_names() + + +if __name__ == "__main__": + test_remote_analyzer() + \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 4f8e600a3f..7a29a3cee6 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -3,7 +3,7 @@ from pathlib import Path import shutil -from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording +from spikeinterface import create_sorting_analyzer, generate_ground_truth_recording from spikeinterface.core.base import spike_peak_dtype from spikeinterface.core.job_tools import divide_recording_into_chunks @@ -80,7 +80,7 @@ def test_run_node_pipeline(cache_folder_creation): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) @@ -202,7 +202,7 @@ def test_skip_after_n_peaks_and_recording_slices(): # create peaks from spikes sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory") sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, peak_sign="neg", outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels( outputs="index", with_dict=True) peaks = sorting_to_peaks(sorting, extremum_channel_inds, spike_peak_dtype) # print(peaks.size) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..765d5bec7e 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -52,6 +52,12 @@ def dataset(): def test_SortingAnalyzer_memory(tmp_path, dataset): recording, sorting = dataset + + # Note the sorting contain already main_channel_index + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + assert np.array_equal(sorting_analyzer.get_main_channels() , sorting.get_property("main_channel_index")) + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) @@ -75,6 +81,16 @@ def test_SortingAnalyzer_memory(tmp_path, dataset): assert "quality" in sorting_analyzer.sorting.get_property_keys() assert "number" in sorting_analyzer.sorting.get_property_keys() + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + + # Create when main_channel_index is not given : this is estimated + sorting2 = sorting.clone() + sorting2._properties.pop("main_channel_index") + print(sorting2.get_property("main_channel_index")) + sorting_analyzer = create_sorting_analyzer(sorting2, recording, format="memory", sparse=False, sparsity=None) + _check_sorting_analyzers(sorting_analyzer, sorting2, cache_folder=tmp_path) + def test_SortingAnalyzer_binary_folder(tmp_path, dataset): recording, sorting = dataset @@ -349,7 +365,6 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): assert ext is None assert sorting_analyzer.has_recording() - # save to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory": @@ -615,12 +630,11 @@ def _set_params(self, param0=5.5): return params def _get_pipeline_nodes(self): - from spikeinterface.core.template_tools import get_template_extremum_channel recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - extremum_channel_inds = get_template_extremum_channel(self.sorting_analyzer, outputs="index") + extremum_channel_inds = self.sorting_analyzer.get_main_channels( outputs="index", with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) @@ -718,9 +732,9 @@ def test_runtime_dependencies(dataset): tmp_path = Path("test_SortingAnalyzer") dataset = get_dataset() test_SortingAnalyzer_memory(tmp_path, dataset) - test_SortingAnalyzer_binary_folder(tmp_path, dataset) - test_SortingAnalyzer_zarr(tmp_path, dataset) - test_SortingAnalyzer_tmp_recording(dataset) - test_extension() - test_extension_params() - test_runtime_dependencies() + # test_SortingAnalyzer_binary_folder(tmp_path, dataset) + # test_SortingAnalyzer_zarr(tmp_path, dataset) + # test_SortingAnalyzer_tmp_recording(dataset) + # test_extension() + # test_extension_params() + # test_runtime_dependencies(dataset) diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 6e85221621..071de348e0 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -304,7 +304,7 @@ def test_compute_sparsity(): # using object SortingAnalyzer sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") + sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0) sparsity = compute_sparsity(sorting_analyzer, method="closest_channels", num_channels=2) sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") sparsity = compute_sparsity( @@ -318,13 +318,13 @@ def test_compute_sparsity(): templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") - sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") + sparsity = compute_sparsity(templates, method="radius", radius_um=50.0) sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(templates, method="closest_channels", num_channels=2) if __name__ == "__main__": - # test_ChannelSparsity() - # test_estimate_sparsity() + test_ChannelSparsity() + test_estimate_sparsity() test_compute_sparsity() diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index a28680612a..5e4cd1f12c 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -7,9 +7,8 @@ from spikeinterface import Templates from spikeinterface.core import ( get_template_amplitudes, - get_template_extremum_channel, - get_template_extremum_channel_peak_shift, - get_template_extremum_amplitude, + get_template_main_channel_peak_shift, + get_template_main_channel_amplitude, ) @@ -56,24 +55,17 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer): def test_get_template_amplitudes(sorting_analyzer): peak_values = get_template_amplitudes(sorting_analyzer) templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - peak_values = get_template_amplitudes(templates, abs_value=True) - peak_to_peak_values = get_template_amplitudes(templates, mode="peak_to_peak") + peak_values = get_template_amplitudes(templates, peak_sign="both", peak_mode="extremum", abs_value=True) + peak_to_peak_values = get_template_amplitudes(templates, peak_sign="both", peak_mode="peak_to_peak") assert np.all(ptp > p for ptp, p in zip(peak_to_peak_values.values(), peak_values.values())) -def test_get_template_extremum_channel(sorting_analyzer): - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign="both") - print(extremum_channels_ids) - templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - extremum_channels_ids = get_template_extremum_channel(templates, peak_sign="both") - print(extremum_channels_ids) - -def test_get_template_extremum_channel_peak_shift(sorting_analyzer): - shifts = get_template_extremum_channel_peak_shift(sorting_analyzer, peak_sign="neg") +def test_get_template_main_channel_peak_shift(sorting_analyzer): + shifts = get_template_main_channel_peak_shift(sorting_analyzer) print(shifts) templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - shifts = get_template_extremum_channel_peak_shift(templates, peak_sign="neg") + shifts = get_template_main_channel_peak_shift(templates, peak_sign="both") # DEBUG # import matplotlib.pyplot as plt @@ -91,13 +83,14 @@ def test_get_template_extremum_channel_peak_shift(sorting_analyzer): # plt.show() -def test_get_template_extremum_amplitude(sorting_analyzer): +def test_get_template_main_channel_amplitude(sorting_analyzer): - extremum_channels_ids = get_template_extremum_amplitude(sorting_analyzer, peak_sign="both") + extremum_channels_ids = get_template_main_channel_amplitude(sorting_analyzer) print(extremum_channels_ids) templates = _get_templates_object_from_sorting_analyzer(sorting_analyzer) - extremum_channels_ids = get_template_extremum_amplitude(templates, peak_sign="both") + + extremum_channels_ids = get_template_main_channel_amplitude(templates, peak_sign="both", peak_mode="extremum") if __name__ == "__main__": @@ -107,6 +100,5 @@ def test_get_template_extremum_amplitude(sorting_analyzer): print(sorting_analyzer) test_get_template_amplitudes(sorting_analyzer) - test_get_template_extremum_channel(sorting_analyzer) - test_get_template_extremum_channel_peak_shift(sorting_analyzer) - test_get_template_extremum_amplitude(sorting_analyzer) + test_get_template_main_channel_peak_shift(sorting_analyzer) + test_get_template_main_channel_amplitude(sorting_analyzer) diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index aa7987b7fe..9675335848 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -3,7 +3,7 @@ from spikeinterface import SortingAnalyzer -from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift, get_template_amplitudes +from spikeinterface.core.template_tools import get_template_main_channel_peak_shift, get_template_amplitudes from spikeinterface.postprocessing import align_sorting _remove_strategies = ("minimum_shift", "highest_amplitude", "max_spikes") @@ -79,7 +79,8 @@ def remove_redundant_units( if align and unit_peak_shifts is None: assert sorting_analyzer is not None, "For align=True must give a SortingAnalyzer or explicit unit_peak_shifts" - unit_peak_shifts = get_template_extremum_channel_peak_shift(sorting_analyzer, peak_sign=peak_sign) + + unit_peak_shifts = get_template_main_channel_peak_shift(sorting_analyzer, with_dict=True) if align: sorting_aligned = align_sorting(sorting, unit_peak_shifts) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 4761660f68..97a7ce7678 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -4,7 +4,7 @@ from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs import spikeinterface.widgets as sw -from spikeinterface.core import get_template_extremum_channel, get_template_extremum_amplitude +from spikeinterface.core import get_template_main_channel_amplitude from spikeinterface.postprocessing import compute_correlograms @@ -99,10 +99,11 @@ def export_report( # unit list units = pd.DataFrame(index=unit_ids) #  , columns=['max_on_channel_id', 'amplitude']) units.index.name = "unit_id" - units["max_on_channel_id"] = pd.Series( - get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="id") - ) - units["amplitude"] = pd.Series(get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign)) + # max_on_channel_id is kept (oold name) + units["max_on_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) + units["main_channel_id"] = sorting_analyzer.get_main_channels(outputs="id", with_dict=False) + + units["amplitude"] = pd.Series(get_template_main_channel_amplitude(sorting_analyzer)) units.to_csv(output_folder / "unit list.csv", sep="\t") unit_colors = sw.get_unit_colors(sorting) diff --git a/src/spikeinterface/exporters/to_ibl.py b/src/spikeinterface/exporters/to_ibl.py index 8f18536daa..383e01f897 100644 --- a/src/spikeinterface/exporters/to_ibl.py +++ b/src/spikeinterface/exporters/to_ibl.py @@ -8,7 +8,6 @@ from spikeinterface.core import SortingAnalyzer, BaseRecording, get_random_data_chunks from spikeinterface.core.job_tools import fix_job_kwargs, ChunkRecordingExecutor, _shared_job_kwargs_doc -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.exporters import export_to_phy @@ -100,7 +99,7 @@ def export_to_ibl_gui( output_folder.mkdir(parents=True, exist_ok=True) ### Save spikes info ### - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_indices) # spikes.clusters @@ -135,7 +134,8 @@ def export_to_ibl_gui( np.save(output_folder / "clusters.waveforms.npy", templates) # cluster channels - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + cluster_channels = np.array(list(extremum_channel_indices.values()), dtype="int32") np.save(output_folder / "clusters.channels.npy", cluster_channels) diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 4add37e8a6..36e627ef68 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -11,7 +11,6 @@ InjectTemplatesRecording, _ensure_seed, ) -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.motion import Motion @@ -126,8 +125,8 @@ def select_templates( min_amplitude is not None or max_amplitude is not None or min_depth is not None or max_depth is not None ), "At least one of min_amplitude, max_amplitude, min_depth, max_depth should be provided" # get template amplitudes and depth - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) + mask = np.ones(templates.num_units, dtype=bool) if min_amplitude is not None or max_amplitude is not None: @@ -141,7 +140,7 @@ def select_templates( amplitudes = np.zeros(templates.num_units) templates_array = templates.templates_array for i in range(templates.num_units): - amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]]) + amplitudes[i] = amp_fun(templates_array[i, :, main_channel_indices[i]]) if min_amplitude is not None: mask &= amplitudes >= min_amplitude if max_amplitude is not None: @@ -150,7 +149,7 @@ def select_templates( assert templates.probe is not None, "Templates should have a probe to filter based on depth" depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] - unit_depths = channel_depths[extremum_channel_indices] + unit_depths = channel_depths[main_channel_indices] if min_depth is not None: mask &= unit_depths >= min_depth if max_depth is not None: @@ -189,8 +188,7 @@ def scale_template_to_range( Templates The scaled templates. """ - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) # get amplitudes if amplitude_function == "ptp": @@ -202,7 +200,7 @@ def scale_template_to_range( amplitudes = np.zeros(templates.num_units) templates_array = templates.templates_array for i in range(templates.num_units): - amplitudes[i] = amp_fun(templates_array[i, :, extremum_channel_indices[i]]) + amplitudes[i] = amp_fun(templates_array[i, :, main_channel_indices[i]]) # scale templates to meet min_amplitude and max_amplitude range min_scale = np.min(amplitudes) / min_amplitude @@ -263,11 +261,10 @@ def relocate_templates( """ seed = _ensure_seed(seed) - extremum_channel_indices = list(get_template_extremum_channel(templates, outputs="index").values()) - extremum_channel_indices = np.array(extremum_channel_indices, dtype=int) + main_channel_indices = templates.get_main_channels(outputs="index", with_dict=False) depth_dimension = ["x", "y"].index(depth_direction) channel_depths = templates.get_channel_locations()[:, depth_dimension] - unit_depths = channel_depths[extremum_channel_indices] + unit_depths = channel_depths[main_channel_indices] assert margin >= 0, "margin should be positive" top_margin = np.max(channel_depths) + margin diff --git a/src/spikeinterface/generation/splitting_tools.py b/src/spikeinterface/generation/splitting_tools.py index 1f404ea3f7..03d7cd5ef8 100644 --- a/src/spikeinterface/generation/splitting_tools.py +++ b/src/spikeinterface/generation/splitting_tools.py @@ -107,9 +107,9 @@ def split_sorting_by_amplitudes( rng = np.random.default_rng(seed) fs = sorting_analyzer.sampling_frequency - from spikeinterface.core.template_tools import get_template_extremum_channel - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + spikes = sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, concatenated=False) new_spikes = spikes[0].copy() amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 3d6911b383..d5c3e037da 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -17,8 +17,7 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric from spikeinterface.core import SortingAnalyzer, NumpySorting from spikeinterface.core.template_tools import ( - get_template_extremum_channel, - get_template_extremum_amplitude, + get_template_main_channel_amplitude, get_dense_templates_array, ) from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate @@ -143,9 +142,6 @@ class PresenceRatio(BaseMetric): def compute_snrs( sorting_analyzer, unit_ids=None, - peak_sign: str = "both", - peak_mode: str = "extremum", - operator: str = "median", ): """ Compute signal to noise ratio. @@ -177,34 +173,19 @@ def compute_snrs( noise_levels = sorting_analyzer.get_extension("noise_levels").get_data() - assert peak_sign in ("neg", "pos", "both") - assert peak_mode in ("extremum", "at_index", "peak_to_peak") - channel_ids = sorting_analyzer.channel_ids - if operator not in ("median", "average"): - raise ValueError(f"Invalid operator: {operator}. Expected 'median' or 'average'.") - if operator == "median" and not sorting_analyzer.has_extension("waveforms"): - warnings.warn( - "Operator 'median' requires 'waveforms' extension. Falling back to 'average'. " - "To use 'median', please compute the 'waveforms' extension first with: analyzer.compute('waveforms')" - ) - operator = "average" - extremum_channels_ids = get_template_extremum_channel( - sorting_analyzer, peak_sign=peak_sign, mode=peak_mode, operator=operator - ) - unit_amplitudes = get_template_extremum_amplitude( - sorting_analyzer, peak_sign=peak_sign, mode=peak_mode, operator=operator - ) + main_channel_index = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + unit_amplitudes = get_template_main_channel_amplitude(sorting_analyzer, with_dict=True) # make a dict to access by chan_id noise_levels = dict(zip(channel_ids, noise_levels)) snrs = {} for unit_id in unit_ids: - chan_id = extremum_channels_ids[unit_id] - noise = noise_levels[chan_id] + chan_ind = main_channel_index[unit_id] + noise = noise_levels[chan_ind] amplitude = unit_amplitudes[unit_id] snrs[unit_id] = np.abs(amplitude) / noise @@ -1430,7 +1411,10 @@ def compute_sd_ratio( noise_levels = get_noise_levels( sorting_analyzer.recording, return_in_uV=sorting_analyzer.return_in_uV, method="std", **job_kwargs ) - best_channels = get_template_extremum_channel(sorting_analyzer, outputs="index", peak_sign=peak_sign) + + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + + n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) if correct_for_template_itself: n_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) @@ -1466,7 +1450,7 @@ def compute_sd_ratio( else: unit_std = np.std(spk_amp) - best_channel = best_channels[unit_id] + best_channel = main_channels[unit_id] std_noise = noise_levels[best_channel] if correct_for_template_itself: diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index c6b539fdc1..3a1d7b5c60 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -3,7 +3,6 @@ import warnings import numpy as np -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension @@ -143,7 +142,9 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): all_labels = sorting_analyzer.sorting.unit_ids[spike_unit_indices] # Get extremum channels for neighbor selection in sparse mode - extremum_channels = get_template_extremum_channel(sorting_analyzer, peak_sign=self.params["peak_sign"]) + + + main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) # Pre-compute spike counts and firing rates if advanced NN metrics are requested advanced_nn_metrics = ["nn_advanced"] # Our grouped advanced NN metric @@ -158,7 +159,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids=None): if sorting_analyzer.is_sparse(): neighbor_channel_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] neighbor_unit_ids = [ - other_unit for other_unit in unit_ids if extremum_channels[other_unit] in neighbor_channel_ids + other_unit for other_unit in unit_ids if main_channels[other_unit] in neighbor_channel_ids ] neighbor_channel_indices = sorting_analyzer.channel_ids_to_indices(neighbor_channel_ids) else: diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 96eda404dd..cc0e38cacf 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -3,7 +3,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension -from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array +from spikeinterface.core.template_tools import get_dense_templates_array from .metrics import ( get_trough_and_peak_idx, @@ -249,12 +249,13 @@ def _prepare_data(self, sorting_analyzer, unit_ids): m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"] ) + + extremum_channel_indices = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + operator = self.params["template_operator"] - extremum_channel_indices = get_template_extremum_channel( - sorting_analyzer, peak_sign=peak_sign, outputs="index", operator=operator - ) all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True, operator=operator) + channel_locations = sorting_analyzer.get_channel_locations() main_channel_templates = [] diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 310be8cceb..a9c1b9da06 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -1,7 +1,7 @@ import numpy as np from spikeinterface.core import ChannelSparsity -from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore +from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -102,10 +102,8 @@ def _get_pipeline_nodes(self): else: cut_out_after = nafter - peak_sign = "neg" if np.abs(np.min(all_templates)) > np.max(all_templates) else "pos" - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + # collisions handle_collisions = self.params["handle_collisions"] diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 448be8d055..9022ea12a1 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -6,7 +6,7 @@ import numpy as np from spikeinterface.core import SortingAnalyzer, Templates, compute_sparsity -from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array, get_template_extremum_channel +from spikeinterface.core.template_tools import _get_nbefore, get_dense_templates_array numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -99,7 +99,9 @@ def compute_monopolar_triangulation( chan_inds = sparsity.unit_id_to_channel_indices[unit_id] neighbours_mask[i, chan_inds] = True enforce_decrease_radial_parents = make_radial_order_parents(contact_locations, neighbours_mask) - best_channels = get_template_extremum_channel(sorting_analyzer_or_templates, outputs="index") + + best_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) + unit_location = np.zeros((unit_ids.size, 4), dtype="float64") for i, unit_id in enumerate(unit_ids): @@ -278,7 +280,7 @@ def compute_grid_convolution( contact_locations, radius_um, upsampling_um, margin_um, weight_method ) - peak_channels = get_template_extremum_channel(sorting_analyzer_or_templates, peak_sign, outputs="index") + main_channels = sorting_analyzer_or_templates.get_main_channels(outputs="index", with_dict=True) weights_sparsity_mask = weights > 0 @@ -286,7 +288,7 @@ def compute_grid_convolution( unit_location = np.zeros((len(unit_ids), 3), dtype="float64") for i, unit_id in enumerate(unit_ids): - main_chan = peak_channels[unit_id] + main_chan = main_channels[unit_id] wf = templates[i, :, :] nearest_mask = nearest_template_mask[main_chan, :] channel_mask = np.sum(weights_sparsity_mask[:, :, nearest_mask], axis=(0, 2)) > 0 @@ -661,14 +663,10 @@ def get_convolution_weights( def compute_location_max_channel( templates_or_sorting_analyzer: SortingAnalyzer | Templates, unit_ids=None, - peak_sign: "neg" | "pos" | "both" = "neg", - mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", ) -> np.ndarray: """ Localize a unit using max channel. - This uses internally `get_template_extremum_channel()` - Parameters ---------- @@ -676,22 +674,14 @@ def compute_location_max_channel( A SortingAnalyzer or Templates object unit_ids: list[str] | list[int] | None A list of unit_id to restrict the computation - peak_sign : "neg" | "pos" | "both" - Sign of the template to find extremum channels - mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index" - Where the amplitude is computed - * "extremum" : take the peak value (max or min depending on `peak_sign`) - * "at_index" : take value at `nbefore` index - * "peak_to_peak" : take the peak-to-peak amplitude Returns ------- unit_locations: np.ndarray 2d """ - extremum_channels_index = get_template_extremum_channel( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=mode, outputs="index" - ) + extremum_channels_index = templates_or_sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + contact_locations = templates_or_sorting_analyzer.get_channel_locations() if unit_ids is None: unit_ids = templates_or_sorting_analyzer.unit_ids diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 0495e2c56e..45d8a6d0ae 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -2,7 +2,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_main_channel_peak_shift from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type @@ -15,28 +15,25 @@ class ComputeSpikeAmplitudes(BaseSpikeVectorExtension): Parameters ---------- - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of the template to compute extremum channel used to retrieve spike amplitudes. + """ extension_name = "spike_amplitudes" depend_on = ["templates"] nodepipeline_variables = ["amplitudes"] - def _set_params(self, peak_sign="neg"): - return super()._set_params(peak_sign=peak_sign) + def _set_params(self): + return super()._set_params() def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["peak_sign"] return_in_uV = self.sorting_analyzer.return_in_uV - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) - peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + + peak_shifts = get_template_main_channel_peak_shift(self.sorting_analyzer, with_dict=True) spike_retriever_node = SpikeRetriever( sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d11d4eb2c7..5fb15a6ea1 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -1,5 +1,4 @@ from spikeinterface.core.sortinganalyzer import register_result_extension -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.node_pipeline import SpikeRetriever from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension @@ -14,8 +13,6 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - peak_sign : "neg" | "pos" | "both", default: "neg" - The peak sign to use when looking for the template extremum channel. spike_retriever_kwargs : dict Arguments to control the spike retriever behavior. See `spikeinterface.sortingcomponents.peak_localization.SpikeRetriever`. @@ -38,14 +35,12 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): def _handle_backward_compatibility_on_load(self): # For backwards compatibility - this renames spike_retriver_kwargs to spike_retriever_kwargs if "spike_retriver_kwargs" in self.params: - self.params["peak_sign"] = self.params["spike_retriver_kwargs"].get("peak_sign", "neg") self.params["spike_retriever_kwargs"] = self.params.pop("spike_retriver_kwargs") def _set_params( self, ms_before=0.5, ms_after=0.5, - peak_sign="neg", spike_retriever_kwargs=None, method="center_of_mass", method_kwargs={}, @@ -55,7 +50,6 @@ def _set_params( return super()._set_params( ms_before=ms_before, ms_after=ms_after, - peak_sign=peak_sign, spike_retriever_kwargs=spike_retriever_kwargs, method=method, method_kwargs=method_kwargs, @@ -66,10 +60,8 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["peak_sign"] - extremum_channels_indices = get_template_extremum_channel( - self.sorting_analyzer, peak_sign=peak_sign, outputs="index" - ) + + extremum_channels_indices = self.sorting_analyzer.get_main_channels(outputs="index", with_dict=True) retriever_kwargs = { "channel_from_template": True, diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index e3c45fe8ef..16650a09ec 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -7,7 +7,7 @@ create_sorting_analyzer, generate_ground_truth_recording, set_global_job_kwargs, - get_template_extremum_amplitude, + get_template_main_channel_amplitude, ) from spikeinterface.core.generate import inject_some_split_units @@ -85,7 +85,7 @@ def get_dataset_to_merge(): analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) analyzer_raw.compute(["random_spikes", "templates"]) # select 3 largest templates to split - sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] + sort_by_amp = np.argsort(get_template_main_channel_amplitude(analyzer_raw, with_dict=False))[::-1] split_ids = sorting.unit_ids[sort_by_amp][:3] sorting_with_splits, split_unit_ids = inject_some_split_units( @@ -116,7 +116,7 @@ def get_dataset_to_split(): analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) analyzer_raw.compute(["random_spikes", "templates"]) # select 3 largest templates to split - sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] + sort_by_amp = np.argsort(list(get_template_main_channel_amplitude(analyzer_raw).values()))[::-1] large_units = sorting.unit_ids[sort_by_amp][:2] return recording, sorting, large_units diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 3e0eb0b632..7ba5614e38 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -53,13 +53,11 @@ def __init__( num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: - from spikeinterface.core.template_tools import get_template_extremum_channel + main_channels = self.templates.get_main_channels(main_channel_peak_sign=self.peak_sign, outputs="index", with_dict=False) - best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") - best_channels = np.array([best_channels[i] for i in templates.unit_ids]) channel_locations = recording.get_channel_locations() template_distances = np.linalg.norm( - channel_locations[:, None] - channel_locations[best_channels][np.newaxis, :], axis=2 + channel_locations[:, None] - channel_locations[main_channels][np.newaxis, :], axis=2 ) self.neighborhood_mask = template_distances <= neighborhood_radius_um else: diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index 37c13b9395..2fc49528fb 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -3,7 +3,6 @@ import numpy as np from spikeinterface.core import ( get_channel_distances, - get_template_extremum_channel, ) from spikeinterface.sortingcomponents.peak_detection.method_list import ( @@ -221,12 +220,12 @@ def __init__( self.sparse_templates_array_static = templates.templates_array self.dtype = self.sparse_templates_array_static.dtype - extremum_chan = get_template_extremum_channel(templates, peak_sign=peak_sign, outputs="index") + # as numpy vector - self.extremum_channel = np.array([extremum_chan[unit_id] for unit_id in unit_ids], dtype="int64") + self.main_channels = templates.get_main_channels(main_channel_peak_sign=peak_sign, outputs="index", with_dict=False) channel_locations = templates.probe.contact_positions - unit_locations = channel_locations[self.extremum_channel] + unit_locations = channel_locations[self.main_channels] self.channel_locations = channel_locations # distance between units diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 73a14bdee7..c81302a7cb 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -13,7 +13,7 @@ from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.core.sparsity import ChannelSparsity from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift +from spikeinterface.core.template_tools import get_template_main_channel_peak_shift from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.sorting_tools import get_numba_vector_to_list_of_spiketrain @@ -627,7 +627,7 @@ def clean_templates( if max_jitter_ms is not None: max_jitter = int(max_jitter_ms * templates.sampling_frequency / 1000.0) n_before = len(templates.unit_ids) - shifts = get_template_extremum_channel_peak_shift(templates) + shifts = get_template_main_channel_peak_shift(templates, with_dict=True) to_select = [] for unit_id in templates.unit_ids: if np.abs(shifts[unit_id]) <= max_jitter: diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index d2b1a21fdd..791058940c 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -6,7 +6,6 @@ from .utils import get_unit_colors from .traces import TracesWidget from spikeinterface.core import ChannelSparsity -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import SortingAnalyzer from spikeinterface.core.baserecording import BaseRecording from spikeinterface.core.basesorting import BaseSorting @@ -119,9 +118,9 @@ def __init__( sparsity = sorting_analyzer.sparsity else: if sparsity is None: - # in this case, we construct a sparsity dictionary only with the best channel - extremum_channel_ids = get_template_extremum_channel(sorting_analyzer) - unit_id_to_channel_ids = {u: [ch] for u, ch in extremum_channel_ids.items()} + # in this case, we construct a sparsity dictionary only with the main channel + main_channels = sorting_analyzer.get_main_channels(outputs="id", with_dict=True) + unit_id_to_channel_ids = {u: [ch] for u, ch in main_channels.items()} sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, unit_ids=sorting_analyzer.unit_ids, diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index ba2f939b80..378948817b 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -4,7 +4,7 @@ from .utils import get_unit_colors -from spikeinterface.core.template_tools import get_template_extremum_amplitude +from spikeinterface.core.template_tools import get_template_main_channel_amplitude class UnitDepthsWidget(BaseWidget): @@ -20,12 +20,10 @@ class UnitDepthsWidget(BaseWidget): by matplotlib. If None, default colors are chosen using the `get_some_colors` function. depth_axis : int, default: 1 The dimension of unit_locations that is depth - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of peak for amplitudes """ def __init__( - self, sorting_analyzer, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs + self, sorting_analyzer, unit_colors=None, depth_axis=1, backend=None, **backend_kwargs ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) @@ -42,7 +40,7 @@ def __init__( unit_locations = ulc.get_data(outputs="numpy") unit_depths = unit_locations[:, depth_axis] - unit_amplitudes = get_template_extremum_amplitude(sorting_analyzer, peak_sign=peak_sign) + unit_amplitudes = get_template_main_channel_amplitude(sorting_analyzer) unit_amplitudes = np.abs([unit_amplitudes[unit_id] for unit_id in unit_ids]) num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit(outputs="array") diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 6a212b1d0e..7d88a523d4 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -2,7 +2,6 @@ import numpy as np from probeinterface import ProbeGroup -from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sortinganalyzer import SortingAnalyzer from .base import BaseWidget, to_attr @@ -84,10 +83,10 @@ def __init__( if np.any(np.isnan(all_unit_locations[sorting.ids_to_indices(unit_ids)])): warnings.warn("Some unit locations contain NaN values. Replacing with extremum channel location.") - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) for unit_id in unit_ids: if np.any(np.isnan(unit_locations[unit_id])): - unit_locations[unit_id] = channel_locations[extremum_channel_indices[unit_id]] + unit_locations[unit_id] = channel_locations[main_channels[unit_id]] data_plot = dict( all_unit_ids=sorting.unit_ids, diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index ec781b5470..0ab6800462 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -3,7 +3,6 @@ import warnings import numpy as np -from spikeinterface.core.template_tools import get_template_extremum_channel from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -135,12 +134,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): col_counter += 1 unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") - extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, outputs="index") + main_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) + unit_location = unit_locations[unit_id] x, y = unit_location[0], unit_location[1] if np.isnan(x) or np.isnan(y): warnings.warn(f"Unit {unit_id} location contains NaN values. Replacing NaN extremum channel location.") - x, y = sorting_analyzer.get_channel_locations()[extremum_channel_indices[unit_id]] + x, y = sorting_analyzer.get_channel_locations()[main_channels[unit_id]] ax_unit_locations.set_xlim(x - 80, x + 80) ax_unit_locations.set_ylim(y - 250, y + 250) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 24994bb570..fb9050a9b7 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -3,7 +3,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from spikeinterface.core import ChannelSparsity, get_template_extremum_channel +from spikeinterface.core import ChannelSparsity class UnitWaveformDensityMapWidget(BaseWidget): @@ -23,8 +23,6 @@ class UnitWaveformDensityMapWidget(BaseWidget): If SortingAnalyzer is already sparse, the argument is ignored use_max_channel : bool, default: False Use only the max channel - peak_sign : "neg" | "pos" | "both", default: "neg" - Used to detect max channel only when use_max_channel=True unit_colors : dict | None, default: None Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted by matplotlib. If None, default colors are chosen using the `get_some_colors` function. @@ -41,7 +39,6 @@ def __init__( sparsity=None, same_axis=False, use_max_channel=False, - peak_sign="neg", unit_colors=None, backend=None, **backend_kwargs, @@ -59,9 +56,7 @@ def __init__( if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" - max_channels = get_template_extremum_channel( - sorting_analyzer, mode="extremum", peak_sign=peak_sign, outputs="index" - ) + max_channels = sorting_analyzer.get_main_channels(outputs="index", with_dict=True) # sparsity is done on all the units even if unit_ids is a few ones because some backends need them all if sorting_analyzer.is_sparse():