From 67aac797e08da30d402fc53cec63ebd3f09417eb Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Mon, 23 Mar 2026 11:01:43 -0400 Subject: [PATCH 01/15] bombcell full pipeline, edits to allow most config options native bombcell has etc --- .gitignore | 7 + .../how_to/full_pipeline_with_bombcell.py | 233 + in_container_params.json | 3 + in_container_recording.json | 15497 ++++++++++++++++ in_container_sorter_script.py | 28 + src/spikeinterface/core/sortinganalyzer.py | 43 +- .../curation/bombcell_curation.py | 70 +- .../metrics/quality/misc_metrics.py | 33 +- .../metrics/template/metrics.py | 1419 +- .../metrics/template/template_metrics.py | 194 +- 10 files changed, 16656 insertions(+), 871 deletions(-) create mode 100644 examples/how_to/full_pipeline_with_bombcell.py create mode 100644 in_container_params.json create mode 100644 in_container_recording.json create mode 100644 in_container_sorter_script.py diff --git a/.gitignore b/.gitignore index 2baa4b4f92..bfb43fcae1 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,10 @@ test_folder/ .DS_Store test_data.json uv.lock + + +# Julie's mess +playground.ipynb +playground2.ipynb +examples/how_to/full_pipeline_with_bombcell.ipynb +examples/how_to/compare_bombcell_unitrefine.ipynb diff --git a/examples/how_to/full_pipeline_with_bombcell.py b/examples/how_to/full_pipeline_with_bombcell.py new file mode 100644 index 0000000000..6ae6704cdd --- /dev/null +++ b/examples/how_to/full_pipeline_with_bombcell.py @@ -0,0 +1,233 @@ +""" +Full pipeline: preprocessing, spike sorting, and bombcell quality control. + +Neuropixels analysis pipeline: load SpikeGLX recording, preprocess, +run Kilosort4, compute quality/template metrics, and run bombcell to +label units as good, MUA, noise, or non-somatic. +""" + +import json +import matplotlib.pyplot as plt +from pathlib import Path +from pprint import pprint + +import spikeinterface.full as si +import spikeinterface.curation as sc +import spikeinterface.widgets as sw + +# %% Paths — edit these + +spikeglx_folder = Path("/path/to/your/spikeglx/recording") +base_folder = spikeglx_folder + +preprocessed_folder = base_folder / "preprocessed" +kilosort_folder = base_folder / "kilosort4_output" +analyzer_folder = base_folder / "sorting_analyzer.zarr" + +preprocessed_exists = (preprocessed_folder / "si_folder.json").exists() + +job_kwargs = dict(n_jobs=-1, chunk_duration="1s", progress_bar=True) + +# %% 1. Load recording + +raw_rec = si.read_spikeglx(spikeglx_folder, stream_name="imec0.ap", load_sync_channel=False) +print(raw_rec) + +# %% 2. Preprocess +# Highpass → bad channel removal → phase shift → common median reference. +# All lazy until saved. + +if not preprocessed_exists: + rec_filtered = si.highpass_filter(raw_rec, freq_min=300.0) + + bad_channel_ids, channel_labels = si.detect_bad_channels(rec_filtered) + print(f"Bad channels detected: {bad_channel_ids}") + rec_clean = rec_filtered.remove_channels(bad_channel_ids) + + # Save bad channel info + preprocessed_folder.mkdir(parents=True, exist_ok=True) + with open(preprocessed_folder / "bad_channels.json", "w") as f: + json.dump({"bad_channel_ids": [str(ch) for ch in bad_channel_ids]}, f, indent=2) + + rec_shifted = si.phase_shift(rec_clean) + rec_cmr = si.common_reference(rec_shifted, reference="global", operator="median") + + # Save to disk (Kilosort needs binary) + rec_preprocessed = rec_cmr.save(folder=preprocessed_folder, format="binary", **job_kwargs) +else: + print(f"Loading preprocessed recording from {preprocessed_folder}") + rec_preprocessed = si.load(preprocessed_folder) + +print(rec_preprocessed) + +# %% 3. Run Kilosort4 + +if kilosort_folder.exists(): + print(f"Loading existing Kilosort4 output from {kilosort_folder}") + # register_recording=False: avoids errors when the original recording + # path no longer exists (e.g. different mount point) + sorting = si.read_sorter_folder(kilosort_folder, register_recording=False) +else: + sorting = si.run_sorter( + sorter_name="kilosort4", + recording=rec_preprocessed, + folder=kilosort_folder, + remove_existing_folder=True, + verbose=True, + skip_kilosort_preprocessing=True, + do_CAR=False, + ) +print(f"Kilosort4 found {len(sorting.unit_ids)} units") + +# %% 4. Create SortingAnalyzer and compute extensions + +if analyzer_folder.exists(): + analyzer = si.load_sorting_analyzer(analyzer_folder) + if not analyzer.has_recording(): + analyzer.set_temporary_recording(rec_preprocessed) +else: + analyzer = si.create_sorting_analyzer( + sorting=sorting, + recording=rec_preprocessed, + sparse=True, + format="zarr", + folder=analyzer_folder, + return_in_uV=True, + ) + +# Core extensions +if not analyzer.has_extension("random_spikes"): + analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500) +if not analyzer.has_extension("waveforms"): + analyzer.compute("waveforms", ms_before=3.0, ms_after=3.0, **job_kwargs) +if not analyzer.has_extension("templates"): + analyzer.compute("templates", operators=["average", "median", "std"]) +if not analyzer.has_extension("noise_levels"): + analyzer.compute("noise_levels") + +# Quality metric prerequisites +if not analyzer.has_extension("spike_amplitudes"): + analyzer.compute("spike_amplitudes", **job_kwargs) +if not analyzer.has_extension("unit_locations"): + analyzer.compute("unit_locations") +if not analyzer.has_extension("spike_locations"): + analyzer.compute("spike_locations", **job_kwargs) + +# Template metrics (include_multi_channel_metrics for exp_decay) +if not analyzer.has_extension("template_metrics"): + analyzer.compute("template_metrics", include_multi_channel_metrics=True) + +# %% 5. Configure and compute quality metrics + +# Toggle options +compute_distance_metrics = False # needs PCA; best for stable/chronic recordings +compute_drift = True +label_non_somatic = True +split_non_somatic_good_mua = False + +# RPV method: "sliding_rp" (default, sweeps RP range) or "llobet" (single RP value) +rp_violation_method = "sliding_rp" + +qm_params = { + "presence_ratio": {"bin_duration_s": 60}, + "rp_violation": {"refractory_period_ms": 2.0, "censored_period_ms": 0.1}, + "sliding_rp_violation": { + "exclude_ref_period_below_ms": 0.5, + "max_ref_period_ms": 10.0, + "confidence_threshold": 0.9, + }, + "drift": {"interval_s": 60, "min_spikes_per_interval": 100}, +} + +metric_names = ["amplitude_median", "snr", "amplitude_cutoff", "num_spikes", "presence_ratio", "firing_rate"] + +if rp_violation_method == "sliding_rp": + metric_names.append("sliding_rp_violation") +else: + metric_names.append("rp_violation") + +if compute_drift: + metric_names.append("drift") + +if compute_distance_metrics: + metric_names.append("mahalanobis") # produces isolation_distance and l_ratio + if not analyzer.has_extension("principal_components"): + analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) + +# To add more metrics, append here and add a threshold below: +# metric_names.append("silhouette") # requires principal_components +# metric_names.append("d_prime") # requires principal_components + +if analyzer.has_extension("quality_metrics"): + analyzer.delete_extension("quality_metrics") +analyzer.compute("quality_metrics", metric_names=metric_names, metric_params=qm_params, **job_kwargs) + +# %% 6. Run bombcell +# +# The thresholds dict has three sections: "noise", "mua", "non-somatic". +# Each entry is {"greater": val, "less": val} (use None to disable one side). +# +# You can add any metric from the analyzer's DataFrame to any section. +# Custom metrics in "non-somatic" are OR'd with the built-in waveform shape logic. +# Metrics that haven't been computed are skipped with a warning. + +thresholds = sc.bombcell_get_default_thresholds() + +# Adjust existing thresholds +thresholds["mua"]["rpv"]["less"] = 0.1 +thresholds["mua"]["presence_ratio"]["greater"] = 0.7 + +# Add custom metrics — uncomment any of these: +# thresholds["mua"]["firing_rate"] = {"greater": 0.1, "less": None} +# thresholds["mua"]["silhouette"] = {"greater": 0.4, "less": None} +# thresholds["noise"]["half_width"] = {"greater": 0.05e-3, "less": 0.6e-3} +# thresholds["non-somatic"]["velocity"] = {"greater": 2.0, "less": None} + +# Disable a threshold: +# thresholds["mua"]["drift_ptp"] = {"greater": None, "less": None} + +pprint(thresholds) + +bombcell_labels = sc.bombcell_label_units( + sorting_analyzer=analyzer, + thresholds=thresholds, + label_non_somatic=label_non_somatic, + split_non_somatic_good_mua=split_non_somatic_good_mua, +) + +print(f"\nLabeled {len(bombcell_labels)} units") +print(bombcell_labels["bombcell_label"].value_counts()) + +# %% 7. Visualize + +sw.plot_unit_labels(analyzer, bombcell_labels["bombcell_label"], ylims=(-300, 100)) +sw.plot_metric_histograms(analyzer, thresholds, figsize=(15, 10)) +sw.plot_bombcell_labels_upset( + analyzer, + unit_labels=bombcell_labels["bombcell_label"], + thresholds=thresholds, + unit_labels_to_plot=["noise", "mua"], # add "non_soma" to see non-somatic patterns +) +plt.show() + +# %% 8. Remove noise units + +analyzer_clean_folder = base_folder / "sorting_analyzer_clean.zarr" + +if analyzer_clean_folder.exists(): + analyzer_clean = si.load_sorting_analyzer(analyzer_clean_folder) +else: + non_noise = bombcell_labels["bombcell_label"] != "noise" + analyzer_clean = analyzer.select_units( + analyzer.unit_ids[non_noise], + folder=analyzer_clean_folder, + format="zarr", + ) +print(f"Kept {len(analyzer_clean.unit_ids)} / {len(analyzer.unit_ids)} units after removing noise") + +# %% Notes on parameter tuning by recording type +# +# Chronic: set compute_distance_metrics=True, increase/disable drift threshold +# Acute: keep compute_distance_metrics=False, keep drift threshold strict +# Cerebellum: relax num_positive_peaks (complex spikes), shorter peak_to_trough_duration +# Striatum: lower spike count and presence ratio thresholds for MSNs diff --git a/in_container_params.json b/in_container_params.json new file mode 100644 index 0000000000..462dc67ed3 --- /dev/null +++ b/in_container_params.json @@ -0,0 +1,3 @@ +{ + "output_folder": "/Users/jf5479/Downloads/AL031_2019-12-02/spikeinterface_output/kilosort4_output" +} \ No newline at end of file diff --git a/in_container_recording.json b/in_container_recording.json new file mode 100644 index 0000000000..64f8f88c42 --- /dev/null +++ b/in_container_recording.json @@ -0,0 +1,15497 @@ +{ + "class": "spikeinterface.preprocessing.common_reference.CommonReferenceRecording", + "module": "spikeinterface", + "version": "0.103.3", + "kwargs": { + "recording": { + "class": "spikeinterface.preprocessing.phase_shift.PhaseShiftRecording", + "module": "spikeinterface", + "version": "0.103.3", + "kwargs": { + "recording": { + "class": "spikeinterface.core.channelslice.ChannelSliceRecording", + "module": "spikeinterface", + "version": "0.103.3", + "kwargs": { + "parent_recording": { + "class": "spikeinterface.preprocessing.filter.HighpassFilterRecording", + "module": "spikeinterface", + "version": "0.103.3", + "kwargs": { + "recording": { + "class": "spikeinterface.core.channelslice.ChannelSliceRecording", + "module": "spikeinterface", + "version": "0.103.3", + "kwargs": { + "parent_recording": { + "class": "spikeinterface.core.binaryrecordingextractor.BinaryRecordingExtractor", + "module": "spikeinterface", + "version": "0.103.3", + "kwargs": { + "file_paths": [ + "/Users/jf5479/Downloads/AL031_2019-12-02/AL031_2019-12-02_bank1_NatIm_g0_t0_bc_decompressed.imec0.ap.bin" + ], + "sampling_frequency": 30000.0, + "t_starts": null, + "num_channels": 385, + "dtype": " np.ndarray: """ @@ -2648,9 +2668,7 @@ def run(self, save=True, **kwargs): self._save_run_info() self._save_data() if self.format == "zarr": - import zarr - - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) + _safe_zarr_consolidate(self.sorting_analyzer._get_zarr_root().store) def save(self): self._save_params() @@ -2659,9 +2677,7 @@ def save(self): self._save_data() if self.format == "zarr": - import zarr - - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) + _safe_zarr_consolidate(self.sorting_analyzer._get_zarr_root().store) def _save_data(self): if self.format == "memory": @@ -2760,8 +2776,11 @@ def _reset_extension_folder(self): import zarr zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+") - _ = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) - zarr.consolidate_metadata(zarr_root.store) + # Use init_group directly instead of create_group to avoid the + # read-back that can fail on network/synced filesystems. + ext_path = "/".join([zarr_root["extensions"].path, self.extension_name]) + zarr.storage.init_group(zarr_root.store, path=ext_path, overwrite=True) + _safe_zarr_consolidate(zarr_root.store) def _delete_extension_folder(self): """ @@ -2773,12 +2792,10 @@ def _delete_extension_folder(self): shutil.rmtree(extension_folder) elif self.format == "zarr": - import zarr - zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+") if self.extension_name in zarr_root["extensions"]: del zarr_root["extensions"][self.extension_name] - zarr.consolidate_metadata(zarr_root.store) + _safe_zarr_consolidate(zarr_root.store) def delete(self): """ diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py index 888f5964ca..0cfdba1164 100644 --- a/src/spikeinterface/curation/bombcell_curation.py +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -11,6 +11,7 @@ from __future__ import annotations import operator +import warnings from pathlib import Path import json import numpy as np @@ -32,11 +33,16 @@ "snr", "amplitude_cutoff", "num_spikes", - "rp_contamination", + "rpv", # maps to rp_contamination or sliding_rp_violation "presence_ratio", "drift_ptp", + "isolation_distance", + "l_ratio", ] +# RPV metric column names (bombcell accepts "rpv" as threshold key and maps to whichever exists) +RPV_METRIC_COLUMNS = ["rp_contamination", "sliding_rp_violation"] + DEFAULT_NON_SOMATIC_METRICS = [ "peak_before_to_trough_ratio", "peak_before_width", @@ -45,6 +51,15 @@ "main_peak_to_trough_ratio", ] +# Metrics belonging to the built-in non-somatic groups. +# The compound logic is: (width_group AND ratio_group) OR main_peak_group +# Any metric in the "non-somatic" threshold section that is NOT listed here +# is treated as a standalone condition OR'd into the final result. +_NON_SOMATIC_WIDTH_GROUP = {"peak_before_width", "trough_width"} +_NON_SOMATIC_RATIO_GROUP = {"peak_before_to_trough_ratio", "peak_before_to_peak_after_ratio"} +_NON_SOMATIC_MAIN_PEAK_GROUP = {"main_peak_to_trough_ratio"} +_NON_SOMATIC_BUILTIN_METRICS = _NON_SOMATIC_WIDTH_GROUP | _NON_SOMATIC_RATIO_GROUP | _NON_SOMATIC_MAIN_PEAK_GROUP + def bombcell_get_default_thresholds() -> dict: """ @@ -68,9 +83,11 @@ def bombcell_get_default_thresholds() -> dict: "snr": {"greater": 5, "less": None}, "amplitude_cutoff": {"greater": None, "less": 0.2}, "num_spikes": {"greater": 300, "less": None}, - "rp_contamination": {"greater": None, "less": 0.1}, + "rpv": {"greater": None, "less": 0.1}, # applies to rp_contamination or sliding_rp_violation "presence_ratio": {"greater": 0.7, "less": None}, "drift_ptp": {"greater": None, "less": 100}, # um + "isolation_distance": {"greater": 20, "less": None}, + "l_ratio": {"greater": None, "less": 0.3}, }, "non-somatic": { "peak_before_to_trough_ratio": {"greater": None, "less": 3}, @@ -113,7 +130,11 @@ def bombcell_label_units( - Large main peak to trough ratio (using "main_peak_to_trough_ratio" metric) If units have a narrow peak and a large ratio OR a large main peak to trough ratio, - they are labeled as non-somatic. If `split_non_somatic_good_mua` is True, non-somatic units are further split + they are labeled as non-somatic. Custom metrics can also be added to the "non-somatic" + threshold section — any metric not part of the built-in groups (width, ratio, main_peak) + is treated as a standalone condition OR'd into the non-somatic detection. + + If `split_non_somatic_good_mua` is True, non-somatic units are further split into "non_soma_good" and "non_soma_mua", otherwise they are all labeled as "non_soma". Parameters @@ -173,6 +194,31 @@ def bombcell_label_units( else: raise ValueError("thresholds must be a dict, a JSON file path, or None") + # Map "rpv" threshold to actual column name (rp_contamination or sliding_rp_violation) + if "mua" in thresholds_dict and "rpv" in thresholds_dict["mua"]: + rpv_thresh = thresholds_dict["mua"].pop("rpv") + for col in RPV_METRIC_COLUMNS: + if col in combined_metrics.columns: + thresholds_dict["mua"][col] = rpv_thresh + break + + # Filter out threshold metrics that are not present in the metrics DataFrame. + # This allows optional metrics (e.g. isolation_distance, l_ratio) to be included + # in the default thresholds without crashing when they haven't been computed. + available_columns = set(combined_metrics.columns) + for section in ("noise", "mua", "non-somatic"): + if section not in thresholds_dict: + continue + missing = [m for m in thresholds_dict[section] if m not in available_columns] + if missing: + warnings.warn( + f"Bombcell thresholds reference metrics not found in the metrics DataFrame " + f"(section '{section}'): {missing}. These will be skipped. " + f"Compute them first if you want them included in the labeling." + ) + for m in missing: + del thresholds_dict[section][m] + n_units = len(combined_metrics) noise_thresholds = thresholds_dict.get("noise", {}) @@ -265,6 +311,24 @@ def bombcell_label_units( # (ratio AND width) OR standalone main_peak_to_trough is_non_somatic = (ratio_conditions & width_conditions) | large_main_peak + # Standalone custom metrics: any metric in non-somatic thresholds that is not + # part of the built-in groups is OR'd in as its own independent condition. + standalone_metrics = { + m: non_somatic_thresholds[m] + for m in non_somatic_thresholds + if m not in _NON_SOMATIC_BUILTIN_METRICS + } + for metric_name, thresh in standalone_metrics.items(): + standalone_labels = threshold_metrics_label_units( + metrics=combined_metrics, + thresholds={metric_name: thresh}, + pass_label="pass", + fail_label="fail", + operator="and", + nan_policy="ignore", + ) + is_non_somatic = is_non_somatic | (standalone_labels["label"] == "fail") + if split_non_somatic_good_mua: good_mask = unit_labels["label"] == "good" mua_mask = unit_labels["label"] == "mua" diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 3d6911b383..67d32640e4 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -553,11 +553,12 @@ def compute_sliding_rp_violations( exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, + confidence_threshold=0.9, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes contamination by using a sliding refractory period. - This metric computes the minimum contamination with at least 90% confidence. + This metric computes the minimum contamination with at least ``confidence_threshold`` confidence. Parameters ---------- @@ -581,11 +582,14 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). + confidence_threshold : float, default: 0.9 + Confidence threshold (between 0 and 1) for determining the minimum contamination. + A higher value requires stronger statistical evidence. Default is 0.9 (90% confidence). Returns ------- contamination : dict of floats - The minimum contamination at 90% confidence. + The minimum contamination at the specified confidence level. References ---------- @@ -632,6 +636,7 @@ def compute_sliding_rp_violations( exclude_ref_period_below_ms, max_ref_period_ms, contamination_values, + confidence_threshold=confidence_threshold, ) return contamination @@ -647,10 +652,11 @@ class SlidingRPViolation(BaseMetric): "exclude_ref_period_below_ms": 0.5, "max_ref_period_ms": 10, "contamination_values": None, + "confidence_threshold": 0.9, } metric_columns = {"sliding_rp_violation": float} metric_descriptions = { - "sliding_rp_violation": "Minimum contamination at 90% confidence using sliding refractory period method." + "sliding_rp_violation": "Minimum contamination at specified confidence using sliding refractory period method." } supports_periods = True @@ -1697,6 +1703,7 @@ def slidingRP_violations( max_ref_period_ms=10, contamination_values=None, return_conf_matrix=False, + confidence_threshold=0.9, ): """ A metric developed by IBL which determines whether the refractory period violations @@ -1720,14 +1727,16 @@ def slidingRP_violations( The contamination values to test, if None it is set to np.arange(0.5, 35, 0.5) / 100. return_conf_matrix : bool, default: False If True, the confidence matrix (n_contaminations, n_ref_periods) is returned. + confidence_threshold : float, default: 0.9 + Confidence threshold (between 0 and 1). Default is 0.9 (90% confidence). Code adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/master/python/slidingRP/metrics.py#L166 Returns ------- - min_cont_with_90_confidence : dict of floats - The minimum contamination with confidence > 90%. + min_contamination : dict of floats + The minimum contamination with confidence above the specified threshold. """ if contamination_values is None: contamination_values = np.arange(0.5, 35, 0.5) / 100 # vector of contamination values to test @@ -1768,17 +1777,17 @@ def slidingRP_violations( test_rp_centers_mask = rp_centers > exclude_ref_period_below_ms / 1000.0 # (in seconds) # only test for refractory period durations greater than 'exclude_ref_period_below_ms' - inds_confidence90 = np.row_stack(np.where(conf_matrix[:, test_rp_centers_mask] > 0.9)) + inds_above_threshold = np.row_stack(np.where(conf_matrix[:, test_rp_centers_mask] > confidence_threshold)) - if len(inds_confidence90[0]) > 0: - minI = np.min(inds_confidence90[0][0]) - min_cont_with_90_confidence = contamination_values[minI] + if len(inds_above_threshold[0]) > 0: + minI = np.min(inds_above_threshold[0][0]) + min_contamination = contamination_values[minI] else: - min_cont_with_90_confidence = np.nan + min_contamination = np.nan if return_conf_matrix: - return min_cont_with_90_confidence, conf_matrix + return min_contamination, conf_matrix else: - return min_cont_with_90_confidence + return min_contamination def _compute_rp_contamination_one_unit( diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py index 973800624d..3d84dceadc 100644 --- a/src/spikeinterface/metrics/template/metrics.py +++ b/src/spikeinterface/metrics/template/metrics.py @@ -1,227 +1,258 @@ -from collections import namedtuple -import numpy as np +from __future__ import annotations +import numpy as np +from collections import namedtuple +from scipy.signal import find_peaks, savgol_filter from spikeinterface.core.analyzer_extension_core import BaseMetric -def detect_peaks_on_templates( - template, - extremum_name, - prominence, - start_search_index, - end_search_index, - width=0, -): - """Detect peaks on template. Three attempts are made to find a valid peak: - - 1. Use the specified prominence threshold to detect peaks. - If multiple are found, the most prominent is selected as the main extremum. - 2. If no peaks are found at the initial threshold, the threshold is halved and detection is attempted again. - If multiple "peaks" are found at the half threshold, the most prominent is selected as the main extremum. - 3. If still no peaks are found, a last resort method is used: use the global maximum in the search window. - - Parameters - ---------- - template : np.ndarray - The 1D template to analyze for peaks. - extremum_name : str - Name of the extremum type being detected (e.g., "peak_before", "peak_after", "trough") for naming outputs. - prominence : float - Minimum prominence for peak detection. - width : float or int, optional - Required width of peaks in samples. Default is 0 (no width requirement). - start_search_index : int, optional - Sample index to start searching for peaks. Default is 0. - end_search_index : int or None, optional - Sample index to end searching for peaks. If None, search until the end of the template. Default is None. - - - Returns - ------- - peaks_info : dict - Dictionary containing information about detected peaks - """ - from scipy.signal import find_peaks - - template_for_search = template[start_search_index:end_search_index] - - peaks_info = {} - if end_search_index > start_search_index and start_search_index > 0: - locs_main, props_main = find_peaks(template_for_search, prominence=prominence, width=width) - locs_main = locs_main + start_search_index - - if len(locs_main) == 0: - lower_prominence = 0.5 * prominence - locs_half, props_half = find_peaks(template_for_search, prominence=lower_prominence, width=width) - locs_half = locs_half + start_search_index - - if len(locs_half) == 1: # Exactly 1 peak found at half threshold, use it - locs = locs_half - props = props_half - elif len(locs_half) > 1: # Multiple peaks found at half threshold, select most prominent - prominences = props_half.get("prominences", np.zeros(len(locs_half))) - best_idx = np.nanargmax(prominences) - locs = np.array([locs_half[best_idx]]) - props = {k: np.array([v[best_idx]]) for k, v in props_half.items()} - else: # No peaks found at half threshold, use global maximum - locs = np.array([start_search_index + np.argmax(template_for_search)], dtype=int) - props = {} - else: - locs = locs_main - props = props_main - - prominences = props.get("prominences") - if prominences is not None and len(prominences) > 0 and not np.all(np.isnan(prominences)): - extremum_idx = np.nanargmax(prominences) - else: - extremum_idx = 0 - - peaks_info[f"{extremum_name}_sample_indices"] = locs_main - peaks_info[f"{extremum_name}_prominences"] = props_main.get("prominences", np.full(len(locs_main), np.nan)) - peaks_info[f"{extremum_name}_widths"] = props_main.get("widths", np.full(len(locs_main), np.nan)) - peaks_info[f"{extremum_name}_index"] = locs[extremum_idx] - peaks_info[f"{extremum_name}_width"] = props.get("widths", np.array([np.nan]))[extremum_idx] - peaks_info[f"{extremum_name}_width_left"] = int(np.round(props.get("left_ips", np.array([-1]))[extremum_idx])) - peaks_info[f"{extremum_name}_width_right"] = int(np.round(props.get("right_ips", np.array([-1]))[extremum_idx])) - else: - peaks_info[f"{extremum_name}_sample_indices"] = np.array([], dtype=int) - peaks_info[f"{extremum_name}_prominences"] = np.array([], dtype=float) - peaks_info[f"{extremum_name}_widths"] = np.array([], dtype=float) - peaks_info[f"{extremum_name}_index"] = -1 - peaks_info[f"{extremum_name}_width"] = np.nan - peaks_info[f"{extremum_name}_width_left"] = -1 - peaks_info[f"{extremum_name}_width_right"] = -1 - - return peaks_info - - -def get_trough_and_peak_idx( - template, - sampling_frequency, - min_thresh_detect_peaks_troughs=0.4, - edge_exclusion_ms=0.1, - min_peak_trough_distance_ratio=0.2, - min_extremum_distance_samples=3, -): +def get_trough_and_peak_idx(template, min_thresh_detect_peaks_troughs=0.4, smooth=True, smooth_window_frac=0.1, smooth_polyorder=3): """ Detect troughs and peaks in a template waveform and return detailed information about each detected feature. - Trough are defined as "minimum" points (negative peaks) and peaks as "maximum" points (positive peaks). - - The function will detect troughs first (by inverting the template and using find_peaks), then peaks before and after - the main trough. For each detection, three attempts are made: - - 1. Use the specified prominence threshold to detect peaks/troughs. - If multiple are found, the most prominent is selected as the main extremum. - 2. If no peaks/troughs are found at the initial threshold, the threshold is halved and detection is attempted again. - If multiple "peaks" are found at the half threshold, the most prominent is selected as the main extremum. - 3. If still no peaks/troughs are found, a last resort method is used: the global extremum (max for peaks, - min for troughs) in the search window is selected as the main extremum. - - Extremum are filtered to ensure a minimum distance from each other and from the edges of the template, to prevent - spurious detections. Parameters ---------- template : numpy.ndarray The 1D template waveform - sampling_frequency : float - The sampling frequency in Hz - min_thresh_detect_peaks_troughs : float, default: 0.3 + min_thresh_detect_peaks_troughs : float, default: 0.4 Minimum prominence threshold as a fraction of the template's absolute max value - edge_exclusion_ms : float, default: 0.1 - Duration in ms to exclude from the start and end of the template - when detecting peaks/troughs. Prevents spurious edge detections. - min_peak_trough_distance_ratio : float, default: 0.2 - Minimum peak-trough distance as a fraction of trough half-width. Used to filter out peaks too close to trough. - min_extremum_distance_samples : int, default: 3 - Minimum distance between consecutive extrema (peaks and troughs) and between extrema and edges in samples. + smooth : bool, default: True + Whether to apply smoothing before peak detection + smooth_window_frac : float, default: 0.1 + Smoothing window length as a fraction of template length (0.05-0.2 recommended) + smooth_polyorder : int, default: 3 + Polynomial order for Savitzky-Golay filter (must be < window_length) Returns ------- - peaks_info : dict - Dictionary containing various information about detected extrema in "peak_before", "trough", "peak_after": - - "{extremum}_sample_indices": array of all extrema sample indices using main prominence threshold - - "{extremum}_prominences": array of all extrema prominences using main prominence threshold - - "{extremum}_widths": array of all extrema widths using main prominence threshold - - "{extremum}_index": sample index of the main extremum in template - - "{extremum}_width": width of the main extremum in samples - - "{extremum}_width_left": sample index of the left intersection point of the main with its prominence level - - "{extremum}_width_right": sample index of the right intersection point of the main with its prominence level - - "{extremum}_half_width_left": sample index of the left intersection point of the main with half of amplitude - - "{extremum}_half_width_right": sample index of the right intersection point of the main with half of amplitude + troughs : dict + Dictionary containing: + - "indices": array of all trough indices + - "values": array of all trough values + - "prominences": array of all trough prominences + - "widths": array of all trough widths + - "main_idx": index of the main trough (most prominent) + - "main_loc": location (sample index) of the main trough in template + peaks_before : dict + Dictionary containing peaks detected before the main trough (initial peaks): + - "indices": array of all peak indices (in original template coordinates) + - "values": array of all peak values + - "prominences": array of all peak prominences + - "widths": array of all peak widths + - "main_idx": index of the main peak (most prominent) + - "main_loc": location (sample index) of the main peak in template + peaks_after : dict + Dictionary containing peaks detected after the main trough (repolarization peaks): + - "indices": array of all peak indices (in original template coordinates) + - "values": array of all peak values + - "prominences": array of all peak prominences + - "widths": array of all peak widths + - "main_idx": index of the main peak (most prominent) + - "main_loc": location (sample index) of the main peak in template """ - from scipy.signal import find_peaks - assert template.ndim == 1 - peaks_info = {} + # Save original for plotting + template_original = template.copy() + + # Smooth template to reduce noise while preserving peaks using Savitzky-Golay filter + if smooth: + window_length = int(len(template) * smooth_window_frac) // 2 * 2 + 1 + window_length = max(smooth_polyorder + 2, window_length) # Must be > polyorder + template = savgol_filter(template, window_length=window_length, polyorder=smooth_polyorder) + + # Initialize empty result dictionaries + empty_dict = { + "indices": np.array([], dtype=int), + "values": np.array([]), + "prominences": np.array([]), + "widths": np.array([]), + "main_idx": None, + "main_loc": None, + } + # Get min prominence to detect peaks and troughs relative to template abs max value min_prominence = min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) - # Compute edge exclusion zone in samples - num_samples = len(template) - edge_samples = int(edge_exclusion_ms / 1000 * sampling_frequency) if edge_exclusion_ms > 0 else 0 - # Interior region for edge exclusion - left_edge = edge_samples if edge_samples > 0 else 0 - right_edge = num_samples - edge_samples if edge_samples > 0 else num_samples - - # --- Detect troughs --- - peaks_info = {} - peaks_info_trough = detect_peaks_on_templates( - template=-template, - extremum_name="trough", - start_search_index=left_edge, - end_search_index=right_edge, - prominence=min_prominence, - ) - main_trough_sample_index = peaks_info_trough["trough_index"] - peaks_info.update(peaks_info_trough) - - # Prevents detecting spurious peaks right next to the trough - _, hw_l, hw_r = _compute_halfwidth(-template, main_trough_sample_index, sampling_frequency) - if hw_l >= 0 and hw_r >= 0: - trough_hw_samples = hw_r - hw_l - min_peak_trough_dist = max( - min_extremum_distance_samples, int(min_peak_trough_distance_ratio * trough_hw_samples) - ) + # --- Find troughs (by inverting waveform and using find_peaks) --- + trough_locs, trough_props = find_peaks(-template, prominence=min_prominence, width=0) + + if len(trough_locs) == 0: + # Fallback: use global minimum + trough_locs = np.array([np.nanargmin(template)]) + trough_props = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])} + + # Determine main trough (most prominent, or first if no valid prominences) + trough_prominences = trough_props.get("prominences", np.array([])) + if len(trough_prominences) > 0 and not np.all(np.isnan(trough_prominences)): + main_trough_idx = np.nanargmax(trough_prominences) else: - min_peak_trough_dist = min_extremum_distance_samples + main_trough_idx = 0 + + main_trough_loc = trough_locs[main_trough_idx] + + troughs = { + "indices": trough_locs, + "values": template[trough_locs], + "prominences": trough_props.get("prominences", np.full(len(trough_locs), np.nan)), + "widths": trough_props.get("widths", np.full(len(trough_locs), np.nan)), + "main_idx": main_trough_idx, + "main_loc": main_trough_loc, + } # --- Find peaks before the main trough --- - peaks_info_before = detect_peaks_on_templates( - template=template, - extremum_name="peak_before", - prominence=min_prominence, - start_search_index=left_edge, - end_search_index=main_trough_sample_index - min_peak_trough_dist, - ) - peaks_info.update(peaks_info_before) + if main_trough_loc > 3: + template_before = template[:main_trough_loc] - # --- Find peaks after the main trough (repolarization peaks) --- - peaks_info_after = detect_peaks_on_templates( - template=template, - extremum_name="peak_after", - prominence=min_prominence, - start_search_index=main_trough_sample_index + min_peak_trough_dist, - end_search_index=right_edge, - ) - peaks_info.update(peaks_info_after) + # Try with original prominence + peak_locs_before, peak_props_before = find_peaks( + template_before, prominence=min_prominence, width=0 + ) - # Compute half-widths - for k in ("peak_before", "trough", "peak_after"): - sample_index = peaks_info[f"{k}_index"] - sign = -1 if k == "trough" else 1 - _, l, r = _compute_halfwidth(sign * template, sample_index, sampling_frequency) - peaks_info[f"{k}_half_width_left"] = l - peaks_info[f"{k}_half_width_right"] = r + # If no peaks found, try with lower prominence (keep only max peak) + if len(peak_locs_before) == 0: + lower_prominence = 0.075 * min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) + peak_locs_before, peak_props_before = find_peaks( + template_before, prominence=lower_prominence, width=0 + ) + # Keep only the most prominent peak when using lower threshold + if len(peak_locs_before) > 1: + prominences = peak_props_before.get("prominences", np.array([])) + if len(prominences) > 0 and not np.all(np.isnan(prominences)): + max_idx = np.nanargmax(prominences) + peak_locs_before = np.array([peak_locs_before[max_idx]]) + peak_props_before = { + "prominences": np.array([prominences[max_idx]]), + "widths": np.array([peak_props_before.get("widths", np.array([np.nan]))[max_idx]]), + } + + # If still no peaks found, fall back to argmax + if len(peak_locs_before) == 0: + peak_locs_before = np.array([np.nanargmax(template_before)]) + peak_props_before = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])} + + peak_prominences_before = peak_props_before.get("prominences", np.array([])) + if len(peak_prominences_before) > 0 and not np.all(np.isnan(peak_prominences_before)): + main_peak_before_idx = np.nanargmax(peak_prominences_before) + else: + main_peak_before_idx = 0 + + peaks_before = { + "indices": peak_locs_before, + "values": template[peak_locs_before], + "prominences": peak_props_before.get("prominences", np.full(len(peak_locs_before), np.nan)), + "widths": peak_props_before.get("widths", np.full(len(peak_locs_before), np.nan)), + "main_idx": main_peak_before_idx, + "main_loc": peak_locs_before[main_peak_before_idx], + } + else: + peaks_before = empty_dict.copy() - return peaks_info + # --- Find peaks after the main trough (repolarization peaks) --- + if main_trough_loc < len(template) - 3: + template_after = template[main_trough_loc:] + # Try with original prominence + peak_locs_after, peak_props_after = find_peaks( + template_after, prominence=min_prominence, width=0 + ) -def get_main_to_next_extremum_duration(template, peaks_info, sampling_frequency, **kwargs): + # If no peaks found, try with lower prominence (keep only max peak) + if len(peak_locs_after) == 0: + lower_prominence = 0.075 * min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) + peak_locs_after, peak_props_after = find_peaks( + template_after, prominence=lower_prominence, width=0 + ) + # Keep only the most prominent peak when using lower threshold + if len(peak_locs_after) > 1: + prominences = peak_props_after.get("prominences", np.array([])) + if len(prominences) > 0 and not np.all(np.isnan(prominences)): + max_idx = np.nanargmax(prominences) + peak_locs_after = np.array([peak_locs_after[max_idx]]) + peak_props_after = { + "prominences": np.array([prominences[max_idx]]), + "widths": np.array([peak_props_after.get("widths", np.array([np.nan]))[max_idx]]), + } + + # If still no peaks found, fall back to argmax + if len(peak_locs_after) == 0: + peak_locs_after = np.array([np.nanargmax(template_after)]) + peak_props_after = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])} + + # Convert to original template coordinates + peak_locs_after_abs = peak_locs_after + main_trough_loc + + peak_prominences_after = peak_props_after.get("prominences", np.array([])) + if len(peak_prominences_after) > 0 and not np.all(np.isnan(peak_prominences_after)): + main_peak_after_idx = np.nanargmax(peak_prominences_after) + else: + main_peak_after_idx = 0 + + peaks_after = { + "indices": peak_locs_after_abs, + "values": template[peak_locs_after_abs], + "prominences": peak_props_after.get("prominences", np.full(len(peak_locs_after), np.nan)), + "widths": peak_props_after.get("widths", np.full(len(peak_locs_after), np.nan)), + "main_idx": main_peak_after_idx, + "main_loc": peak_locs_after_abs[main_peak_after_idx], + } + else: + peaks_after = empty_dict.copy() + + # Quick visualization (set to True for debugging) + _plot = False + if _plot: + import matplotlib.pyplot as plt + + # Old simple method for comparison (argmin/argmax) + old_trough_idx = np.nanargmin(template) + old_peak_idx = np.nanargmax(template[old_trough_idx:]) + old_trough_idx + + fig, ax = plt.subplots(figsize=(10, 5)) + ax.plot(template_original, color="lightgray", lw=1, label="original (noisy)") + ax.plot(template, "k-", lw=1.5, label="smoothed") + + # Plot old method (simple argmin/argmax) + ax.axvline(old_trough_idx, color="gray", ls="--", alpha=0.5, label="old trough (argmin)") + ax.axvline(old_peak_idx, color="gray", ls=":", alpha=0.5, label="old peak (argmax after trough)") + + # Plot all detected troughs + ax.scatter(troughs["indices"], troughs["values"], c="blue", s=50, marker="v", zorder=5, label="troughs") + if troughs["main_loc"] is not None: + ax.scatter(troughs["main_loc"], template[troughs["main_loc"]], c="blue", s=150, marker="v", + edgecolors="red", linewidths=2, zorder=6, label="main trough") + + # Plot all peaks before + if len(peaks_before["indices"]) > 0: + ax.scatter(peaks_before["indices"], peaks_before["values"], c="green", s=50, marker="^", + zorder=5, label="peaks before") + if peaks_before["main_loc"] is not None: + ax.scatter(peaks_before["main_loc"], template[peaks_before["main_loc"]], c="green", s=150, + marker="^", edgecolors="red", linewidths=2, zorder=6, label="main peak before") + + # Plot all peaks after + if len(peaks_after["indices"]) > 0: + ax.scatter(peaks_after["indices"], peaks_after["values"], c="orange", s=50, marker="^", + zorder=5, label="peaks after") + if peaks_after["main_loc"] is not None: + ax.scatter(peaks_after["main_loc"], template[peaks_after["main_loc"]], c="orange", s=150, + marker="^", edgecolors="red", linewidths=2, zorder=6, label="main peak after") + + ax.axhline(0, color="gray", ls="-", alpha=0.3) + ax.set_xlabel("Sample") + ax.set_ylabel("Amplitude") + ax.legend(loc="best", fontsize=8) + ax.set_title(f"Trough/Peak Detection (prominence threshold: {min_thresh_detect_peaks_troughs})") + plt.tight_layout() + plt.show() + + return troughs, peaks_before, peaks_after + + +def get_waveform_duration(template, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs): """ - Calculate duration from the main extremum to the next extremum. + Calculate waveform duration from the main extremum to the next extremum. The duration is measured from the largest absolute feature (main trough or main peak) to the next extremum. For typical negative-first waveforms, this is trough-to-peak. @@ -231,72 +262,76 @@ def get_main_to_next_extremum_duration(template, peaks_info, sampling_frequency, ---------- template : numpy.ndarray The 1D template waveform - peaks_info : dict - Peaks and troughs detection results from get_trough_and_peak_idx sampling_frequency : float The sampling frequency in Hz + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx Returns ------- - main_to_next_extremum_duration : float - Duration in seconds from main extremum to next extremum + waveform_duration_us : float + Waveform duration in microseconds """ # Get main locations and values - trough_index = peaks_info["trough_index"] - trough_val = template[trough_index] if trough_index is not None else None + trough_loc = troughs["main_loc"] + trough_val = template[trough_loc] if trough_loc is not None else None - peak_before_index = peaks_info["peak_before_index"] - peak_before_val = template[peak_before_index] if peak_before_index is not None else None + peak_before_loc = peaks_before["main_loc"] + peak_before_val = template[peak_before_loc] if peak_before_loc is not None else None - peak_after_index = peaks_info["peak_after_index"] - peak_after_val = template[peak_after_index] if peak_after_index is not None else None + peak_after_loc = peaks_after["main_loc"] + peak_after_val = template[peak_after_loc] if peak_after_loc is not None else None # Find the main extremum (largest absolute value) candidates = [] - if trough_index is not None and trough_val is not None: - candidates.append(("trough", trough_index, abs(trough_val))) - if peak_before_index is not None and peak_before_val is not None: - candidates.append(("peak_before", peak_before_index, abs(peak_before_val))) - if peak_after_index is not None and peak_after_val is not None: - candidates.append(("peak_after", peak_after_index, abs(peak_after_val))) + if trough_loc is not None and trough_val is not None: + candidates.append(("trough", trough_loc, abs(trough_val))) + if peak_before_loc is not None and peak_before_val is not None: + candidates.append(("peak_before", peak_before_loc, abs(peak_before_val))) + if peak_after_loc is not None and peak_after_val is not None: + candidates.append(("peak_after", peak_after_loc, abs(peak_after_val))) if len(candidates) == 0: return np.nan # Sort by absolute value to find main extremum candidates.sort(key=lambda x: x[2], reverse=True) - main_type, main_index, _ = candidates[0] + main_type, main_loc, _ = candidates[0] # Find the next extremum after the main one if main_type == "trough": # Main is trough, next is peak_after - if peak_after_index is not None: - duration_samples = abs(peak_after_index - main_index) - elif peak_before_index is not None: - duration_samples = abs(main_index - peak_before_index) + if peak_after_loc is not None: + duration_samples = abs(peak_after_loc - main_loc) + elif peak_before_loc is not None: + duration_samples = abs(main_loc - peak_before_loc) else: return np.nan elif main_type == "peak_before": # Main is peak before, next is trough - if trough_index is not None: - duration_samples = abs(trough_index - main_index) + if trough_loc is not None: + duration_samples = abs(trough_loc - main_loc) else: return np.nan else: # peak_after # Main is peak after, previous is trough - if trough_index is not None: - duration_samples = abs(main_index - trough_index) + if trough_loc is not None: + duration_samples = abs(main_loc - trough_loc) else: return np.nan - # Convert to seconds - main_to_next_extremum_duration = duration_samples / sampling_frequency + # Convert to microseconds + waveform_duration_us = (duration_samples / sampling_frequency) * 1e6 - return main_to_next_extremum_duration + return waveform_duration_us -def get_waveform_ratios(template, peaks_info, **kwargs): +def get_waveform_ratios(template, troughs, peaks_before, peaks_after, **kwargs): """ Calculate various waveform amplitude ratios. @@ -304,8 +339,12 @@ def get_waveform_ratios(template, peaks_info, **kwargs): ---------- template : numpy.ndarray The 1D template waveform - peaks_info : dict - Peaks and troughs detection results from get_trough_and_peak_idx + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx Returns ------- @@ -317,13 +356,9 @@ def get_waveform_ratios(template, peaks_info, **kwargs): - "main_peak_to_trough_ratio": ratio of larger peak to trough amplitude """ # Get absolute amplitudes - trough_amp = abs(template[peaks_info["trough_index"]]) if peaks_info["trough_index"] is not None else np.nan - peak_before_amp = ( - abs(template[peaks_info["peak_before_index"]]) if peaks_info["peak_before_index"] is not None else np.nan - ) - peak_after_amp = ( - abs(template[peaks_info["peak_after_index"]]) if peaks_info["peak_after_index"] is not None else np.nan - ) + trough_amp = abs(template[troughs["main_loc"]]) if troughs["main_loc"] is not None else np.nan + peak_before_amp = abs(template[peaks_before["main_loc"]]) if peaks_before["main_loc"] is not None else np.nan + peak_after_amp = abs(template[peaks_after["main_loc"]]) if peaks_after["main_loc"] is not None else np.nan def safe_ratio(a, b): if np.isnan(a) or np.isnan(b) or b == 0: @@ -334,14 +369,7 @@ def safe_ratio(a, b): "peak_before_to_trough_ratio": safe_ratio(peak_before_amp, trough_amp), "peak_after_to_trough_ratio": safe_ratio(peak_after_amp, trough_amp), "peak_before_to_peak_after_ratio": safe_ratio(peak_before_amp, peak_after_amp), - "main_peak_to_trough_ratio": safe_ratio( - ( - max(peak_before_amp, peak_after_amp) - if not (np.isnan(peak_before_amp) and np.isnan(peak_after_amp)) - else np.nan - ), - trough_amp, - ), + "main_peak_to_trough_ratio": safe_ratio(max(peak_before_amp, peak_after_amp) if not (np.isnan(peak_before_amp) and np.isnan(peak_after_amp)) else np.nan, trough_amp), } return ratios @@ -351,11 +379,9 @@ def get_waveform_baseline_flatness(template, sampling_frequency, **kwargs): """ Compute the baseline flatness of the waveform. - This metric measures the max deviation of the baseline from its own mean, - relative to the max deviation of the whole waveform from the baseline mean. - A lower value indicates a flat baseline (expected for good units). - Referenced to baseline mean so it works with both zero-centered and - DC-offset data. + This metric measures the ratio of the max absolute amplitude in the baseline + window to the max absolute amplitude of the whole waveform. A lower value + indicates a flat baseline (expected for good units). Parameters ---------- @@ -370,8 +396,7 @@ def get_waveform_baseline_flatness(template, sampling_frequency, **kwargs): Returns ------- baseline_flatness : float - Ratio of max(abs(baseline - baseline_mean)) / max(abs(template - baseline_mean)). - Lower = flatter baseline. + Ratio of max(abs(baseline)) / max(abs(waveform)). Lower = flatter baseline. """ baseline_window_ms = kwargs.get("baseline_window_ms", (0.0, 0.5)) @@ -394,49 +419,64 @@ def get_waveform_baseline_flatness(template, sampling_frequency, **kwargs): if len(baseline_segment) == 0: return np.nan - baseline_mean = np.nanmean(baseline_segment) - max_baseline_dev = np.nanmax(np.abs(baseline_segment - baseline_mean)) - max_waveform_dev = np.nanmax(np.abs(template - baseline_mean)) + max_baseline = np.nanmax(np.abs(baseline_segment)) + max_waveform = np.nanmax(np.abs(template)) - if max_waveform_dev == 0 or np.isnan(max_waveform_dev): + if max_waveform == 0 or np.isnan(max_waveform): return np.nan - baseline_flatness = max_baseline_dev / max_waveform_dev + baseline_flatness = max_baseline / max_waveform return baseline_flatness -def get_waveform_widths(peaks_info, sampling_frequency, **kwargs): +def get_waveform_widths(template, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs): """ - Get the widths of the main trough and peaks in seconds. + Get the widths of the main trough and peaks in microseconds. Parameters ---------- - peaks_info : dict - Peaks and troughs detection results from get_trough_and_peak_idx + template : numpy.ndarray + The 1D template waveform sampling_frequency : float The sampling frequency in Hz + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx Returns ------- widths : dict Dictionary containing: - - "trough_width": width of main trough in seconds - - "peak_before_width": width of main peak before trough in seconds - - "peak_after_width": width of main peak after trough in seconds + - "trough_width_us": width of main trough in microseconds + - "peak_before_width_us": width of main peak before trough in microseconds + - "peak_after_width_us": width of main peak after trough in microseconds """ + def get_main_width(feature_dict): + if feature_dict["main_idx"] is None: + return np.nan + widths = feature_dict.get("widths", np.array([])) + if len(widths) == 0: + return np.nan + main_idx = feature_dict["main_idx"] + if main_idx < len(widths): + return widths[main_idx] + return np.nan - # Convert from samples to seconds - samples_to_seconds = 1.0 / sampling_frequency + # Convert from samples to microseconds + samples_to_us = 1e6 / sampling_frequency - trough_width = peaks_info["trough_width"] - peak_before_width = peaks_info["peak_before_width"] - peak_after_width = peaks_info["peak_after_width"] + trough_width = get_main_width(troughs) + peak_before_width = get_main_width(peaks_before) + peak_after_width = get_main_width(peaks_after) widths = { - "trough_width": trough_width * samples_to_seconds if not np.isnan(trough_width) else np.nan, - "peak_before_width": peak_before_width * samples_to_seconds if not np.isnan(peak_before_width) else np.nan, - "peak_after_width": peak_after_width * samples_to_seconds if not np.isnan(peak_after_width) else np.nan, + "trough_width_us": trough_width * samples_to_us if not np.isnan(trough_width) else np.nan, + "peak_before_width_us": peak_before_width * samples_to_us if not np.isnan(peak_before_width) else np.nan, + "peak_after_width_us": peak_after_width * samples_to_us if not np.isnan(peak_after_width) else np.nan, } return widths @@ -444,180 +484,181 @@ def get_waveform_widths(peaks_info, sampling_frequency, **kwargs): ######################################################################################### # Single-channel metrics -def get_peak_to_trough_duration(peaks_info, sampling_frequency, **kwargs) -> float: +def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: """ - Return the duration in seconds between the main trough and the main peak after the trough of input waveforms. - - The function assumes that the trough comes before the peak. + Return the peak to valley duration in seconds of input waveforms. Parameters ---------- - peaks_info : dict - Peaks and troughs detection results from get_trough_and_peak_idx + template_single: numpy.ndarray + The 1D template waveform sampling_frequency : float The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak Returns ------- - pt_duration: float - The duration in seconds between the main trough and the main peak after the trough + ptv: float + The peak to valley duration in seconds """ - if peaks_info["trough_index"] is None or peaks_info["peak_after_index"] is None: + if trough_idx is None or peak_idx is None: + troughs, _, peaks_after = get_trough_and_peak_idx(template_single) + trough_idx = troughs["main_loc"] + peak_idx = peaks_after["main_loc"] + if trough_idx is None or peak_idx is None: return np.nan - pt_duration = (peaks_info["peak_after_index"] - peaks_info["trough_index"]) / sampling_frequency - return pt_duration + ptv = (peak_idx - trough_idx) / sampling_frequency + return ptv -def _compute_halfwidth(template, extremum_index, sampling_frequency): - """Compute the halfwidth of a positive peak. +def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs) -> float: + """ + Return the peak to trough ratio of input waveforms. Parameters ---------- - template : numpy.ndarray + template_single: numpy.ndarray The 1D template waveform - extremum_index: int - The index of the extremum sampling_frequency : float The sampling frequency of the template + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak Returns ------- - hw: float - The half width in seconds + ptratio: float + The peak to trough ratio """ - extremum_val = template[extremum_index] - # threshold is half of peak height (assuming baseline is 0) - threshold = 0.5 * extremum_val - - # Find where the template crosses the threshold before and after the extremum - threshold_crossings = np.where(np.diff(template >= threshold))[0] - crossings_before_extremum = threshold_crossings[threshold_crossings < extremum_index] - crossings_after_extremum = threshold_crossings[threshold_crossings >= extremum_index] - - if len(crossings_before_extremum) == 0 or len(crossings_after_extremum) == 0: - return np.nan, -1, -1 - else: - last_crossing_before_extremum = crossings_before_extremum[-1] - first_crossing_after_extremum = crossings_after_extremum[0] - - hw = (first_crossing_after_extremum - last_crossing_before_extremum) / sampling_frequency - - return hw, last_crossing_before_extremum, first_crossing_after_extremum + if trough_idx is None or peak_idx is None: + troughs, _, peaks_after = get_trough_and_peak_idx(template_single) + trough_idx = troughs["main_loc"] + peak_idx = peaks_after["main_loc"] + if trough_idx is None or peak_idx is None: + return np.nan + ptratio = template_single[peak_idx] / template_single[trough_idx] + return ptratio -def get_half_widths(main_channel_template, sampling_frequency, peaks_info, **kwargs) -> tuple[float, float]: +def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: """ - Return the half width of the main trough and main peak in seconds. + Return the half width of input waveforms in seconds. Parameters ---------- - main_channel_template: numpy.ndarray + template_single: numpy.ndarray The 1D template waveform sampling_frequency : float The sampling frequency of the template - peaks_info : dict - Peaks and troughs detection results from get_trough_and_peak_idx + trough_idx: int, default: None + The index of the trough + peak_idx: int, default: None + The index of the peak Returns ------- - hw: tuple[float, float] - The half width in seconds of (trough, peak) + hw: float + The half width in seconds """ - # Compute the trough half width - if peaks_info["trough_index"] is None: - # Edge case: template is flat - trough_hw = np.nan - else: - # for the trough, we invert the waveform to compute halfwidth as for a peak - trough_hw, _, _ = _compute_halfwidth(-main_channel_template, peaks_info["trough_index"], sampling_frequency) - # Compute the peak half width - if peaks_info["peak_after_index"] is None and peaks_info["peak_before_index"] is None: - # Edge case: template is flat - peak_hw = np.nan + if trough_idx is None or peak_idx is None: + troughs, _, peaks_after = get_trough_and_peak_idx(template_single) + trough_idx = troughs["main_loc"] + peak_idx = peaks_after["main_loc"] + + if peak_idx is None or peak_idx == 0: + return np.nan + + trough_val = template_single[trough_idx] + # threshold is half of peak height (assuming baseline is 0) + threshold = 0.5 * trough_val + + (cpre_idx,) = np.where(template_single[:trough_idx] < threshold) + (cpost_idx,) = np.where(template_single[trough_idx:] < threshold) + + if len(cpre_idx) == 0 or len(cpost_idx) == 0: + hw = np.nan + else: - # find largest peak - if peaks_info["peak_after_index"] is not None and peaks_info["peak_before_index"] is not None: - peak_after_val = main_channel_template[peaks_info["peak_after_index"]] - peak_before_val = main_channel_template[peaks_info["peak_before_index"]] - if peak_after_val >= peak_before_val: - peak_index = peaks_info["peak_after_index"] - else: - peak_index = peaks_info["peak_before_index"] - elif peaks_info["peak_after_index"] is not None: - peak_index = peaks_info["peak_after_index"] - else: - peak_index = peaks_info["peak_before_index"] - peak_hw, _, _ = _compute_halfwidth(main_channel_template, peak_index, sampling_frequency) + # last occurence of template lower than thr, before peak + cross_pre_pk = cpre_idx[0] - 1 + # first occurence of template lower than peak, after peak + cross_post_pk = cpost_idx[-1] + 1 + trough_idx - return trough_hw, peak_hw + hw = (cross_post_pk - cross_pre_pk) / sampling_frequency + return hw -def get_repolarization_slope(main_channel_template, sampling_frequency, peaks_info, **kwargs): +def get_repolarization_slope(template_single, sampling_frequency, trough_idx=None, **kwargs): """ - Return slope of repolarization period between trough and baseline. + Return slope of repolarization period between trough and baseline After reaching it's maximum polarization, the neuron potential will recover. The repolarization slope is defined as the dV/dT of the action potential between trough and baseline. The returned slope is in units of (unit of template) - per second. By default traces are scaled to units of µV, controlled + per second. By default traces are scaled to units of uV, controlled by `sorting_analyzer.return_in_uV`. In this case this function returns the slope - in µV/s. + in uV/s. Parameters ---------- - main_channel_template: numpy.ndarray + template_single: numpy.ndarray The 1D template waveform sampling_frequency : float The sampling frequency of the template - peaks_info: dict - Peaks and troughs detection results from get_trough_and_peak_idx + trough_idx: int, default: None + The index of the trough Returns ------- slope: float The repolarization slope """ - import scipy.stats + if trough_idx is None: + troughs, _, _ = get_trough_and_peak_idx(template_single) + trough_idx = troughs["main_loc"] - times = np.arange(main_channel_template.shape[0]) / sampling_frequency + times = np.arange(template_single.shape[0]) / sampling_frequency - trough_index = peaks_info["trough_index"] - if trough_index is None or trough_index == 0: + if trough_idx is None or trough_idx == 0: return np.nan - (rtrn_idx,) = np.nonzero(main_channel_template[trough_index:] >= 0) + (rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0) if len(rtrn_idx) == 0: return np.nan # first time after trough, where template is at baseline - return_to_base_idx = rtrn_idx[0] + trough_index - if return_to_base_idx - trough_index < 3: + return_to_base_idx = rtrn_idx[0] + trough_idx + + if return_to_base_idx - trough_idx < 3: return np.nan - res = scipy.stats.linregress( - times[trough_index:return_to_base_idx], main_channel_template[trough_index:return_to_base_idx] - ) + import scipy.stats + + res = scipy.stats.linregress(times[trough_idx:return_to_base_idx], template_single[trough_idx:return_to_base_idx]) return res.slope -def get_recovery_slope(main_channel_template, sampling_frequency, peaks_info, **kwargs): +def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwargs): """ - Return the recovery slope between the main peak after the trough and baseline. - - After repolarization, the neuron hyperpolarizes until it peaks. The recovery slope is the + Return the recovery slope of input waveforms. After repolarization, + the neuron hyperpolarizes until it peaks. The recovery slope is the slope of the action potential after the peak, returning to the baseline in dV/dT. The returned slope is in units of (unit of template) - per second. By default traces are scaled to units of µV, controlled + per second. By default traces are scaled to units of uV, controlled by `sorting_analyzer.return_in_uV`. In this case this function returns the slope - in µV/s. The slope is computed within a user-defined window after the peak. + in uV/s. The slope is computed within a user-defined window after the peak. Parameters ---------- - main_channel_template: numpy.ndarray + template_single: numpy.ndarray The 1D template waveform sampling_frequency : float The sampling frequency of the template - peaks_info: dict - The index of the peak after the trough + peak_idx: int, default: None + The index of the peak **kwargs: Required kwargs: - recovery_window_ms: the window in ms after the peak to compute the recovery_slope @@ -630,31 +671,40 @@ def get_recovery_slope(main_channel_template, sampling_frequency, peaks_info, ** assert "recovery_window_ms" in kwargs, "recovery_window_ms must be given as kwarg" recovery_window_ms = kwargs["recovery_window_ms"] + if peak_idx is None: + _, _, peaks_after = get_trough_and_peak_idx(template_single) + peak_idx = peaks_after["main_loc"] - times = np.arange(main_channel_template.shape[0]) / sampling_frequency + times = np.arange(template_single.shape[0]) / sampling_frequency - peak_after_trough_index = peaks_info["peak_after_index"] - if peak_after_trough_index is None or peak_after_trough_index == 0: + if peak_idx is None or peak_idx == 0: return np.nan - max_idx = int(peak_after_trough_index + ((recovery_window_ms / 1000) * sampling_frequency)) - max_idx = np.min([max_idx, main_channel_template.shape[0]]) + max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency)) + max_idx = np.min([max_idx, template_single.shape[0]]) - res = scipy.stats.linregress( - times[peak_after_trough_index:max_idx], main_channel_template[peak_after_trough_index:max_idx] - ) + res = scipy.stats.linregress(times[peak_idx:max_idx], template_single[peak_idx:max_idx]) return res.slope -def get_number_of_peaks(peaks_info, **kwargs): +def get_number_of_peaks(template_single, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs): """ Count the total number of peaks (positive) and troughs (negative) in the template. - Uses the pre-computed peak/trough detection from get_trough_and_peak_idx. + Uses the pre-computed peak/trough detection from get_trough_and_peak_idx which + applies smoothing for more robust detection. Parameters ---------- - peaks_info: dict - Peaks and troughs detection results from get_trough_and_peak_idx + template_single: numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency of the template + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx Returns ------- @@ -664,12 +714,13 @@ def get_number_of_peaks(peaks_info, **kwargs): The number of negative peaks (troughs) """ # Count peaks (positive) from peaks_before and peaks_after - num_peaks_before = len(peaks_info["peak_before_sample_indices"]) - num_peaks_after = len(peaks_info["peak_after_sample_indices"]) + num_peaks_before = len(peaks_before["indices"]) + num_peaks_after = len(peaks_after["indices"]) num_positive = num_peaks_before + num_peaks_after # Count troughs (negative) - num_negative = len(peaks_info["trough_sample_indices"]) + num_negative = len(troughs["indices"]) + return num_positive, num_negative @@ -700,32 +751,26 @@ def sort_template_and_locations(template, channel_locations, depth_direction="y" return template[:, sort_indices], channel_locations[sort_indices, :] -def fit_line_robust(x, y, eps=1e-12): +def fit_velocity(peak_times, channel_dist): """ - Fit line using robust Theil-Sen estimator (median of pairwise slopes). + Fit velocity from peak times and channel distances using robust Theilsen estimator. """ - import itertools - - # Calculate slope and bias using Theil-Sen estimator - slopes = [] - for (x0, y0), (x1, y1) in itertools.combinations(zip(x, y), 2): - if np.abs(x1 - x0) > eps: - slopes.append((y1 - y0) / (x1 - x0)) - if len(slopes) == 0: # all x are identical - return np.nan, -np.inf - slope = np.median(slopes) - bias = np.median(y - slope * x) + # from scipy.stats import linregress + # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) - # Calculate R2 score - y_pred = slope * x + bias - r2_score = 1 - ((y - y_pred) ** 2).sum() / (((y - y.mean()) ** 2).sum() + eps) + from sklearn.linear_model import TheilSenRegressor - return slope, r2_score + theil = TheilSenRegressor() + theil.fit(peak_times.reshape(-1, 1), channel_dist) + slope = theil.coef_[0] + intercept = theil.intercept_ + score = theil.score(peak_times.reshape(-1, 1), channel_dist) + return slope, intercept, score def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs): """ - Compute both velocity above and below the max channel of the template in units µm/ms. + Compute both velocity above and below the max channel of the template in units um/ms. Parameters ---------- @@ -739,7 +784,7 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - min_channels: the minimum number of channels above or below to compute velocity - min_r2: the minimum r2 to accept the velocity fit - - column_range: the range in µm in the x-direction to consider channels for velocity + - column_range: the range in um in the x-direction to consider channels for velocity Returns ------- @@ -776,10 +821,8 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) channel_locations_above = channel_locations[channels_above] peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) - inv_velocity_above, score = fit_line_robust(distances_um_above, peak_times_ms_above) - if score > min_r2 and inv_velocity_above != 0: - velocity_above = 1 / inv_velocity_above - else: + velocity_above, _, score = fit_velocity(peak_times_ms_above, distances_um_above) + if score < min_r2: velocity_above = np.nan # Compute velocity below @@ -791,10 +834,8 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) channel_locations_below = channel_locations[channels_below] peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) - inv_velocity_below, score = fit_line_robust(distances_um_below, peak_times_ms_below) - if score > min_r2 and inv_velocity_below != 0: - velocity_below = 1 / inv_velocity_below - else: + velocity_below, _, score = fit_velocity(peak_times_ms_below, distances_um_below) + if score < min_r2: velocity_below = np.nan return velocity_above, velocity_below @@ -802,11 +843,7 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): """ - Compute the spatial decay of the template amplitude over distance. - - Can fit either an exponential decay (with offset) or a linear decay model. Channels are first - filtered by x-distance tolerance from the max channel, then the closest channels - in y-distance are used for fitting. + Compute the exponential decay of the template amplitude over distance in units um/s. Parameters ---------- @@ -817,18 +854,13 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs sampling_frequency : float The sampling frequency of the template **kwargs: Required kwargs: - - peak_function: the function to use to compute the peak amplitude ("ptp" or "min") - - min_r2: the minimum r2 to accept the fit - - linear_fit: bool, if True use linear fit, otherwise exponential fit - - channel_tolerance: max x-distance (um) from max channel to include channels - - min_channels_for_fit: minimum number of valid channels required for fitting - - num_channels_for_fit: number of closest channels to use for fitting - - normalize_decay: bool, if True normalize amplitudes to max before fitting + - peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") + - min_r2: the minimum r2 to accept the exp decay fit Returns ------- exp_decay_value : float - The spatial decay slope (decay constant for exp fit, negative slope for linear fit) + The exponential decay of the template amplitude """ from scipy.optimize import curve_fit from sklearn.metrics import r2_score @@ -836,117 +868,41 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs def exp_decay(x, decay, amp0, offset): return amp0 * np.exp(-decay * x) + offset - def linear_fit_func(x, a, b): - return a * x + b - - # Extract parameters assert "peak_function" in kwargs, "peak_function must be given as kwarg" peak_function = kwargs["peak_function"] assert "min_r2" in kwargs, "min_r2 must be given as kwarg" min_r2 = kwargs["min_r2"] - - use_linear_fit = kwargs.get("linear_fit", False) - channel_tolerance = kwargs.get("channel_tolerance", None) - normalize_decay = kwargs.get("normalize_decay", False) - - # Set defaults based on fit type if not specified - min_channels_for_fit = kwargs.get("min_channels_for_fit") - if min_channels_for_fit is None: - min_channels_for_fit = 5 if use_linear_fit else 8 - - num_channels_for_fit = kwargs.get("num_channels_for_fit") - if num_channels_for_fit is None: - num_channels_for_fit = 6 if use_linear_fit else 10 - - # Compute peak amplitudes per channel + # exp decay fit if peak_function == "ptp": fun = np.ptp elif peak_function == "min": fun = np.min - else: - fun = np.ptp - peak_amplitudes = np.abs(fun(template, axis=0)) - max_channel_idx = np.argmax(peak_amplitudes) - max_channel_location = channel_locations[max_channel_idx] - - # Channel selection based on tolerance (new bombcell-style) or use all channels (old style) - if channel_tolerance is not None: - # Calculate x-distances from max channel - x_dist = np.abs(channel_locations[:, 0] - max_channel_location[0]) - - # Find channels within x-distance tolerance - valid_x_channels = np.argwhere(x_dist <= channel_tolerance).flatten() - - if len(valid_x_channels) < min_channels_for_fit: - return np.nan - - # Calculate y-distances for channel selection - y_dist = np.abs(channel_locations[:, 1] - max_channel_location[1]) - - # Set y distances to max for channels outside x tolerance (so they won't be selected) - y_dist_masked = y_dist.copy() - y_dist_masked[~np.isin(np.arange(len(y_dist)), valid_x_channels)] = y_dist.max() + 1 - - # Select the closest channels in y-distance - use_these_channels = np.argsort(y_dist_masked)[:num_channels_for_fit] - - # Calculate distances from max channel for selected channels - channel_distances = np.sqrt( - np.sum(np.square(channel_locations[use_these_channels] - max_channel_location), axis=1) - ) + max_channel_location = channel_locations[np.argmax(peak_amplitudes)] + channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) + distances_sort_indices = np.argsort(channel_distances) - # Get amplitudes for selected channels - spatial_decay_points = np.max(np.abs(template[:, use_these_channels]), axis=0) - - # Sort by distance - sort_idx = np.argsort(channel_distances) - channel_distances_sorted = channel_distances[sort_idx] - peak_amplitudes_sorted = spatial_decay_points[sort_idx] - - # Normalize if requested - if normalize_decay: - peak_amplitudes_sorted = peak_amplitudes_sorted / np.max(peak_amplitudes_sorted) - - # Ensure float64 for numerical stability - channel_distances_sorted = np.float64(channel_distances_sorted) - peak_amplitudes_sorted = np.float64(peak_amplitudes_sorted) - - else: - # Old style: use all channels sorted by distance - channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) - distances_sort_indices = np.argsort(channel_distances) - - # longdouble is float128 when the platform supports it, otherwise it is float64 - channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.longdouble) - peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.longdouble) + # longdouble is float128 when the platform supports it, otherwise it is float64 + channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.longdouble) + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.longdouble) try: - if use_linear_fit: - # Linear fit: y = a*x + b - popt, _ = curve_fit(linear_fit_func, channel_distances_sorted, peak_amplitudes_sorted) - predicted = linear_fit_func(channel_distances_sorted, *popt) - r2 = r2_score(peak_amplitudes_sorted, predicted) - exp_decay_value = -popt[0] # Negative of slope - else: - # Exponential fit with offset: y = amp0 * exp(-decay * x) + offset - amp0 = peak_amplitudes_sorted[0] - offset0 = np.min(peak_amplitudes_sorted) - - popt, _ = curve_fit( - exp_decay, - channel_distances_sorted, - peak_amplitudes_sorted, - bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), - p0=[1e-3, peak_amplitudes_sorted[0], offset0], - ) - r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) - exp_decay_value = popt[0] + amp0 = peak_amplitudes_sorted[0] + offset0 = np.min(peak_amplitudes_sorted) + + popt, _ = curve_fit( + exp_decay, + channel_distances_sorted, + peak_amplitudes_sorted, + bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), + p0=[1e-3, peak_amplitudes_sorted[0], offset0], + ) + r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) + exp_decay_value = popt[0] if r2 < min_r2: exp_decay_value = np.nan - - except Exception: + except: exp_decay_value = np.nan return exp_decay_value @@ -954,7 +910,7 @@ def linear_fit_func(x, a, b): def get_spread(template, channel_locations, sampling_frequency, **kwargs) -> float: """ - Compute the spread of the template amplitude over distance in units µm/s. + Compute the spread of the template amplitude over distance in units um/s. Parameters ---------- @@ -967,7 +923,7 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs) -> flo **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - spread_threshold: the threshold to compute the spread - - column_range: the range in µm in the x-direction to consider channels for velocity + - column_range: the range in um in the x-direction to consider channels for velocity Returns ------- @@ -1006,51 +962,85 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs) -> flo return spread -class PeakToTroughDuration(BaseMetric): - metric_name = "peak_to_trough_duration" +def single_channel_metric(unit_function, sorting_analyzer, unit_ids, tmp_data, **metric_params): + result = {} + templates_single = tmp_data["templates_single"] + troughs = tmp_data.get("troughs", None) + peaks = tmp_data.get("peaks", None) + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + trough_idx = troughs[unit_id] if troughs is not None else None + peak_idx = peaks[unit_id] if peaks is not None else None + metric_params["trough_idx"] = trough_idx + metric_params["peak_idx"] = peak_idx + value = unit_function(template_single, sampling_frequency, **metric_params) + result[unit_id] = value + return result + + +class PeakToValley(BaseMetric): + metric_name = "peak_to_valley" + metric_params = {} + metric_columns = {"peak_to_valley": float} + metric_descriptions = { + "peak_to_valley": "Duration in s between the trough (minimum) and the peak (maximum) of the spike waveform." + } + needs_tmp_data = True + + @staticmethod + def _peak_to_valley_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + return single_channel_metric( + unit_function=get_peak_to_valley, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) + + metric_function = _peak_to_valley_metric_function + + +class PeakToTroughRatio(BaseMetric): + metric_name = "peak_trough_ratio" metric_params = {} - metric_columns = {"peak_to_trough_duration": float} + metric_columns = {"peak_trough_ratio": float} metric_descriptions = { - "peak_to_trough_duration": "Duration in seconds between the trough (minimum) and the next peak (maximum) of the template." + "peak_trough_ratio": "Ratio of the amplitude of the peak (maximum) to the trough (minimum) of the spike waveform." } needs_tmp_data = True - deprecated_names = ["peak_to_valley"] @staticmethod - def _peak_to_trough_duration_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - pt_durations = {} - sampling_frequency = tmp_data["sampling_frequency"] - for unit_index, unit_id in enumerate(unit_ids): - peaks_info = tmp_data["peaks_info"][unit_index] - pt_durations[unit_id] = get_peak_to_trough_duration(peaks_info, sampling_frequency, **metric_params) - return pt_durations + def _peak_to_trough_ratio_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + return single_channel_metric( + unit_function=get_peak_trough_ratio, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) - metric_function = _peak_to_trough_duration_metric_function + metric_function = _peak_to_trough_ratio_metric_function class HalfWidth(BaseMetric): metric_name = "half_width" metric_params = {} - metric_columns = {"trough_half_width": float, "peak_half_width": float} + metric_columns = {"half_width": float} metric_descriptions = { - "trough_half_width": "Duration in s at half the amplitude of the trough (minimum) of the template.", - "peak_half_width": "Duration in s at half the amplitude of the peak (maximum) of the template.", + "half_width": "Duration in s at half the amplitude of the trough (minimum) of the spike waveform." } needs_tmp_data = True @staticmethod def _half_width_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - half_width_result = namedtuple("HalfWidthResult", ["trough_half_width", "peak_half_width"]) - trough_half_widths = {} - peak_half_widths = {} - sampling_frequency = tmp_data["sampling_frequency"] - for unit_index, unit_id in enumerate(unit_ids): - peaks_info = tmp_data["peaks_info"][unit_index] - main_channel_template = tmp_data["main_channel_templates"][unit_index] - trough_half_widths[unit_id], peak_half_widths[unit_id] = get_half_widths( - main_channel_template, sampling_frequency, peaks_info, **metric_params - ) - return half_width_result(trough_half_width=trough_half_widths, peak_half_width=peak_half_widths) + return single_channel_metric( + unit_function=get_half_width, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) metric_function = _half_width_metric_function @@ -1060,21 +1050,19 @@ class RepolarizationSlope(BaseMetric): metric_params = {} metric_columns = {"repolarization_slope": float} metric_descriptions = { - "repolarization_slope": "Slope of the repolarization phase of the template, between the trough (minimum) and return to baseline in µV/s." + "repolarization_slope": "Slope of the repolarization phase of the spike waveform, between the trough (minimum) and return to baseline in uV/s." } needs_tmp_data = True @staticmethod def _repolarization_slope_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - repolarization_slopes = {} - sampling_frequency = tmp_data["sampling_frequency"] - for unit_index, unit_id in enumerate(unit_ids): - main_channel_template = tmp_data["main_channel_templates"][unit_index] - peaks_info = tmp_data["peaks_info"][unit_index] - repolarization_slopes[unit_id] = get_repolarization_slope( - main_channel_template, sampling_frequency, peaks_info, **metric_params - ) - return repolarization_slopes + return single_channel_metric( + unit_function=get_repolarization_slope, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) metric_function = _repolarization_slope_metric_function @@ -1084,27 +1072,47 @@ class RecoverySlope(BaseMetric): metric_params = {"recovery_window_ms": 0.7} metric_columns = {"recovery_slope": float} metric_descriptions = { - "recovery_slope": "Slope of the recovery phase of the template, after the peak (maximum) returning to baseline in µV/s." + "recovery_slope": "Slope of the recovery phase of the spike waveform, after the peak (maximum) returning to baseline in uV/s." } needs_tmp_data = True @staticmethod def _recovery_slope_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - recovery_slopes = {} - sampling_frequency = tmp_data["sampling_frequency"] - for unit_index, unit_id in enumerate(unit_ids): - main_channel_template = tmp_data["main_channel_templates"][unit_index] - peaks_info = tmp_data["peaks_info"][unit_index] - recovery_slopes[unit_id] = get_recovery_slope( - main_channel_template, sampling_frequency, peaks_info, **metric_params - ) - return recovery_slopes + return single_channel_metric( + unit_function=get_recovery_slope, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) metric_function = _recovery_slope_metric_function +def _number_of_peaks_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + num_peaks_result = namedtuple("NumberOfPeaksResult", ["num_positive_peaks", "num_negative_peaks"]) + num_positive_peaks_dict = {} + num_negative_peaks_dict = {} + sampling_frequency = tmp_data["sampling_frequency"] + templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + num_positive, num_negative = get_number_of_peaks( + template_single, sampling_frequency, + troughs_info[unit_id], peaks_before_info[unit_id], peaks_after_info[unit_id], + **metric_params + ) + num_positive_peaks_dict[unit_id] = num_positive + num_negative_peaks_dict[unit_id] = num_negative + return num_peaks_result(num_positive_peaks=num_positive_peaks_dict, num_negative_peaks=num_negative_peaks_dict) + + class NumberOfPeaks(BaseMetric): metric_name = "number_of_peaks" + metric_function = _number_of_peaks_metric_function metric_params = {} metric_columns = {"num_positive_peaks": int, "num_negative_peaks": int} metric_descriptions = { @@ -1113,52 +1121,71 @@ class NumberOfPeaks(BaseMetric): } needs_tmp_data = True - @staticmethod - def _number_of_peaks_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - num_peaks_result = namedtuple("NumberOfPeaksResult", ["num_positive_peaks", "num_negative_peaks"]) - num_positive_peaks_dict = {} - num_negative_peaks_dict = {} - for unit_index, unit_id in enumerate(unit_ids): - peaks_info = tmp_data["peaks_info"][unit_index] - num_positive, num_negative = get_number_of_peaks( - peaks_info, - **metric_params, - ) - num_positive_peaks_dict[unit_id] = num_positive - num_negative_peaks_dict[unit_id] = num_negative - return num_peaks_result(num_positive_peaks=num_positive_peaks_dict, num_negative_peaks=num_negative_peaks_dict) - metric_function = _number_of_peaks_metric_function +def _waveform_duration_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + result = {} + templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + value = get_waveform_duration( + template_single, sampling_frequency, + troughs_info[unit_id], peaks_before_info[unit_id], peaks_after_info[unit_id], + **metric_params + ) + result[unit_id] = value + return result -class MainToNextExtremumDuration(BaseMetric): - metric_name = "main_to_next_extremum_duration" +class WaveformDuration(BaseMetric): + metric_name = "waveform_duration" + metric_function = _waveform_duration_metric_function metric_params = {} - metric_columns = {"main_to_next_extremum_duration": float} - metric_descriptions = {"main_to_next_extremum_duration": "Duration in seconds from main extremum to next extremum."} + metric_columns = {"waveform_duration": float} + metric_descriptions = { + "waveform_duration": "Waveform duration in microseconds from main extremum to next extremum." + } needs_tmp_data = True - @staticmethod - def _main_to_next_extremum_duration_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - result = {} - sampling_frequency = tmp_data["sampling_frequency"] - for unit_index, unit_id in enumerate(unit_ids): - main_channel_template = tmp_data["main_channel_templates"][unit_index] - peaks_info = tmp_data["peaks_info"][unit_index] - value = get_main_to_next_extremum_duration( - main_channel_template, - peaks_info, - sampling_frequency, - **metric_params, - ) - result[unit_id] = value - return result - metric_function = _main_to_next_extremum_duration_metric_function +def _waveform_ratios_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + waveform_ratios_result = namedtuple("WaveformRatiosResult", [ + "peak_before_to_trough_ratio", "peak_after_to_trough_ratio", + "peak_before_to_peak_after_ratio", "main_peak_to_trough_ratio" + ]) + peak_before_to_trough = {} + peak_after_to_trough = {} + peak_before_to_peak_after = {} + main_peak_to_trough = {} + templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + ratios = get_waveform_ratios( + template_single, + troughs_info[unit_id], peaks_before_info[unit_id], peaks_after_info[unit_id], + **metric_params + ) + peak_before_to_trough[unit_id] = ratios["peak_before_to_trough_ratio"] + peak_after_to_trough[unit_id] = ratios["peak_after_to_trough_ratio"] + peak_before_to_peak_after[unit_id] = ratios["peak_before_to_peak_after_ratio"] + main_peak_to_trough[unit_id] = ratios["main_peak_to_trough_ratio"] + return waveform_ratios_result( + peak_before_to_trough_ratio=peak_before_to_trough, + peak_after_to_trough_ratio=peak_after_to_trough, + peak_before_to_peak_after_ratio=peak_before_to_peak_after, + main_peak_to_trough_ratio=main_peak_to_trough + ) class WaveformRatios(BaseMetric): metric_name = "waveform_ratios" + metric_function = _waveform_ratios_metric_function metric_params = {} metric_columns = { "peak_before_to_trough_ratio": float, @@ -1173,48 +1200,40 @@ class WaveformRatios(BaseMetric): "main_peak_to_trough_ratio": "Ratio of main peak amplitude to trough amplitude", } needs_tmp_data = True - deprecated_names = ["peak_trough_ratio"] - @staticmethod - def _waveform_ratios_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - waveform_ratios_result = namedtuple( - "WaveformRatiosResult", - [ - "peak_before_to_trough_ratio", - "peak_after_to_trough_ratio", - "peak_before_to_peak_after_ratio", - "main_peak_to_trough_ratio", - ], - ) - peak_before_to_trough = {} - peak_after_to_trough = {} - peak_before_to_peak_after = {} - main_peak_to_trough = {} - - for unit_index, unit_id in enumerate(unit_ids): - main_channel_template = tmp_data["main_channel_templates"][unit_index] - peaks_info = tmp_data["peaks_info"][unit_index] - ratios = get_waveform_ratios( - main_channel_template, - peaks_info, - **metric_params, - ) - peak_before_to_trough[unit_id] = ratios["peak_before_to_trough_ratio"] - peak_after_to_trough[unit_id] = ratios["peak_after_to_trough_ratio"] - peak_before_to_peak_after[unit_id] = ratios["peak_before_to_peak_after_ratio"] - main_peak_to_trough[unit_id] = ratios["main_peak_to_trough_ratio"] - return waveform_ratios_result( - peak_before_to_trough_ratio=peak_before_to_trough, - peak_after_to_trough_ratio=peak_after_to_trough, - peak_before_to_peak_after_ratio=peak_before_to_peak_after, - main_peak_to_trough_ratio=main_peak_to_trough, - ) - metric_function = _waveform_ratios_metric_function +def _waveform_widths_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + waveform_widths_result = namedtuple("WaveformWidthsResult", [ + "trough_width", "peak_before_width", "peak_after_width" + ]) + trough_width_dict = {} + peak_before_width_dict = {} + peak_after_width_dict = {} + templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + widths = get_waveform_widths( + template_single, sampling_frequency, + troughs_info[unit_id], peaks_before_info[unit_id], peaks_after_info[unit_id], + **metric_params + ) + trough_width_dict[unit_id] = widths["trough_width_us"] + peak_before_width_dict[unit_id] = widths["peak_before_width_us"] + peak_after_width_dict[unit_id] = widths["peak_after_width_us"] + return waveform_widths_result( + trough_width=trough_width_dict, + peak_before_width=peak_before_width_dict, + peak_after_width=peak_after_width_dict + ) class WaveformWidths(BaseMetric): metric_name = "waveform_widths" + metric_function = _waveform_widths_metric_function metric_params = {} metric_columns = { "trough_width": float, @@ -1222,71 +1241,43 @@ class WaveformWidths(BaseMetric): "peak_after_width": float, } metric_descriptions = { - "trough_width": "Width of the main trough in seconds", - "peak_before_width": "Width of the main peak before trough in seconds", - "peak_after_width": "Width of the main peak after trough in seconds", + "trough_width": "Width of the main trough in microseconds", + "peak_before_width": "Width of the main peak before trough in microseconds", + "peak_after_width": "Width of the main peak after trough in microseconds", } needs_tmp_data = True - @staticmethod - def _waveform_widths_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - waveform_widths_result = namedtuple( - "WaveformWidthsResult", ["trough_width", "peak_before_width", "peak_after_width"] - ) - trough_width_dict = {} - peak_before_width_dict = {} - peak_after_width_dict = {} - - sampling_frequency = tmp_data["sampling_frequency"] - for unit_index, unit_id in enumerate(unit_ids): - peaks_info = tmp_data["peaks_info"][unit_index] - widths = get_waveform_widths( - peaks_info, - sampling_frequency, - **metric_params, - ) - trough_width_dict[unit_id] = widths["trough_width"] - peak_before_width_dict[unit_id] = widths["peak_before_width"] - peak_after_width_dict[unit_id] = widths["peak_after_width"] - return waveform_widths_result( - trough_width=trough_width_dict, - peak_before_width=peak_before_width_dict, - peak_after_width=peak_after_width_dict, - ) - metric_function = _waveform_widths_metric_function +def _waveform_baseline_flatness_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + result = {} + templates_single = tmp_data["templates_single"] + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + value = get_waveform_baseline_flatness(template_single, sampling_frequency, **metric_params) + result[unit_id] = value + return result class WaveformBaselineFlatness(BaseMetric): metric_name = "waveform_baseline_flatness" + metric_function = _waveform_baseline_flatness_metric_function metric_params = {"baseline_window_ms": (0.0, 0.5)} metric_columns = {"waveform_baseline_flatness": float} metric_descriptions = { "waveform_baseline_flatness": "Ratio of max baseline amplitude to max waveform amplitude. Lower = flatter baseline." } needs_tmp_data = True - deprecated_names = ["num_positive_peaks", "num_negative_peaks"] - - @staticmethod - def _waveform_baseline_flatness_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - result = {} - sampling_frequency = tmp_data["sampling_frequency"] - for unit_index, unit_id in enumerate(unit_ids): - main_channel_template = tmp_data["main_channel_templates"][unit_index] - value = get_waveform_baseline_flatness(main_channel_template, sampling_frequency, **metric_params) - result[unit_id] = value - return result - - metric_function = _waveform_baseline_flatness_metric_function single_channel_metrics = [ - PeakToTroughDuration, + PeakToValley, + PeakToTroughRatio, HalfWidth, RepolarizationSlope, RecoverySlope, NumberOfPeaks, - MainToNextExtremumDuration, + WaveformDuration, WaveformRatios, WaveformWidths, WaveformBaselineFlatness, @@ -1297,13 +1288,13 @@ def _get_velocity_fits_metric_function(sorting_analyzer, unit_ids, tmp_data, **m velocity_above_result = namedtuple("Velocities", ["velocity_above", "velocity_below"]) velocity_above_dict = {} velocity_below_dict = {} - multi_channel_templates = tmp_data["multi_channel_templates"] + templates_multi = tmp_data["templates_multi"] channel_locations_multi = tmp_data["channel_locations_multi"] sampling_frequency = tmp_data["sampling_frequency"] metric_params["depth_direction"] = tmp_data["depth_direction"] for unit_index, unit_id in enumerate(unit_ids): channel_locations = channel_locations_multi[unit_index] - template = multi_channel_templates[unit_index] + template = templates_multi[unit_index] vel_above, vel_below = get_velocity_fits(template, channel_locations, sampling_frequency, **metric_params) velocity_above_dict[unit_id] = vel_above velocity_below_dict[unit_id] = vel_below @@ -1320,57 +1311,55 @@ class VelocityFits(BaseMetric): } metric_columns = {"velocity_above": float, "velocity_below": float} metric_descriptions = { - "velocity_above": "Velocity of the spike propagation above the max channel in µm/ms", - "velocity_below": "Velocity of the spike propagation below the max channel in µm/ms", + "velocity_above": "Velocity of the spike propagation above the max channel in um/ms", + "velocity_below": "Velocity of the spike propagation below the max channel in um/ms", } needs_tmp_data = True - deprecated_names = ["velocity_above", "velocity_below"] + + +def multi_channel_metric(unit_function, sorting_analyzer, unit_ids, tmp_data, **metric_params): + result = {} + templates_multi = tmp_data["templates_multi"] + channel_locations_multi = tmp_data["channel_locations_multi"] + sampling_frequency = tmp_data["sampling_frequency"] + metric_params["depth_direction"] = tmp_data["depth_direction"] + for unit_index, unit_id in enumerate(unit_ids): + channel_locations = channel_locations_multi[unit_index] + template = templates_multi[unit_index] + value = unit_function(template, channel_locations, sampling_frequency, **metric_params) + result[unit_id] = value + return result class ExpDecay(BaseMetric): metric_name = "exp_decay" - metric_params = { - "peak_function": "ptp", - "min_r2": 0.2, - "linear_fit": False, - "channel_tolerance": None, # None uses old style (all channels), set to e.g. 33 for bombcell-style - "min_channels_for_fit": None, # None means use default based on linear_fit (5 for linear, 8 for exp) - "num_channels_for_fit": None, # None means use default based on linear_fit (6 for linear, 10 for exp) - "normalize_decay": False, - } + metric_params = {"peak_function": "ptp", "min_r2": 0.2} metric_columns = {"exp_decay": float} metric_descriptions = { - "exp_decay": ( - "Spatial decay of the template amplitude over distance from the extremum channel (1/um). " - "Uses exponential or linear fit based on linear_fit parameter." - ) + "exp_decay": ("Exponential decay of the template amplitude over distance from the extremum channel (1/um).") } needs_tmp_data = True @staticmethod def _exp_decay_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - exp_decays = {} - multi_channel_templates = tmp_data["multi_channel_templates"] - channel_locations_multi = tmp_data["channel_locations_multi"] - sampling_frequency = tmp_data["sampling_frequency"] - metric_params["depth_direction"] = tmp_data["depth_direction"] - for unit_index, unit_id in enumerate(unit_ids): - channel_locations = channel_locations_multi[unit_index] - template = multi_channel_templates[unit_index] - value = get_exp_decay(template, channel_locations, sampling_frequency, **metric_params) - exp_decays[unit_id] = value - return exp_decays + return multi_channel_metric( + unit_function=get_exp_decay, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) metric_function = _exp_decay_metric_function class Spread(BaseMetric): metric_name = "spread" - metric_params = {"spread_threshold": 0.2, "spread_smooth_um": 20, "column_range": None} + metric_params = {"spread_threshold": 0.5, "spread_smooth_um": 20, "column_range": None} metric_columns = {"spread": float} metric_descriptions = { "spread": ( - "Spread of the template amplitude in µm, calculated as the distance between channels whose " + "Spread of the template amplitude in um, calculated as the distance between channels whose " "templates exceed the spread_threshold." ) } @@ -1378,17 +1367,13 @@ class Spread(BaseMetric): @staticmethod def _spread_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - spreads = {} - multi_channel_templates = tmp_data["multi_channel_templates"] - channel_locations_multi = tmp_data["channel_locations_multi"] - sampling_frequency = tmp_data["sampling_frequency"] - metric_params["depth_direction"] = tmp_data["depth_direction"] - for unit_index, unit_id in enumerate(unit_ids): - channel_locations = channel_locations_multi[unit_index] - template = multi_channel_templates[unit_index] - value = get_spread(template, channel_locations, sampling_frequency, **metric_params) - spreads[unit_id] = value - return spreads + return multi_channel_metric( + unit_function=get_spread, + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + tmp_data=tmp_data, + **metric_params, + ) metric_function = _spread_metric_function diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 96eda404dd..1a52b14013 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -1,15 +1,22 @@ -import warnings +""" +Functions based on +https://github.com/AllenInstitute/ecephys_spike_sorting/blob/master/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py +22/04/2020 +""" + +from __future__ import annotations + import numpy as np +import warnings +from copy import deepcopy +from scipy.signal import find_peaks from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array -from .metrics import ( - get_trough_and_peak_idx, - single_channel_metrics, - multi_channel_metrics, -) +from .metrics import get_trough_and_peak_idx, single_channel_metrics, multi_channel_metrics + MIN_SPARSE_CHANNELS_FOR_MULTI_CHANNEL_WARNING = 10 MIN_CHANNELS_FOR_MULTI_CHANNEL_METRICS = 64 @@ -28,8 +35,6 @@ def get_template_metric_list(): def get_template_metric_names(): - import warnings - warnings.warn( "get_template_metric_names is deprecated and will be removed in a version 0.105.0. " "Please use get_template_metric_list instead.", @@ -42,18 +47,17 @@ def get_template_metric_names(): class ComputeTemplateMetrics(BaseMetricExtension): """ Compute template metrics including: - * peak_to_trough_duration - * main_to_next_extremum_duration - * half_width + * peak_to_valley + * peak_trough_ratio + * halfwidth * repolarization_slope * recovery_slope - * number_of_peaks - * waveform_ratios - * waveform_widths - * waveform_baseline_flatness + * num_positive_peaks + * num_negative_peaks Optionally, the following multi-channel metrics can be computed (when include_multi_channel_metrics=True): - * velocity_fits + * velocity_above + * velocity_below * exp_decay * spread @@ -68,16 +72,12 @@ class ComputeTemplateMetrics(BaseMetricExtension): metric_params : dict of dicts or None, default: None Dictionary with parameters for template metrics calculation. Default parameters can be obtained with: `si.metrics.template_metrics.get_default_template_metrics_params()` - peak_sign : {"neg", "pos", "both"}, default: "both" - Whether to use the positive ("pos"), negative ("neg"), or both ("both") peaks to estimate extremum channels. + peak_sign : {"neg", "pos"}, default: "neg" + Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. upsampling_factor : int, default: 10 The upsampling factor to upsample the templates include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics - min_thresh_detect_peaks_troughs : float, default: 0.3 - Minimum prominence threshold as a fraction of the template's absolute max value - edge_exclusion_ms : float, default: 0.09 - Duration in milliseconds to exclude from template edges during peak/trough detection. Returns ------- @@ -95,11 +95,8 @@ class ComputeTemplateMetrics(BaseMetricExtension): depend_on = ["templates"] need_backward_compatibility_on_load = True metric_list = single_channel_metrics + multi_channel_metrics - tmp_data_to_save = ["peaks_data", "main_channel_templates"] def _handle_backward_compatibility_on_load(self): - from copy import deepcopy - # For backwards compatibility - this reformats metrics_kwargs as metric_params if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: @@ -111,63 +108,24 @@ def _handle_backward_compatibility_on_load(self): del self.params["metrics_kwargs"] # handle metric names change: + # num_positive_peaks/num_negative_peaks merged into number_of_peaks if "num_positive_peaks" in self.params["metric_names"]: self.params["metric_names"].remove("num_positive_peaks") if "number_of_peaks" not in self.params["metric_names"]: self.params["metric_names"].append("number_of_peaks") - if "num_positive_peaks" in self.params["metric_params"]: - del self.params["metric_params"]["num_positive_peaks"] if "num_negative_peaks" in self.params["metric_names"]: self.params["metric_names"].remove("num_negative_peaks") if "number_of_peaks" not in self.params["metric_names"]: self.params["metric_names"].append("number_of_peaks") - if "num_negative_peaks" in self.params["metric_params"]: - del self.params["metric_params"]["num_negative_peaks"] # velocity_above/velocity_below merged into velocity_fits if "velocity_above" in self.params["metric_names"]: self.params["metric_names"].remove("velocity_above") if "velocity_fits" not in self.params["metric_names"]: self.params["metric_names"].append("velocity_fits") - self.params["metric_params"]["velocity_fits"] = self.params["metric_params"]["velocity_above"] - self.params["metric_params"]["velocity_fits"]["min_channels"] = self.params["metric_params"][ - "velocity_above" - ]["min_channels_for_velocity"] - self.params["metric_params"]["velocity_fits"]["min_r2"] = self.params["metric_params"]["velocity_above"][ - "min_r2_velocity" - ] - del self.params["metric_params"]["velocity_above"] if "velocity_below" in self.params["metric_names"]: self.params["metric_names"].remove("velocity_below") if "velocity_fits" not in self.params["metric_names"]: self.params["metric_names"].append("velocity_fits") - # parameters are already updated from velocity_above - if "velocity_below" in self.params["metric_params"]: - del self.params["metric_params"]["velocity_below"] - # exp_decay parameters changes - if "exp_decay" in self.params["metric_names"]: - if "exp_peak_function" in self.params["metric_params"]["exp_decay"]: - self.params["metric_params"]["exp_decay"]["peak_function"] = self.params["metric_params"]["exp_decay"][ - "exp_peak_function" - ] - if "min_r2_exp_decay" in self.params["metric_params"]["exp_decay"]: - self.params["metric_params"]["exp_decay"]["min_r2"] = self.params["metric_params"]["exp_decay"][ - "min_r2_exp_decay" - ] - if "depth_direction" not in self.params: - self.params["depth_direction"] = "y" - - # peak_to_valley -> peak_to_trough_duration - if "peak_to_valley" in self.params["metric_names"]: - self.params["metric_names"].remove("peak_to_valley") - if "peak_to_trough_duration" not in self.params["metric_names"]: - self.params["metric_names"].append("peak_to_trough_duration") - # peak_trough ratio -> main peak to trough ratio - # note that the new implementation correctly uses the absolute peak values, - # which is different from the old implementation. - if "peak_trough_ratio" in self.params["metric_names"]: - self.params["metric_names"].remove("peak_trough_ratio") - if "waveform_ratios" not in self.params["metric_names"]: - self.params["metric_names"].append("waveform_ratios") def _set_params( self, @@ -175,17 +133,15 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, - periods=None, # common extension kwargs - peak_sign="both", - template_operator="average", + peak_sign="neg", upsampling_factor=10, include_multi_channel_metrics=False, depth_direction="y", - min_thresh_detect_peaks_troughs=0.3, - edge_exclusion_ms=0.09, - min_peak_trough_distance_ratio=0.2, - min_extremum_distance_samples=3, + min_thresh_detect_peaks_troughs=0.4, + smooth=True, + smooth_window_frac=0.1, + smooth_polyorder=3, ): # Auto-detect if multi-channel metrics should be included based on number of channels num_channels = self.sorting_analyzer.get_num_channels() @@ -210,34 +166,26 @@ def _set_params( metric_params=metric_params, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, - periods=periods, # template metrics do not use periods peak_sign=peak_sign, upsampling_factor=upsampling_factor, - template_operator=template_operator, include_multi_channel_metrics=include_multi_channel_metrics, depth_direction=depth_direction, min_thresh_detect_peaks_troughs=min_thresh_detect_peaks_troughs, - edge_exclusion_ms=edge_exclusion_ms, - min_peak_trough_distance_ratio=min_peak_trough_distance_ratio, - min_extremum_distance_samples=min_extremum_distance_samples, + smooth=smooth, + smooth_window_frac=smooth_window_frac, + smooth_polyorder=smooth_polyorder, ) def _prepare_data(self, sorting_analyzer, unit_ids): - import warnings - import pandas as pd from scipy.signal import resample_poly - # compute main_channel_templates and multi_channel_templates (if include_multi_channel_metrics is True) + # compute templates_single and templates_multi (if include_multi_channel_metrics is True) tmp_data = {} if unit_ids is None: unit_ids = sorting_analyzer.unit_ids peak_sign = self.params["peak_sign"] upsampling_factor = self.params["upsampling_factor"] - min_thresh_detect_peaks_troughs = self.params["min_thresh_detect_peaks_troughs"] - edge_exclusion_ms = self.params.get("edge_exclusion_ms", 0.1) - min_peak_trough_distance_ratio = self.params.get("min_peak_trough_distance_ratio", 0.2) - min_extremum_distance_samples = self.params.get("min_extremum_distance_samples", 3) sampling_frequency = sorting_analyzer.sampling_frequency if self.params["upsampling_factor"] > 1: sampling_frequency_up = upsampling_factor * sampling_frequency @@ -249,17 +197,18 @@ def _prepare_data(self, sorting_analyzer, unit_ids): m in get_multi_channel_template_metric_names() for m in self.params["metrics_to_compute"] ) - operator = self.params["template_operator"] - extremum_channel_indices = get_template_extremum_channel( - sorting_analyzer, peak_sign=peak_sign, outputs="index", operator=operator - ) - all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True, operator=operator) + extremum_channel_indices = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign, outputs="index") + all_templates = get_dense_templates_array(sorting_analyzer, return_in_uV=True) channel_locations = sorting_analyzer.get_channel_locations() - main_channel_templates = [] - peaks_info = [] - multi_channel_templates = [] + templates_single = [] + troughs = {} + peaks = {} + troughs_info = {} + peaks_before_info = {} + peaks_after_info = {} + templates_multi = [] channel_locations_multi = [] for unit_id in unit_ids: unit_index = sorting_analyzer.sorting.id_to_index(unit_id) @@ -271,18 +220,23 @@ def _prepare_data(self, sorting_analyzer, unit_ids): template_upsampled = resample_poly(template_single, up=upsampling_factor, down=1) else: template_upsampled = template_single - - peaks_info_unit = get_trough_and_peak_idx( + sampling_frequency_up = sampling_frequency + troughs_dict, peaks_before_dict, peaks_after_dict = get_trough_and_peak_idx( template_upsampled, - sampling_frequency_up, - min_thresh_detect_peaks_troughs=min_thresh_detect_peaks_troughs, - edge_exclusion_ms=edge_exclusion_ms, - min_peak_trough_distance_ratio=min_peak_trough_distance_ratio, - min_extremum_distance_samples=min_extremum_distance_samples, + min_thresh_detect_peaks_troughs=self.params['min_thresh_detect_peaks_troughs'], + smooth=self.params['smooth'], + smooth_window_frac=self.params['smooth_window_frac'], + smooth_polyorder=self.params['smooth_polyorder'], ) - main_channel_templates.append(template_upsampled) - peaks_info.append(peaks_info_unit) + templates_single.append(template_upsampled) + # Store main locations for backward compatibility + troughs[unit_id] = troughs_dict["main_loc"] + peaks[unit_id] = peaks_after_dict["main_loc"] + # Store full dicts for new metrics + troughs_info[unit_id] = troughs_dict + peaks_before_info[unit_id] = peaks_before_dict + peaks_after_info[unit_id] = peaks_after_dict if include_multi_channel_metrics: if sorting_analyzer.is_sparse(): @@ -302,36 +256,26 @@ def _prepare_data(self, sorting_analyzer, unit_ids): template_multi_upsampled = resample_poly(template_multi, up=upsampling_factor, down=1, axis=0) else: template_multi_upsampled = template_multi - multi_channel_templates.append(template_multi_upsampled) + templates_multi.append(template_multi_upsampled) channel_locations_multi.append(channel_location_multi) - tmp_data["peaks_info"] = peaks_info - tmp_data["main_channel_templates"] = np.array(main_channel_templates) + tmp_data["troughs"] = troughs + tmp_data["peaks"] = peaks + tmp_data["troughs_info"] = troughs_info + tmp_data["peaks_before_info"] = peaks_before_info + tmp_data["peaks_after_info"] = peaks_after_info + tmp_data["templates_single"] = np.array(templates_single) if include_multi_channel_metrics: - # multi_channel_templates is a list of 2D arrays of shape (n_times, n_channels) - tmp_data["multi_channel_templates"] = multi_channel_templates + # templates_multi is a list of 2D arrays of shape (n_times, n_channels) + tmp_data["templates_multi"] = templates_multi tmp_data["channel_locations_multi"] = channel_locations_multi tmp_data["depth_direction"] = self.params["depth_direction"] - # Add peaks_info and preprocessed templates to self.data for storage in extension - columns = [] - for k in ("trough", "peak_before", "peak_after"): - for suffix in ( - "index", - "width_left", - "width_right", - "half_width_left", - "half_width_right", - ): - columns.append(f"{k}_{suffix}") - tmp_data["peaks_data"] = pd.DataFrame( - index=unit_ids, - data=peaks_info, - columns=columns, - dtype=int, - ) - + # store trough and peak dicts for GUI use + self.data['troughs_info'] = troughs_dict + self.data['peaks_before_info'] = peaks_before_dict + self.data['peaks_after_info'] = troughs_dict, peaks_before_dict, peaks_after_dict return tmp_data @@ -358,8 +302,6 @@ def get_default_tm_params(metric_names=None): metric_params : dict Dictionary with default parameters for template metrics. """ - import warnings - warnings.warn( "get_default_tm_params is deprecated and will be removed in a version 0.105.0. " "Please use get_default_template_metrics_params instead.", From 51f085e1913fc27769f7cbb490f5523edd4c7ea1 Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Mon, 23 Mar 2026 11:02:41 -0400 Subject: [PATCH 02/15] bombcell full pipeline, edits to allow most config options native bombcell has etc --- .gitignore | 3 + in_container_params.json | 3 - in_container_recording.json | 15497 -------------------------------- in_container_sorter_script.py | 28 - 4 files changed, 3 insertions(+), 15528 deletions(-) delete mode 100644 in_container_params.json delete mode 100644 in_container_recording.json delete mode 100644 in_container_sorter_script.py diff --git a/.gitignore b/.gitignore index bfb43fcae1..1de05bd8c5 100644 --- a/.gitignore +++ b/.gitignore @@ -154,3 +154,6 @@ playground.ipynb playground2.ipynb examples/how_to/full_pipeline_with_bombcell.ipynb examples/how_to/compare_bombcell_unitrefine.ipynb +in_container_params.json +in_container_recording.json +in_container_sorter_script.py diff --git a/in_container_params.json b/in_container_params.json deleted file mode 100644 index 462dc67ed3..0000000000 --- a/in_container_params.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "output_folder": "/Users/jf5479/Downloads/AL031_2019-12-02/spikeinterface_output/kilosort4_output" -} \ No newline at end of file diff --git a/in_container_recording.json b/in_container_recording.json deleted file mode 100644 index 64f8f88c42..0000000000 --- a/in_container_recording.json +++ /dev/null @@ -1,15497 +0,0 @@ -{ - "class": "spikeinterface.preprocessing.common_reference.CommonReferenceRecording", - "module": "spikeinterface", - "version": "0.103.3", - "kwargs": { - "recording": { - "class": "spikeinterface.preprocessing.phase_shift.PhaseShiftRecording", - "module": "spikeinterface", - "version": "0.103.3", - "kwargs": { - "recording": { - "class": "spikeinterface.core.channelslice.ChannelSliceRecording", - "module": "spikeinterface", - "version": "0.103.3", - "kwargs": { - "parent_recording": { - "class": "spikeinterface.preprocessing.filter.HighpassFilterRecording", - "module": "spikeinterface", - "version": "0.103.3", - "kwargs": { - "recording": { - "class": "spikeinterface.core.channelslice.ChannelSliceRecording", - "module": "spikeinterface", - "version": "0.103.3", - "kwargs": { - "parent_recording": { - "class": "spikeinterface.core.binaryrecordingextractor.BinaryRecordingExtractor", - "module": "spikeinterface", - "version": "0.103.3", - "kwargs": { - "file_paths": [ - "/Users/jf5479/Downloads/AL031_2019-12-02/AL031_2019-12-02_bank1_NatIm_g0_t0_bc_decompressed.imec0.ap.bin" - ], - "sampling_frequency": 30000.0, - "t_starts": null, - "num_channels": 385, - "dtype": " Date: Mon, 23 Mar 2026 15:11:35 +0000 Subject: [PATCH 03/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../curation/bombcell_curation.py | 4 +- .../metrics/template/metrics.py | 149 ++++++++++++------ .../metrics/template/template_metrics.py | 19 ++- 3 files changed, 113 insertions(+), 59 deletions(-) diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py index 0cfdba1164..b85498cda2 100644 --- a/src/spikeinterface/curation/bombcell_curation.py +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -314,9 +314,7 @@ def bombcell_label_units( # Standalone custom metrics: any metric in non-somatic thresholds that is not # part of the built-in groups is OR'd in as its own independent condition. standalone_metrics = { - m: non_somatic_thresholds[m] - for m in non_somatic_thresholds - if m not in _NON_SOMATIC_BUILTIN_METRICS + m: non_somatic_thresholds[m] for m in non_somatic_thresholds if m not in _NON_SOMATIC_BUILTIN_METRICS } for metric_name, thresh in standalone_metrics.items(): standalone_labels = threshold_metrics_label_units( diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py index 3d84dceadc..470ffe193a 100644 --- a/src/spikeinterface/metrics/template/metrics.py +++ b/src/spikeinterface/metrics/template/metrics.py @@ -6,7 +6,9 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric -def get_trough_and_peak_idx(template, min_thresh_detect_peaks_troughs=0.4, smooth=True, smooth_window_frac=0.1, smooth_polyorder=3): +def get_trough_and_peak_idx( + template, min_thresh_detect_peaks_troughs=0.4, smooth=True, smooth_window_frac=0.1, smooth_polyorder=3 +): """ Detect troughs and peaks in a template waveform and return detailed information about each detected feature. @@ -106,16 +108,12 @@ def get_trough_and_peak_idx(template, min_thresh_detect_peaks_troughs=0.4, smoot template_before = template[:main_trough_loc] # Try with original prominence - peak_locs_before, peak_props_before = find_peaks( - template_before, prominence=min_prominence, width=0 - ) + peak_locs_before, peak_props_before = find_peaks(template_before, prominence=min_prominence, width=0) # If no peaks found, try with lower prominence (keep only max peak) if len(peak_locs_before) == 0: lower_prominence = 0.075 * min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) - peak_locs_before, peak_props_before = find_peaks( - template_before, prominence=lower_prominence, width=0 - ) + peak_locs_before, peak_props_before = find_peaks(template_before, prominence=lower_prominence, width=0) # Keep only the most prominent peak when using lower threshold if len(peak_locs_before) > 1: prominences = peak_props_before.get("prominences", np.array([])) @@ -154,16 +152,12 @@ def get_trough_and_peak_idx(template, min_thresh_detect_peaks_troughs=0.4, smoot template_after = template[main_trough_loc:] # Try with original prominence - peak_locs_after, peak_props_after = find_peaks( - template_after, prominence=min_prominence, width=0 - ) + peak_locs_after, peak_props_after = find_peaks(template_after, prominence=min_prominence, width=0) # If no peaks found, try with lower prominence (keep only max peak) if len(peak_locs_after) == 0: lower_prominence = 0.075 * min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) - peak_locs_after, peak_props_after = find_peaks( - template_after, prominence=lower_prominence, width=0 - ) + peak_locs_after, peak_props_after = find_peaks(template_after, prominence=lower_prominence, width=0) # Keep only the most prominent peak when using lower threshold if len(peak_locs_after) > 1: prominences = peak_props_after.get("prominences", np.array([])) @@ -220,24 +214,65 @@ def get_trough_and_peak_idx(template, min_thresh_detect_peaks_troughs=0.4, smoot # Plot all detected troughs ax.scatter(troughs["indices"], troughs["values"], c="blue", s=50, marker="v", zorder=5, label="troughs") if troughs["main_loc"] is not None: - ax.scatter(troughs["main_loc"], template[troughs["main_loc"]], c="blue", s=150, marker="v", - edgecolors="red", linewidths=2, zorder=6, label="main trough") + ax.scatter( + troughs["main_loc"], + template[troughs["main_loc"]], + c="blue", + s=150, + marker="v", + edgecolors="red", + linewidths=2, + zorder=6, + label="main trough", + ) # Plot all peaks before if len(peaks_before["indices"]) > 0: - ax.scatter(peaks_before["indices"], peaks_before["values"], c="green", s=50, marker="^", - zorder=5, label="peaks before") + ax.scatter( + peaks_before["indices"], + peaks_before["values"], + c="green", + s=50, + marker="^", + zorder=5, + label="peaks before", + ) if peaks_before["main_loc"] is not None: - ax.scatter(peaks_before["main_loc"], template[peaks_before["main_loc"]], c="green", s=150, - marker="^", edgecolors="red", linewidths=2, zorder=6, label="main peak before") + ax.scatter( + peaks_before["main_loc"], + template[peaks_before["main_loc"]], + c="green", + s=150, + marker="^", + edgecolors="red", + linewidths=2, + zorder=6, + label="main peak before", + ) # Plot all peaks after if len(peaks_after["indices"]) > 0: - ax.scatter(peaks_after["indices"], peaks_after["values"], c="orange", s=50, marker="^", - zorder=5, label="peaks after") + ax.scatter( + peaks_after["indices"], + peaks_after["values"], + c="orange", + s=50, + marker="^", + zorder=5, + label="peaks after", + ) if peaks_after["main_loc"] is not None: - ax.scatter(peaks_after["main_loc"], template[peaks_after["main_loc"]], c="orange", s=150, - marker="^", edgecolors="red", linewidths=2, zorder=6, label="main peak after") + ax.scatter( + peaks_after["main_loc"], + template[peaks_after["main_loc"]], + c="orange", + s=150, + marker="^", + edgecolors="red", + linewidths=2, + zorder=6, + label="main peak after", + ) ax.axhline(0, color="gray", ls="-", alpha=0.3) ax.set_xlabel("Sample") @@ -369,7 +404,14 @@ def safe_ratio(a, b): "peak_before_to_trough_ratio": safe_ratio(peak_before_amp, trough_amp), "peak_after_to_trough_ratio": safe_ratio(peak_after_amp, trough_amp), "peak_before_to_peak_after_ratio": safe_ratio(peak_before_amp, peak_after_amp), - "main_peak_to_trough_ratio": safe_ratio(max(peak_before_amp, peak_after_amp) if not (np.isnan(peak_before_amp) and np.isnan(peak_after_amp)) else np.nan, trough_amp), + "main_peak_to_trough_ratio": safe_ratio( + ( + max(peak_before_amp, peak_after_amp) + if not (np.isnan(peak_before_amp) and np.isnan(peak_after_amp)) + else np.nan + ), + trough_amp, + ), } return ratios @@ -455,6 +497,7 @@ def get_waveform_widths(template, sampling_frequency, troughs, peaks_before, pea - "peak_before_width_us": width of main peak before trough in microseconds - "peak_after_width_us": width of main peak after trough in microseconds """ + def get_main_width(feature_dict): if feature_dict["main_idx"] is None: return np.nan @@ -1101,9 +1144,12 @@ def _number_of_peaks_metric_function(sorting_analyzer, unit_ids, tmp_data, **met for unit_index, unit_id in enumerate(unit_ids): template_single = templates_single[unit_index] num_positive, num_negative = get_number_of_peaks( - template_single, sampling_frequency, - troughs_info[unit_id], peaks_before_info[unit_id], peaks_after_info[unit_id], - **metric_params + template_single, + sampling_frequency, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, ) num_positive_peaks_dict[unit_id] = num_positive num_negative_peaks_dict[unit_id] = num_negative @@ -1132,9 +1178,12 @@ def _waveform_duration_metric_function(sorting_analyzer, unit_ids, tmp_data, **m for unit_index, unit_id in enumerate(unit_ids): template_single = templates_single[unit_index] value = get_waveform_duration( - template_single, sampling_frequency, - troughs_info[unit_id], peaks_before_info[unit_id], peaks_after_info[unit_id], - **metric_params + template_single, + sampling_frequency, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, ) result[unit_id] = value return result @@ -1152,10 +1201,15 @@ class WaveformDuration(BaseMetric): def _waveform_ratios_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - waveform_ratios_result = namedtuple("WaveformRatiosResult", [ - "peak_before_to_trough_ratio", "peak_after_to_trough_ratio", - "peak_before_to_peak_after_ratio", "main_peak_to_trough_ratio" - ]) + waveform_ratios_result = namedtuple( + "WaveformRatiosResult", + [ + "peak_before_to_trough_ratio", + "peak_after_to_trough_ratio", + "peak_before_to_peak_after_ratio", + "main_peak_to_trough_ratio", + ], + ) peak_before_to_trough = {} peak_after_to_trough = {} peak_before_to_peak_after = {} @@ -1168,8 +1222,10 @@ def _waveform_ratios_metric_function(sorting_analyzer, unit_ids, tmp_data, **met template_single = templates_single[unit_index] ratios = get_waveform_ratios( template_single, - troughs_info[unit_id], peaks_before_info[unit_id], peaks_after_info[unit_id], - **metric_params + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, ) peak_before_to_trough[unit_id] = ratios["peak_before_to_trough_ratio"] peak_after_to_trough[unit_id] = ratios["peak_after_to_trough_ratio"] @@ -1179,7 +1235,7 @@ def _waveform_ratios_metric_function(sorting_analyzer, unit_ids, tmp_data, **met peak_before_to_trough_ratio=peak_before_to_trough, peak_after_to_trough_ratio=peak_after_to_trough, peak_before_to_peak_after_ratio=peak_before_to_peak_after, - main_peak_to_trough_ratio=main_peak_to_trough + main_peak_to_trough_ratio=main_peak_to_trough, ) @@ -1203,9 +1259,9 @@ class WaveformRatios(BaseMetric): def _waveform_widths_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - waveform_widths_result = namedtuple("WaveformWidthsResult", [ - "trough_width", "peak_before_width", "peak_after_width" - ]) + waveform_widths_result = namedtuple( + "WaveformWidthsResult", ["trough_width", "peak_before_width", "peak_after_width"] + ) trough_width_dict = {} peak_before_width_dict = {} peak_after_width_dict = {} @@ -1217,17 +1273,18 @@ def _waveform_widths_metric_function(sorting_analyzer, unit_ids, tmp_data, **met for unit_index, unit_id in enumerate(unit_ids): template_single = templates_single[unit_index] widths = get_waveform_widths( - template_single, sampling_frequency, - troughs_info[unit_id], peaks_before_info[unit_id], peaks_after_info[unit_id], - **metric_params + template_single, + sampling_frequency, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, ) trough_width_dict[unit_id] = widths["trough_width_us"] peak_before_width_dict[unit_id] = widths["peak_before_width_us"] peak_after_width_dict[unit_id] = widths["peak_after_width_us"] return waveform_widths_result( - trough_width=trough_width_dict, - peak_before_width=peak_before_width_dict, - peak_after_width=peak_after_width_dict + trough_width=trough_width_dict, peak_before_width=peak_before_width_dict, peak_after_width=peak_after_width_dict ) diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 1a52b14013..ccd0309882 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -9,7 +9,7 @@ import numpy as np import warnings from copy import deepcopy -from scipy.signal import find_peaks +from scipy.signal import find_peaks from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension @@ -17,7 +17,6 @@ from .metrics import get_trough_and_peak_idx, single_channel_metrics, multi_channel_metrics - MIN_SPARSE_CHANNELS_FOR_MULTI_CHANNEL_WARNING = 10 MIN_CHANNELS_FOR_MULTI_CHANNEL_METRICS = 64 @@ -223,10 +222,10 @@ def _prepare_data(self, sorting_analyzer, unit_ids): sampling_frequency_up = sampling_frequency troughs_dict, peaks_before_dict, peaks_after_dict = get_trough_and_peak_idx( template_upsampled, - min_thresh_detect_peaks_troughs=self.params['min_thresh_detect_peaks_troughs'], - smooth=self.params['smooth'], - smooth_window_frac=self.params['smooth_window_frac'], - smooth_polyorder=self.params['smooth_polyorder'], + min_thresh_detect_peaks_troughs=self.params["min_thresh_detect_peaks_troughs"], + smooth=self.params["smooth"], + smooth_window_frac=self.params["smooth_window_frac"], + smooth_polyorder=self.params["smooth_polyorder"], ) templates_single.append(template_upsampled) @@ -272,10 +271,10 @@ def _prepare_data(self, sorting_analyzer, unit_ids): tmp_data["channel_locations_multi"] = channel_locations_multi tmp_data["depth_direction"] = self.params["depth_direction"] - # store trough and peak dicts for GUI use - self.data['troughs_info'] = troughs_dict - self.data['peaks_before_info'] = peaks_before_dict - self.data['peaks_after_info'] = troughs_dict, peaks_before_dict, peaks_after_dict + # store trough and peak dicts for GUI use + self.data["troughs_info"] = troughs_dict + self.data["peaks_before_info"] = peaks_before_dict + self.data["peaks_after_info"] = troughs_dict, peaks_before_dict, peaks_after_dict return tmp_data From eb97c34e0d7ebaed05ca5ec2b2855c3d3f1f30a3 Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Thu, 26 Mar 2026 11:38:55 -0400 Subject: [PATCH 04/15] bombcell full pipeline, edits to allow most config options native bombcell has etc --- .../widgets/bombcell_curation.py | 2 +- .../widgets/template_peak_trough.py | 336 ++++++++++++++++++ src/spikeinterface/widgets/widget_list.py | 3 + 3 files changed, 340 insertions(+), 1 deletion(-) create mode 100644 src/spikeinterface/widgets/template_peak_trough.py diff --git a/src/spikeinterface/widgets/bombcell_curation.py b/src/spikeinterface/widgets/bombcell_curation.py index 1a1212ba5c..03d57c73d8 100644 --- a/src/spikeinterface/widgets/bombcell_curation.py +++ b/src/spikeinterface/widgets/bombcell_curation.py @@ -105,7 +105,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.text( 0.5, 0.5, - "UpSet plots require 'upsetplot' package.\n\npip install upsetplot", + "UpSet plots require 'upsetplot' package.\n\npip install upsetplot-bombcell", ha="center", va="center", fontsize=14, diff --git a/src/spikeinterface/widgets/template_peak_trough.py b/src/spikeinterface/widgets/template_peak_trough.py new file mode 100644 index 0000000000..cf30f53c0f --- /dev/null +++ b/src/spikeinterface/widgets/template_peak_trough.py @@ -0,0 +1,336 @@ +"""Widget for visualizing template and mean raw waveforms with detected peaks and troughs.""" + +from __future__ import annotations + +import numpy as np + +from .base import BaseWidget, to_attr + + +class TemplatePeakTroughWidget(BaseWidget): + """Plot template and mean raw waveform side by side for each unit, with detected + peaks and troughs overlaid on the peak channel. + + For each unit, two panels are shown: + - **Left**: template (average) waveform with detected peak_before, trough, and peak_after + markers on the peak channel. + - **Right**: mean raw waveform (average of extracted waveforms). + + Multiple channels can be displayed (peak channel + ``n_channels_around`` neighbors). + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object with ``templates`` extension computed. + If ``show_mean_waveform`` is True, the ``waveforms`` extension must also be computed. + unit_ids : list or None, default: None + Unit IDs to plot. If None, plots all units (up to ``max_units``). + max_units : int, default: 16 + Maximum number of units to plot when ``unit_ids`` is None. + n_channels_around : int, default: 0 + Number of channels above and below the peak channel to display. + 0 means only the peak channel. + max_columns : int, default: 4 + Maximum number of units per row. Each unit takes 2 subplot columns + (template + mean waveform). + show_mean_waveform : bool, default: True + Whether to show the mean raw waveform panel. Requires ``waveforms`` extension. + unit_labels : np.ndarray or None, default: None + Optional array of labels (e.g. bombcell labels) to show in subplot titles. + Must be the same length as ``sorting_analyzer.unit_ids``. + min_thresh_detect_peaks_troughs : float, default: 0.4 + Prominence threshold passed to ``get_trough_and_peak_idx``. + min_peak_before_ratio : float or None, default: 0.5 + Minimum ratio of ``abs(peak_before) / abs(trough)`` for the peak_before marker + to be displayed. If None, peak_before is always shown when detected. + A value of 0.5 means peak_before must be at least 50% of the trough amplitude. + """ + + def __init__( + self, + sorting_analyzer, + unit_ids=None, + max_units: int = 16, + n_channels_around: int = 0, + max_columns: int = 4, + show_mean_waveform: bool = True, + unit_labels: np.ndarray | None = None, + min_thresh_detect_peaks_troughs: float = 0.4, + min_peak_before_ratio: float | None = 0.5, + backend=None, + **backend_kwargs, + ): + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "templates") + + if show_mean_waveform: + wf_ext = sorting_analyzer.get_extension("waveforms") + if wf_ext is None: + raise ValueError( + "show_mean_waveform=True requires the 'waveforms' extension. " + "Either compute it or set show_mean_waveform=False." + ) + + all_unit_ids = list(sorting_analyzer.unit_ids) + if unit_ids is None: + unit_ids = all_unit_ids[:max_units] + else: + unit_ids = list(unit_ids) + + plot_data = dict( + sorting_analyzer=sorting_analyzer, + unit_ids=unit_ids, + n_channels_around=n_channels_around, + max_columns=max_columns, + show_mean_waveform=show_mean_waveform, + unit_labels=unit_labels, + min_thresh_detect_peaks_troughs=min_thresh_detect_peaks_troughs, + min_peak_before_ratio=min_peak_before_ratio, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from spikeinterface.metrics.template.metrics import get_trough_and_peak_idx + + dp = to_attr(data_plot) + sorting_analyzer = dp.sorting_analyzer + unit_ids = dp.unit_ids + n_around = dp.n_channels_around + thresh = dp.min_thresh_detect_peaks_troughs + peak_before_ratio = dp.min_peak_before_ratio + show_mean = dp.show_mean_waveform + + templates_ext = sorting_analyzer.get_extension("templates") + templates = templates_ext.get_templates(operator="average") + all_unit_ids = list(sorting_analyzer.unit_ids) + channel_locations = sorting_analyzer.get_channel_locations() + + # Compute mean raw waveforms if needed + mean_waveforms = {} + if show_mean: + wf_ext = sorting_analyzer.get_extension("waveforms") + for uid in unit_ids: + wfs = wf_ext.get_waveforms_one_unit(uid, force_dense=True) + mean_waveforms[uid] = np.mean(wfs, axis=0) + + n_units = len(unit_ids) + panels_per_unit = 2 if show_mean else 1 + ncols_units = min(n_units, dp.max_columns) + ncols = ncols_units * panels_per_unit + nrows = int(np.ceil(n_units / ncols_units)) + + figsize = backend_kwargs.pop("figsize", None) + if figsize is None: + figsize = (3.5 * ncols, 3 * nrows) + fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) + + for i, uid in enumerate(unit_ids): + row = i // ncols_units + col_base = (i % ncols_units) * panels_per_unit + + unit_idx = all_unit_ids.index(uid) + template = templates[unit_idx] # (n_samples, n_channels) + + # Find peak channel + best_chan = int(np.argmax(np.max(np.abs(template), axis=0))) + + # Get channel indices to display + if n_around > 0: + # Use spatial proximity: find closest channels + best_loc = channel_locations[best_chan] + dists = np.linalg.norm(channel_locations - best_loc, axis=1) + n_total = 1 + 2 * n_around + chan_inds = np.argsort(dists)[:n_total] + # Sort by y-position (depth) for display order + chan_inds = chan_inds[np.argsort(channel_locations[chan_inds, 1])[::-1]] + else: + chan_inds = np.array([best_chan]) + + peak_chan_pos_in_display = int(np.where(chan_inds == best_chan)[0][0]) + + # Build title + title = f"Unit {uid}" + if dp.unit_labels is not None: + label_idx = all_unit_ids.index(uid) + if label_idx < len(dp.unit_labels): + title += f" ({dp.unit_labels[label_idx]})" + + # --- Left panel: template with peak/trough markers --- + ax_template = axes[row, col_base] + self._plot_multichannel( + ax_template, + template, + chan_inds, + best_chan, + peak_chan_pos_in_display, + thresh, + title=f"{title}\ntemplate", + show_markers=True, + min_peak_before_ratio=peak_before_ratio, + ) + + # --- Right panel: mean raw waveform --- + if show_mean: + ax_mean = axes[row, col_base + 1] + self._plot_multichannel( + ax_mean, + mean_waveforms[uid], + chan_inds, + best_chan, + peak_chan_pos_in_display, + thresh, + title=f"{title}\nmean waveform", + show_markers=True, + min_peak_before_ratio=peak_before_ratio, + ) + + # Hide unused axes + for row_idx in range(nrows): + for col_idx in range(ncols): + unit_i = row_idx * ncols_units + col_idx // panels_per_unit + if unit_i >= n_units: + axes[row_idx, col_idx].set_visible(False) + + # Legend from first axis + handles, labels = axes[0, 0].get_legend_handles_labels() + if handles: + fig.legend(handles, labels, loc="lower right", fontsize=7, ncol=3) + + fig.tight_layout() + self.figure = fig + self.axes = axes + self.ax = axes[0, 0] + + @staticmethod + def _plot_multichannel( + ax, data, chan_inds, best_chan, peak_chan_pos, thresh, title="", show_markers=True, min_peak_before_ratio=None + ): + """Plot waveform data on multiple channels with optional peak/trough markers. + + Parameters + ---------- + ax : matplotlib.axes.Axes + data : np.ndarray, shape (n_samples, n_channels) + chan_inds : np.ndarray of channel indices to plot + best_chan : int, the peak channel index in data + peak_chan_pos : int, position of peak channel in display order + thresh : float, prominence threshold for peak/trough detection + title : str + show_markers : bool + min_peak_before_ratio : float or None + """ + n_samples = data.shape[0] + n_chans = len(chan_inds) + + spacer = 0.0 + if n_chans == 1: + # Single channel: simple plot + waveform = data[:, best_chan] + ax.plot(waveform, color="k", lw=1) + + if show_markers: + _overlay_peak_trough_markers(ax, waveform, thresh, min_peak_before_ratio=min_peak_before_ratio) + else: + # Multi-channel: offset vertically + traces = data[:, chan_inds] # (n_samples, n_display_chans) + spacer = np.max(np.ptp(traces, axis=0)) * 1.3 + if spacer == 0: + spacer = 1.0 + + for j in range(n_chans): + offset = j * -spacer + waveform = traces[:, j] + is_peak = (chan_inds[j] == best_chan) + color = "k" if is_peak else "gray" + lw = 1.2 if is_peak else 0.7 + alpha = 1.0 if is_peak else 0.5 + ax.plot(waveform + offset, color=color, lw=lw, alpha=alpha) + + if show_markers and is_peak: + _overlay_peak_trough_markers( + ax, waveform, thresh, y_offset=offset, min_peak_before_ratio=min_peak_before_ratio + ) + + ax.axhline(-peak_chan_pos * spacer, color="gray", ls="-", alpha=0.15) + ax.set_title(title, fontsize=8) + for spine in ax.spines.values(): + spine.set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + + +def _overlay_peak_trough_markers(ax, waveform, thresh, y_offset=0.0, min_peak_before_ratio=None): + """Scatter peak_before, trough, and peak_after markers onto an axis. + + Parameters + ---------- + min_peak_before_ratio : float or None + If set, only show peak_before when abs(peak_before) / abs(trough) >= this value. + """ + from spikeinterface.metrics.template.metrics import get_trough_and_peak_idx + + troughs, peaks_before, peaks_after = get_trough_and_peak_idx( + waveform, min_thresh_detect_peaks_troughs=thresh + ) + + # Check whether peak_before passes the ratio threshold + show_peak_before = True + if min_peak_before_ratio is not None and peaks_before["main_loc"] is not None and troughs["main_loc"] is not None: + trough_val = np.abs(waveform[troughs["main_loc"]]) + peak_before_val = np.abs(waveform[peaks_before["main_loc"]]) + if trough_val > 0: + show_peak_before = (peak_before_val / trough_val) >= min_peak_before_ratio + else: + show_peak_before = False + + # Troughs — secondary (exclude main to avoid double-plotting) + if len(troughs["indices"]) > 0: + secondary_mask = np.ones(len(troughs["indices"]), dtype=bool) + if troughs["main_idx"] is not None: + secondary_mask[troughs["main_idx"]] = False + if secondary_mask.any(): + ax.scatter( + troughs["indices"][secondary_mask], troughs["values"][secondary_mask] + y_offset, + c="blue", s=30, marker="v", zorder=5, label="trough", + ) + if troughs["main_loc"] is not None: + ax.scatter( + troughs["main_loc"], waveform[troughs["main_loc"]] + y_offset, + c="blue", s=100, marker="v", edgecolors="red", linewidths=1.5, zorder=6, + label="trough" if not secondary_mask.any() else None, + ) + + # Peaks before — secondary (exclude main), skip if below ratio threshold + if show_peak_before and len(peaks_before["indices"]) > 0: + secondary_mask = np.ones(len(peaks_before["indices"]), dtype=bool) + if peaks_before["main_idx"] is not None: + secondary_mask[peaks_before["main_idx"]] = False + if secondary_mask.any(): + ax.scatter( + peaks_before["indices"][secondary_mask], peaks_before["values"][secondary_mask] + y_offset, + c="green", s=30, marker="^", zorder=5, label="peak before", + ) + if peaks_before["main_loc"] is not None: + ax.scatter( + peaks_before["main_loc"], waveform[peaks_before["main_loc"]] + y_offset, + c="green", s=100, marker="^", edgecolors="red", linewidths=1.5, zorder=6, + label="peak before" if not secondary_mask.any() else None, + ) + + # Peaks after — secondary (exclude main) + if len(peaks_after["indices"]) > 0: + secondary_mask = np.ones(len(peaks_after["indices"]), dtype=bool) + if peaks_after["main_idx"] is not None: + secondary_mask[peaks_after["main_idx"]] = False + if secondary_mask.any(): + ax.scatter( + peaks_after["indices"][secondary_mask], peaks_after["values"][secondary_mask] + y_offset, + c="orange", s=30, marker="^", zorder=5, label="peak after", + ) + if peaks_after["main_loc"] is not None: + ax.scatter( + peaks_after["main_loc"], waveform[peaks_after["main_loc"]] + y_offset, + c="orange", s=100, marker="^", edgecolors="red", linewidths=1.5, zorder=6, + label="peak after" if not secondary_mask.any() else None, + ) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 6cd6bbfd23..8d5931fb11 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -39,6 +39,7 @@ from .unit_labels import WaveformOverlayByLabelWidget from .unit_valid_periods import ValidUnitPeriodsWidget from .bombcell_curation import BombcellUpsetPlotWidget, plot_bombcell_unit_labeling_all +from .template_peak_trough import TemplatePeakTroughWidget widget_list = [ AgreementMatrixWidget, @@ -81,6 +82,7 @@ UnitWaveformsWidget, ValidUnitPeriodsWidget, WaveformOverlayByLabelWidget, + TemplatePeakTroughWidget, StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, @@ -171,6 +173,7 @@ plot_study_agreement_matrix = StudyAgreementMatrix plot_study_summary = StudySummary plot_study_comparison_collision_by_similarity = StudyComparisonCollisionBySimilarityWidget +plot_template_peak_trough = TemplatePeakTroughWidget def plot_timeseries(*args, **kwargs): From 898ab12ebad91f9b7cfdf8eab52133a51b868467 Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Thu, 26 Mar 2026 11:40:25 -0400 Subject: [PATCH 05/15] sliding RPV priority --- src/spikeinterface/curation/bombcell_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py index 0cfdba1164..e6a8de011b 100644 --- a/src/spikeinterface/curation/bombcell_curation.py +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -41,7 +41,7 @@ ] # RPV metric column names (bombcell accepts "rpv" as threshold key and maps to whichever exists) -RPV_METRIC_COLUMNS = ["rp_contamination", "sliding_rp_violation"] +RPV_METRIC_COLUMNS = ["sliding_rp_violation", "rp_contamination"] DEFAULT_NON_SOMATIC_METRICS = [ "peak_before_to_trough_ratio", From b05ec5de67fbf8d1a3893f25d319df255a0a90cc Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Thu, 26 Mar 2026 12:04:08 -0400 Subject: [PATCH 06/15] sliding RPV priority --- .../core/analyzer_extension_core.py | 53 ++++++++--- .../metrics/quality/misc_metrics.py | 90 ++++++++++++++----- .../metrics/quality/quality_metrics.py | 4 + .../quality/tests/test_metrics_functions.py | 15 ++-- 4 files changed, 118 insertions(+), 44 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 53fe7be1f2..d41ed636a2 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -915,6 +915,9 @@ class BaseMetricExtension(AnalyzerExtension): need_job_kwargs = True need_backward_compatibility_on_load = False metric_list: list[BaseMetric] = None # list of BaseMetric + # dict mapping data key -> bool (True = per-unit array indexed on axis 0, + # False = global data passed through unchanged on select/merge/split). + # Set to None to disable. tmp_data_to_save = None def __init__(self, sorting_analyzer): @@ -1340,7 +1343,11 @@ def _run(self, **job_kwargs): if self.tmp_data_to_save is not None: for k in self.tmp_data_to_save: - self.data[k] = tmp_data[k] + if k in tmp_data: + self.data[k] = tmp_data[k] + elif extension is not None and k in extension.data: + # Propagate previously saved tmp_data for metrics not recomputed + self.data[k] = extension.data[k] def _get_data(self): # convert to correct dtype @@ -1358,10 +1365,22 @@ def _select_extension_data(self, unit_ids: list[int | str]): Returns ------- dict - Dictionary containing the selected metrics DataFrame. + Dictionary containing the selected metrics DataFrame and any tmp_data arrays. """ new_metrics = self.data["metrics"].loc[np.array(unit_ids)] - return dict(metrics=new_metrics) + result = dict(metrics=new_metrics) + + if self.tmp_data_to_save is not None: + keep_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + for k, is_per_unit in self.tmp_data_to_save.items(): + if k not in self.data: + continue + if is_per_unit: + result[k] = self.data[k][keep_indices] + else: + result[k] = self.data[k] + + return result def _merge_extension_data( self, @@ -1411,11 +1430,15 @@ def _merge_extension_data( new_data["metrics"] = self._cast_metrics(metrics) if self.tmp_data_to_save is not None: - for k in self.tmp_data_to_save: - new_arr = _update_data_after_merge_or_split( - self.sorting_analyzer, new_sorting_analyzer, self.data[k], new_tmp_data[k], new_unit_ids - ) - new_data[k] = new_arr + for k, is_per_unit in self.tmp_data_to_save.items(): + if k not in self.data or k not in new_tmp_data: + continue + if is_per_unit: + new_data[k] = _update_data_after_merge_or_split( + self.sorting_analyzer, new_sorting_analyzer, self.data[k], new_tmp_data[k], new_unit_ids + ) + else: + new_data[k] = new_tmp_data[k] return new_data @@ -1459,11 +1482,15 @@ def _split_extension_data( new_data["metrics"] = self._cast_metrics(metrics) if self.tmp_data_to_save is not None: - for k in self.tmp_data_to_save: - new_arr = _update_data_after_merge_or_split( - self.sorting_analyzer, new_sorting_analyzer, self.data[k], new_tmp_data[k], new_unit_ids_f - ) - new_data[k] = new_arr + for k, is_per_unit in self.tmp_data_to_save.items(): + if k not in self.data or k not in new_tmp_data: + continue + if is_per_unit: + new_data[k] = _update_data_after_merge_or_split( + self.sorting_analyzer, new_sorting_analyzer, self.data[k], new_tmp_data[k], new_unit_ids_f + ) + else: + new_data[k] = new_tmp_data[k] return new_data diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 67d32640e4..660fa9c7cb 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -546,6 +546,7 @@ class RPViolation(BaseMetric): def compute_sliding_rp_violations( sorting_analyzer, unit_ids=None, + tmp_data=None, periods=None, min_spikes=0, bin_size_ms=0.25, @@ -566,6 +567,9 @@ def compute_sliding_rp_violations( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the sliding RP violations. If None, all units are used. + tmp_data : dict or None + Shared data dict from the metric extension. When provided, per-tauR contamination + curves and rp_centers are written to it for persistent storage. periods : array of unit_period_dtype | None, default: None Periods (segment_index, start_sample_index, end_sample_index, unit_index) on which to compute the metric. If None, the entire recording duration is used. @@ -588,8 +592,11 @@ def compute_sliding_rp_violations( Returns ------- - contamination : dict of floats - The minimum contamination at the specified confidence level. + result : namedtuple + Named tuple with fields: + + - ``sliding_rp_violation`` : dict of floats — minimum contamination per unit. + - ``sliding_rp_estimated_tauR`` : dict of floats — estimated refractory period (s) per unit. References ---------- @@ -597,6 +604,8 @@ def compute_sliding_rp_violations( This code was adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/1.0.0/python/slidingRP/metrics.py """ + res = namedtuple("sliding_rp_result", ["sliding_rp_violation", "sliding_rp_estimated_tauR"]) + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) @@ -606,7 +615,18 @@ def compute_sliding_rp_violations( fs = sorting_analyzer.sampling_frequency + # Pre-compute rp_centers so we know the array size even if all units are skipped + if contamination_values is None: + contamination_values_arr = np.arange(0.5, 35, 0.5) / 100 + else: + contamination_values_arr = np.asarray(contamination_values) + rp_bin_size = bin_size_ms / 1000 + rp_edges = np.arange(0, max_ref_period_ms / 1000, rp_bin_size) + rp_centers = rp_edges + rp_bin_size / 2 + contamination = {} + estimated_tauR = {} + per_tauR_contam_list = [] spikes, slices = sorting.to_reordered_spike_vector( ["sample_index", "segment_index", "unit_index"], return_order=False @@ -622,13 +642,15 @@ def compute_sliding_rp_violations( unit_n_spikes = len(sub_spikes) if unit_n_spikes <= min_spikes: contamination[unit_id] = np.nan + estimated_tauR[unit_id] = np.nan + per_tauR_contam_list.append(np.full(len(rp_centers), np.nan)) continue duration = total_durations[unit_id] sub_sorting = NumpySorting(sub_spikes, fs, unit_ids=[unit_id]) - contamination[unit_id] = slidingRP_violations( + min_contam, est_tauR, contam_at_each_tauR, _ = slidingRP_violations( sub_sorting, duration, bin_size_ms, @@ -638,8 +660,16 @@ def compute_sliding_rp_violations( contamination_values, confidence_threshold=confidence_threshold, ) + contamination[unit_id] = min_contam + estimated_tauR[unit_id] = est_tauR + per_tauR_contam_list.append(contam_at_each_tauR) + + # Store per-tauR data in tmp_data for persistent storage by the extension + if tmp_data is not None: + tmp_data["sliding_rp_per_tauR_contamination"] = np.array(per_tauR_contam_list) + tmp_data["sliding_rp_rp_centers"] = rp_centers - return contamination + return res(contamination, estimated_tauR) class SlidingRPViolation(BaseMetric): @@ -654,10 +684,12 @@ class SlidingRPViolation(BaseMetric): "contamination_values": None, "confidence_threshold": 0.9, } - metric_columns = {"sliding_rp_violation": float} + metric_columns = {"sliding_rp_violation": float, "sliding_rp_estimated_tauR": float} metric_descriptions = { - "sliding_rp_violation": "Minimum contamination at specified confidence using sliding refractory period method." + "sliding_rp_violation": "Minimum contamination at specified confidence using sliding refractory period method.", + "sliding_rp_estimated_tauR": "Estimated refractory period (seconds) at which the minimum contamination was found.", } + needs_tmp_data = True supports_periods = True @@ -1702,7 +1734,6 @@ def slidingRP_violations( exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, - return_conf_matrix=False, confidence_threshold=0.9, ): """ @@ -1713,8 +1744,10 @@ def slidingRP_violations( Parameters ---------- - spike_samples : ndarray_like or list (for multi-segment) - The spike times in samples. + sorting : BaseSorting + A sorting object (typically single-unit). + duration : float + Total duration in seconds. bin_size_ms : float The size (in ms) of binning for the autocorrelogram. window_size_s : float, default: 1 @@ -1725,8 +1758,6 @@ def slidingRP_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, if None it is set to np.arange(0.5, 35, 0.5) / 100. - return_conf_matrix : bool, default: False - If True, the confidence matrix (n_contaminations, n_ref_periods) is returned. confidence_threshold : float, default: 0.9 Confidence threshold (between 0 and 1). Default is 0.9 (90% confidence). @@ -1735,8 +1766,15 @@ def slidingRP_violations( Returns ------- - min_contamination : dict of floats + min_contamination : float The minimum contamination with confidence above the specified threshold. + estimated_tauR : float + The refractory period (in seconds) at which the minimum contamination was found. + contamination_at_each_tauR : np.ndarray + 1D array of length ``n_rp_centers``. For each tauR tested, the minimum + contamination with confidence above the threshold (NaN where none passes). + rp_centers : np.ndarray + 1D array of the refractory period centers tested (in seconds). """ if contamination_values is None: contamination_values = np.arange(0.5, 35, 0.5) / 100 # vector of contamination values to test @@ -1776,18 +1814,26 @@ def slidingRP_violations( ) test_rp_centers_mask = rp_centers > exclude_ref_period_below_ms / 1000.0 # (in seconds) - # only test for refractory period durations greater than 'exclude_ref_period_below_ms' - inds_above_threshold = np.row_stack(np.where(conf_matrix[:, test_rp_centers_mask] > confidence_threshold)) - - if len(inds_above_threshold[0]) > 0: - minI = np.min(inds_above_threshold[0][0]) - min_contamination = contamination_values[minI] + # For each tauR, find the minimum contamination where confidence exceeds threshold + contamination_at_each_tauR = np.full(len(rp_centers), np.nan) + for j in range(len(rp_centers)): + passing = np.where(conf_matrix[:, j] > confidence_threshold)[0] + if len(passing) > 0: + contamination_at_each_tauR[j] = contamination_values[passing[0]] + + # Only test for refractory period durations greater than 'exclude_ref_period_below_ms' + masked_contam = contamination_at_each_tauR.copy() + masked_contam[~test_rp_centers_mask] = np.nan + + if np.any(~np.isnan(masked_contam)): + best_idx = np.nanargmin(masked_contam) + min_contamination = masked_contam[best_idx] + estimated_tauR = rp_centers[best_idx] else: min_contamination = np.nan - if return_conf_matrix: - return min_contamination, conf_matrix - else: - return min_contamination + estimated_tauR = np.nan + + return min_contamination, estimated_tauR, contamination_at_each_tauR, rp_centers def _compute_rp_contamination_one_unit( diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index c6b539fdc1..3cba9b81b4 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -46,6 +46,10 @@ class ComputeQualityMetrics(BaseMetricExtension): need_job_kwargs = True need_backward_compatibility_on_load = True metric_list = misc_metrics_list + pca_metrics_list + tmp_data_to_save = { + "sliding_rp_per_tauR_contamination": True, # (n_units, n_rp_centers) + "sliding_rp_rp_centers": False, # (n_rp_centers,) — same for all units + } @classmethod def get_required_dependencies(cls, **params): diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index e267b176ce..87a8f7e86b 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -420,22 +420,19 @@ def test_calculate_isi_violations(sorting_analyzer_violations, periods_violation def test_calculate_sliding_rp_violations(sorting_analyzer_violations, periods_violations): sorting_analyzer = sorting_analyzer_violations - contaminations = compute_sliding_rp_violations(sorting_analyzer, bin_size_ms=0.25, window_size_s=1) + result = compute_sliding_rp_violations(sorting_analyzer, bin_size_ms=0.25, window_size_s=1) + contaminations = result.sliding_rp_violation periods = periods_violations - contaminations_periods = compute_sliding_rp_violations( + result_periods = compute_sliding_rp_violations( sorting_analyzer, periods=periods, bin_size_ms=0.25, window_size_s=1 ) - assert contaminations == contaminations_periods + assert contaminations == result_periods.sliding_rp_violation empty_periods = np.empty(0, dtype=unit_period_dtype) - contaminations_periods_empty = compute_sliding_rp_violations( + result_empty = compute_sliding_rp_violations( sorting_analyzer, periods=empty_periods, bin_size_ms=0.25, window_size_s=1 ) - assert np.all(np.isnan(np.array(list(contaminations_periods_empty.values())))) - - # testing method accuracy with magic number is not a good pratcice, I remove this. - # contaminations_gt = {0: 0.03, 1: 0.185, 2: 0.325} - # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) + assert np.all(np.isnan(np.array(list(result_empty.sliding_rp_violation.values())))) def test_calculate_rp_violations(sorting_analyzer_violations, periods_violations): From 15bf09a8d07c0d629e3e1009ed1a269de5ee1dd3 Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Thu, 26 Mar 2026 13:42:00 -0400 Subject: [PATCH 07/15] draft doc updating - to check --- doc/references.rst | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/doc/references.rst b/doc/references.rst index 6a17cbb6dc..40b13f0aca 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -80,9 +80,13 @@ important for your research: - :code:`silhouette` [Rousseeuw]_ [Hruschka]_ If you use the :code:`metrics.template` module, i.e. you use the :code:`analyzer.compute("template_metrics")` method, -please following citations: +please include the following citations: -- [Jia]_ +- [Jia]_ [Fabre]_ +- :code:`half_width`, :code:`peak_to_trough_duration`, :code:`number_of_peaks` [Jia]_ [Fabre]_ +- :code:`main_to_next_extremum_duration`, :code:`waveform_ratios`, :code:`waveform_widths` [Fabre]_ +- :code:`repolarization_slope`, :code:`recovery_slope` [Jia]_ +- :code:`exp_decay` [Jia]_ [Fabre]_ Curation Module From f646bccc5a8bc611b87b7b52162027c4877461a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Mar 2026 17:44:48 +0000 Subject: [PATCH 08/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../quality/tests/test_metrics_functions.py | 4 +- .../widgets/template_peak_trough.py | 63 ++++++++++++++----- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index 87a8f7e86b..dba34dee5c 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -423,9 +423,7 @@ def test_calculate_sliding_rp_violations(sorting_analyzer_violations, periods_vi result = compute_sliding_rp_violations(sorting_analyzer, bin_size_ms=0.25, window_size_s=1) contaminations = result.sliding_rp_violation periods = periods_violations - result_periods = compute_sliding_rp_violations( - sorting_analyzer, periods=periods, bin_size_ms=0.25, window_size_s=1 - ) + result_periods = compute_sliding_rp_violations(sorting_analyzer, periods=periods, bin_size_ms=0.25, window_size_s=1) assert contaminations == result_periods.sliding_rp_violation empty_periods = np.empty(0, dtype=unit_period_dtype) diff --git a/src/spikeinterface/widgets/template_peak_trough.py b/src/spikeinterface/widgets/template_peak_trough.py index cf30f53c0f..69be47b275 100644 --- a/src/spikeinterface/widgets/template_peak_trough.py +++ b/src/spikeinterface/widgets/template_peak_trough.py @@ -241,7 +241,7 @@ def _plot_multichannel( for j in range(n_chans): offset = j * -spacer waveform = traces[:, j] - is_peak = (chan_inds[j] == best_chan) + is_peak = chan_inds[j] == best_chan color = "k" if is_peak else "gray" lw = 1.2 if is_peak else 0.7 alpha = 1.0 if is_peak else 0.5 @@ -270,9 +270,7 @@ def _overlay_peak_trough_markers(ax, waveform, thresh, y_offset=0.0, min_peak_be """ from spikeinterface.metrics.template.metrics import get_trough_and_peak_idx - troughs, peaks_before, peaks_after = get_trough_and_peak_idx( - waveform, min_thresh_detect_peaks_troughs=thresh - ) + troughs, peaks_before, peaks_after = get_trough_and_peak_idx(waveform, min_thresh_detect_peaks_troughs=thresh) # Check whether peak_before passes the ratio threshold show_peak_before = True @@ -291,13 +289,24 @@ def _overlay_peak_trough_markers(ax, waveform, thresh, y_offset=0.0, min_peak_be secondary_mask[troughs["main_idx"]] = False if secondary_mask.any(): ax.scatter( - troughs["indices"][secondary_mask], troughs["values"][secondary_mask] + y_offset, - c="blue", s=30, marker="v", zorder=5, label="trough", + troughs["indices"][secondary_mask], + troughs["values"][secondary_mask] + y_offset, + c="blue", + s=30, + marker="v", + zorder=5, + label="trough", ) if troughs["main_loc"] is not None: ax.scatter( - troughs["main_loc"], waveform[troughs["main_loc"]] + y_offset, - c="blue", s=100, marker="v", edgecolors="red", linewidths=1.5, zorder=6, + troughs["main_loc"], + waveform[troughs["main_loc"]] + y_offset, + c="blue", + s=100, + marker="v", + edgecolors="red", + linewidths=1.5, + zorder=6, label="trough" if not secondary_mask.any() else None, ) @@ -308,13 +317,24 @@ def _overlay_peak_trough_markers(ax, waveform, thresh, y_offset=0.0, min_peak_be secondary_mask[peaks_before["main_idx"]] = False if secondary_mask.any(): ax.scatter( - peaks_before["indices"][secondary_mask], peaks_before["values"][secondary_mask] + y_offset, - c="green", s=30, marker="^", zorder=5, label="peak before", + peaks_before["indices"][secondary_mask], + peaks_before["values"][secondary_mask] + y_offset, + c="green", + s=30, + marker="^", + zorder=5, + label="peak before", ) if peaks_before["main_loc"] is not None: ax.scatter( - peaks_before["main_loc"], waveform[peaks_before["main_loc"]] + y_offset, - c="green", s=100, marker="^", edgecolors="red", linewidths=1.5, zorder=6, + peaks_before["main_loc"], + waveform[peaks_before["main_loc"]] + y_offset, + c="green", + s=100, + marker="^", + edgecolors="red", + linewidths=1.5, + zorder=6, label="peak before" if not secondary_mask.any() else None, ) @@ -325,12 +345,23 @@ def _overlay_peak_trough_markers(ax, waveform, thresh, y_offset=0.0, min_peak_be secondary_mask[peaks_after["main_idx"]] = False if secondary_mask.any(): ax.scatter( - peaks_after["indices"][secondary_mask], peaks_after["values"][secondary_mask] + y_offset, - c="orange", s=30, marker="^", zorder=5, label="peak after", + peaks_after["indices"][secondary_mask], + peaks_after["values"][secondary_mask] + y_offset, + c="orange", + s=30, + marker="^", + zorder=5, + label="peak after", ) if peaks_after["main_loc"] is not None: ax.scatter( - peaks_after["main_loc"], waveform[peaks_after["main_loc"]] + y_offset, - c="orange", s=100, marker="^", edgecolors="red", linewidths=1.5, zorder=6, + peaks_after["main_loc"], + waveform[peaks_after["main_loc"]] + y_offset, + c="orange", + s=100, + marker="^", + edgecolors="red", + linewidths=1.5, + zorder=6, label="peak after" if not secondary_mask.any() else None, ) From 119396efa155e35199f41f6eb11c82511d71f84f Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Thu, 26 Mar 2026 14:10:45 -0400 Subject: [PATCH 09/15] good time chunks --- .../how_to/full_pipeline_with_bombcell.py | 24 ++++- .../curation/bombcell_curation.py | 88 ++++++++++++++++--- .../metrics/template/metrics.py | 3 +- .../metrics/template/template_metrics.py | 1 - 4 files changed, 100 insertions(+), 16 deletions(-) diff --git a/examples/how_to/full_pipeline_with_bombcell.py b/examples/how_to/full_pipeline_with_bombcell.py index 6ae6704cdd..7d0f82a969 100644 --- a/examples/how_to/full_pipeline_with_bombcell.py +++ b/examples/how_to/full_pipeline_with_bombcell.py @@ -124,13 +124,17 @@ compute_drift = True label_non_somatic = True split_non_somatic_good_mua = False +use_valid_periods = False # compute quality metrics only on good time chunks # RPV method: "sliding_rp" (default, sweeps RP range) or "llobet" (single RP value) rp_violation_method = "sliding_rp" +refractory_period_ms = 2.0 +censored_period_ms = 0.1 + qm_params = { "presence_ratio": {"bin_duration_s": 60}, - "rp_violation": {"refractory_period_ms": 2.0, "censored_period_ms": 0.1}, + "rp_violation": {"refractory_period_ms": refractory_period_ms, "censored_period_ms": censored_period_ms}, "sliding_rp_violation": { "exclude_ref_period_below_ms": 0.5, "max_ref_period_ms": 10.0, @@ -139,6 +143,16 @@ "drift": {"interval_s": 60, "min_spikes_per_interval": 100}, } +# Valid time periods parameters (only used if use_valid_periods = True) +# fp_threshold and fn_threshold are auto-derived from bombcell thresholds +valid_periods_params = { + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + "period_mode": "absolute", + "period_duration_s_absolute": 30.0, + "minimum_valid_period_duration": 180, +} + metric_names = ["amplitude_median", "snr", "amplitude_cutoff", "num_spikes", "presence_ratio", "firing_rate"] if rp_violation_method == "sliding_rp": @@ -154,9 +168,8 @@ if not analyzer.has_extension("principal_components"): analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) -# To add more metrics, append here and add a threshold below: -# metric_names.append("silhouette") # requires principal_components -# metric_names.append("d_prime") # requires principal_components +if use_valid_periods and not analyzer.has_extension("amplitude_scalings"): + analyzer.compute("amplitude_scalings", **job_kwargs) if analyzer.has_extension("quality_metrics"): analyzer.delete_extension("quality_metrics") @@ -193,6 +206,9 @@ thresholds=thresholds, label_non_somatic=label_non_somatic, split_non_somatic_good_mua=split_non_somatic_good_mua, + use_valid_periods=use_valid_periods, + valid_periods_params=valid_periods_params if use_valid_periods else None, + **job_kwargs, ) print(f"\nLabeled {len(bombcell_labels)} units") diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py index 2ae7f26c60..a0aabe39df 100644 --- a/src/spikeinterface/curation/bombcell_curation.py +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -105,6 +105,10 @@ def bombcell_label_units( label_non_somatic: bool = True, split_non_somatic_good_mua: bool = False, external_metrics: "pd.DataFrame | list[pd.DataFrame]" | None = None, + use_valid_periods: bool = False, + valid_periods_params: dict | None = None, + recompute_quality_metrics: bool = True, + **job_kwargs, ) -> "pd.DataFrame": """ Label units based on quality metrics and template metrics using Bombcell logic: @@ -152,6 +156,25 @@ def bombcell_label_units( If True, split non-somatic into "non_soma_good" and "non_soma_mua". external_metrics: "pd.DataFrame | list[pd.DataFrame]" | None = None External metrics DataFrame(s) (index = unit_ids) to use instead of those from SortingAnalyzer. + use_valid_periods : bool, default: False + If True, compute valid time periods per unit and recompute quality metrics restricted to + those periods before labeling. This uses the ``valid_unit_periods`` extension to identify + chunks with acceptable false positive (refractory violations) and false negative (amplitude + cutoff) rates. The FP/FN thresholds are derived from the bombcell thresholds + (``rpv`` → ``fp_threshold``, ``amplitude_cutoff`` → ``fn_threshold``). + Requires ``amplitude_scalings`` extension and Numba. + valid_periods_params : dict or None, default: None + Additional parameters passed to the ``valid_unit_periods`` extension computation. + Use this to set ``refractory_period_ms``, ``censored_period_ms``, ``period_mode``, + ``period_duration_s_absolute``, etc. Parameters ``fp_threshold`` and ``fn_threshold`` + are automatically derived from bombcell thresholds if not explicitly provided here. + recompute_quality_metrics : bool, default: True + If ``use_valid_periods`` is True, whether to recompute quality metrics restricted to valid + periods. If False, the existing quality metrics are used as-is (useful if you already + computed them with ``use_valid_periods=True``). + **job_kwargs + Job keyword arguments (n_jobs, chunk_duration, progress_bar) passed to + ``valid_unit_periods`` and ``quality_metrics`` computation when ``use_valid_periods=True``. Returns ------- @@ -166,6 +189,61 @@ def bombcell_label_units( """ import pandas as pd + # Parse thresholds early so we can derive valid_periods params from them + if thresholds is None: + thresholds_dict = bombcell_get_default_thresholds() + elif isinstance(thresholds, (str, Path)): + with open(thresholds, "r") as f: + thresholds_dict = json.load(f) + elif isinstance(thresholds, dict): + thresholds_dict = thresholds + else: + raise ValueError("thresholds must be a dict, a JSON file path, or None") + + # Compute valid periods and recompute quality metrics if requested + if use_valid_periods: + if sorting_analyzer is None: + raise ValueError("use_valid_periods=True requires a sorting_analyzer") + + # Derive fp/fn thresholds from bombcell thresholds + vp_params = dict(valid_periods_params) if valid_periods_params is not None else {} + + if "fp_threshold" not in vp_params: + rpv_thresh = thresholds_dict.get("mua", {}).get("rpv", {}).get("less", None) + if rpv_thresh is not None: + vp_params["fp_threshold"] = rpv_thresh + + if "fn_threshold" not in vp_params: + ac_thresh = thresholds_dict.get("mua", {}).get("amplitude_cutoff", {}).get("less", None) + if ac_thresh is not None: + vp_params["fn_threshold"] = ac_thresh + + # Compute valid_unit_periods + if sorting_analyzer.has_extension("valid_unit_periods"): + sorting_analyzer.delete_extension("valid_unit_periods") + sorting_analyzer.compute("valid_unit_periods", **vp_params, **job_kwargs) + + # Recompute quality metrics restricted to valid periods + if recompute_quality_metrics: + # Preserve existing quality metric settings (metric_names, metric_params) + qm_ext = sorting_analyzer.get_extension("quality_metrics") + if qm_ext is not None: + existing_params = qm_ext.params.copy() + existing_params.pop("periods", None) + existing_params.pop("use_valid_periods", None) + sorting_analyzer.delete_extension("quality_metrics") + sorting_analyzer.compute( + "quality_metrics", + use_valid_periods=True, + **existing_params, + **job_kwargs, + ) + else: + raise ValueError( + "use_valid_periods=True with recompute_quality_metrics=True requires " + "quality_metrics to have been computed at least once." + ) + if sorting_analyzer is not None: combined_metrics = sorting_analyzer.get_metrics_extension_data() if combined_metrics.empty: @@ -184,16 +262,6 @@ def bombcell_label_units( else: combined_metrics = external_metrics - if thresholds is None: - thresholds_dict = bombcell_get_default_thresholds() - elif isinstance(thresholds, (str, Path)): - with open(thresholds, "r") as f: - thresholds_dict = json.load(f) - elif isinstance(thresholds, dict): - thresholds_dict = thresholds - else: - raise ValueError("thresholds must be a dict, a JSON file path, or None") - # Map "rpv" threshold to actual column name (rp_contamination or sliding_rp_violation) if "mua" in thresholds_dict and "rpv" in thresholds_dict["mua"]: rpv_thresh = thresholds_dict["mua"].pop("rpv") diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py index 470ffe193a..b758ec2fb1 100644 --- a/src/spikeinterface/metrics/template/metrics.py +++ b/src/spikeinterface/metrics/template/metrics.py @@ -2,7 +2,6 @@ import numpy as np from collections import namedtuple -from scipy.signal import find_peaks, savgol_filter from spikeinterface.core.analyzer_extension_core import BaseMetric @@ -53,6 +52,8 @@ def get_trough_and_peak_idx( - "main_idx": index of the main peak (most prominent) - "main_loc": location (sample index) of the main peak in template """ + from scipy.signal import find_peaks, savgol_filter + assert template.ndim == 1 # Save original for plotting diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index ccd0309882..3d87e8d067 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -9,7 +9,6 @@ import numpy as np import warnings from copy import deepcopy -from scipy.signal import find_peaks from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension From 9f7eac709b3e0253305439d6331461eddb40d0e8 Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Thu, 26 Mar 2026 14:31:48 -0400 Subject: [PATCH 10/15] bombcell preprocessing (standard) wrapper script --- examples/how_to/preprocess_for_bombcell.py | 127 +++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 examples/how_to/preprocess_for_bombcell.py diff --git a/examples/how_to/preprocess_for_bombcell.py b/examples/how_to/preprocess_for_bombcell.py new file mode 100644 index 0000000000..be95e6b15f --- /dev/null +++ b/examples/how_to/preprocess_for_bombcell.py @@ -0,0 +1,127 @@ +""" +Preprocess recording and create SortingAnalyzer for BombCell. +Companion script: run_bombcell_labeling.py +""" + +import json +from pathlib import Path +import spikeinterface.full as si + +# %% Paths +recording_folder = Path("/path/to/your/recording") +preprocessed_folder = recording_folder / "preprocessed" +sorting_folder = recording_folder / "kilosort4_output" +analyzer_folder = recording_folder / "sorting_analyzer.zarr" + +# %% Rerun flags (set True to force recompute) +rerun_preprocessing = False +rerun_sorting = False +rerun_analyzer = False +rerun_extensions = False + +# %% Preprocessing parameters +preprocessing_params = dict( + highpass_freq_min=300.0, + detect_bad_channels=True, + apply_phase_shift=True, + apply_common_reference=True, + cmr_reference="global", + cmr_operator="median", +) + +# %% Sorter parameters +sorter_name = "kilosort4" +sorter_params = dict( + skip_kilosort_preprocessing=True, + do_CAR=False, +) + +# %% Extension parameters +extension_params = dict( + random_spikes=dict(method="uniform", max_spikes_per_unit=500), + waveforms=dict(ms_before=3.0, ms_after=3.0), + templates=dict(operators=["average", "median", "std"]), + template_metrics=dict(include_multi_channel_metrics=True), +) + +job_kwargs = dict(n_jobs=-1, chunk_duration="1s", progress_bar=True) + +# %% 1. Load recording +raw_rec = si.read_spikeglx(recording_folder, stream_name="imec0.ap", load_sync_channel=False) +print(f"Loaded: {raw_rec.get_num_channels()} channels, {raw_rec.get_total_duration():.1f}s") + +# %% 2. Preprocess +if (preprocessed_folder / "si_folder.json").exists() and not rerun_preprocessing: + print(f"Loading preprocessed from {preprocessed_folder}") + rec_preprocessed = si.load(preprocessed_folder) +else: + pp = preprocessing_params + rec = si.highpass_filter(raw_rec, freq_min=pp["highpass_freq_min"]) + + if pp["detect_bad_channels"]: + bad_ids, labels = si.detect_bad_channels(rec) + print(f"Bad channels: {list(bad_ids)}") + rec = rec.remove_channels(bad_ids) + preprocessed_folder.mkdir(parents=True, exist_ok=True) + with open(preprocessed_folder / "bad_channels.json", "w") as f: + json.dump({"bad_channel_ids": [str(c) for c in bad_ids]}, f) + + if pp["apply_phase_shift"]: + rec = si.phase_shift(rec) + if pp["apply_common_reference"]: + rec = si.common_reference(rec, reference=pp["cmr_reference"], operator=pp["cmr_operator"]) + + rec_preprocessed = rec.save(folder=preprocessed_folder, format="binary", **job_kwargs) + +# %% 3. Spike sorting +if sorting_folder.exists() and not rerun_sorting: + print(f"Loading sorting from {sorting_folder}") + sorting = si.read_sorter_folder(sorting_folder, register_recording=False) +else: + sorting = si.run_sorter( + sorter_name=sorter_name, + recording=rec_preprocessed, + folder=sorting_folder, + remove_existing_folder=True, + verbose=True, + **sorter_params, + ) +print(f"Units: {len(sorting.unit_ids)}") + +# %% 4. Create SortingAnalyzer +if analyzer_folder.exists() and not rerun_analyzer: + print(f"Loading analyzer from {analyzer_folder}") + analyzer = si.load_sorting_analyzer(analyzer_folder) + if not analyzer.has_recording(): + analyzer.set_temporary_recording(rec_preprocessed) +else: + analyzer = si.create_sorting_analyzer( + sorting=sorting, + recording=rec_preprocessed, + sparse=True, + format="zarr", + folder=analyzer_folder, + return_in_uV=True, + ) + +# %% 5. Compute extensions +def compute_ext(name, **kwargs): + if analyzer.has_extension(name) and not rerun_extensions: + return + if analyzer.has_extension(name): + analyzer.delete_extension(name) + print(f"Computing {name}...") + analyzer.compute(name, **kwargs) + +compute_ext("random_spikes", **extension_params["random_spikes"]) +compute_ext("waveforms", **extension_params["waveforms"], **job_kwargs) +compute_ext("templates", **extension_params["templates"]) +compute_ext("noise_levels") +compute_ext("spike_amplitudes", **job_kwargs) +compute_ext("unit_locations") +compute_ext("spike_locations", **job_kwargs) +compute_ext("template_metrics", **extension_params["template_metrics"]) + +print(f"\nDone. Analyzer saved to {analyzer_folder}") +print(f"Extensions: {analyzer.get_loaded_extension_names()}") +print(f"Next: run run_bombcell_labeling.py") From 82d9dcb4ce9ade080df9c5168dc53b912f61b531 Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Thu, 26 Mar 2026 15:21:09 -0400 Subject: [PATCH 11/15] bombcell preprocessing (standard) wrapper script --- examples/how_to/preprocess_for_bombcell.py | 127 ----- src/spikeinterface/curation/__init__.py | 4 + .../curation/bombcell_pipeline.py | 445 ++++++++++++++++++ 3 files changed, 449 insertions(+), 127 deletions(-) delete mode 100644 examples/how_to/preprocess_for_bombcell.py create mode 100644 src/spikeinterface/curation/bombcell_pipeline.py diff --git a/examples/how_to/preprocess_for_bombcell.py b/examples/how_to/preprocess_for_bombcell.py deleted file mode 100644 index be95e6b15f..0000000000 --- a/examples/how_to/preprocess_for_bombcell.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Preprocess recording and create SortingAnalyzer for BombCell. -Companion script: run_bombcell_labeling.py -""" - -import json -from pathlib import Path -import spikeinterface.full as si - -# %% Paths -recording_folder = Path("/path/to/your/recording") -preprocessed_folder = recording_folder / "preprocessed" -sorting_folder = recording_folder / "kilosort4_output" -analyzer_folder = recording_folder / "sorting_analyzer.zarr" - -# %% Rerun flags (set True to force recompute) -rerun_preprocessing = False -rerun_sorting = False -rerun_analyzer = False -rerun_extensions = False - -# %% Preprocessing parameters -preprocessing_params = dict( - highpass_freq_min=300.0, - detect_bad_channels=True, - apply_phase_shift=True, - apply_common_reference=True, - cmr_reference="global", - cmr_operator="median", -) - -# %% Sorter parameters -sorter_name = "kilosort4" -sorter_params = dict( - skip_kilosort_preprocessing=True, - do_CAR=False, -) - -# %% Extension parameters -extension_params = dict( - random_spikes=dict(method="uniform", max_spikes_per_unit=500), - waveforms=dict(ms_before=3.0, ms_after=3.0), - templates=dict(operators=["average", "median", "std"]), - template_metrics=dict(include_multi_channel_metrics=True), -) - -job_kwargs = dict(n_jobs=-1, chunk_duration="1s", progress_bar=True) - -# %% 1. Load recording -raw_rec = si.read_spikeglx(recording_folder, stream_name="imec0.ap", load_sync_channel=False) -print(f"Loaded: {raw_rec.get_num_channels()} channels, {raw_rec.get_total_duration():.1f}s") - -# %% 2. Preprocess -if (preprocessed_folder / "si_folder.json").exists() and not rerun_preprocessing: - print(f"Loading preprocessed from {preprocessed_folder}") - rec_preprocessed = si.load(preprocessed_folder) -else: - pp = preprocessing_params - rec = si.highpass_filter(raw_rec, freq_min=pp["highpass_freq_min"]) - - if pp["detect_bad_channels"]: - bad_ids, labels = si.detect_bad_channels(rec) - print(f"Bad channels: {list(bad_ids)}") - rec = rec.remove_channels(bad_ids) - preprocessed_folder.mkdir(parents=True, exist_ok=True) - with open(preprocessed_folder / "bad_channels.json", "w") as f: - json.dump({"bad_channel_ids": [str(c) for c in bad_ids]}, f) - - if pp["apply_phase_shift"]: - rec = si.phase_shift(rec) - if pp["apply_common_reference"]: - rec = si.common_reference(rec, reference=pp["cmr_reference"], operator=pp["cmr_operator"]) - - rec_preprocessed = rec.save(folder=preprocessed_folder, format="binary", **job_kwargs) - -# %% 3. Spike sorting -if sorting_folder.exists() and not rerun_sorting: - print(f"Loading sorting from {sorting_folder}") - sorting = si.read_sorter_folder(sorting_folder, register_recording=False) -else: - sorting = si.run_sorter( - sorter_name=sorter_name, - recording=rec_preprocessed, - folder=sorting_folder, - remove_existing_folder=True, - verbose=True, - **sorter_params, - ) -print(f"Units: {len(sorting.unit_ids)}") - -# %% 4. Create SortingAnalyzer -if analyzer_folder.exists() and not rerun_analyzer: - print(f"Loading analyzer from {analyzer_folder}") - analyzer = si.load_sorting_analyzer(analyzer_folder) - if not analyzer.has_recording(): - analyzer.set_temporary_recording(rec_preprocessed) -else: - analyzer = si.create_sorting_analyzer( - sorting=sorting, - recording=rec_preprocessed, - sparse=True, - format="zarr", - folder=analyzer_folder, - return_in_uV=True, - ) - -# %% 5. Compute extensions -def compute_ext(name, **kwargs): - if analyzer.has_extension(name) and not rerun_extensions: - return - if analyzer.has_extension(name): - analyzer.delete_extension(name) - print(f"Computing {name}...") - analyzer.compute(name, **kwargs) - -compute_ext("random_spikes", **extension_params["random_spikes"]) -compute_ext("waveforms", **extension_params["waveforms"], **job_kwargs) -compute_ext("templates", **extension_params["templates"]) -compute_ext("noise_levels") -compute_ext("spike_amplitudes", **job_kwargs) -compute_ext("unit_locations") -compute_ext("spike_locations", **job_kwargs) -compute_ext("template_metrics", **extension_params["template_metrics"]) - -print(f"\nDone. Analyzer saved to {analyzer_folder}") -print(f"Extensions: {analyzer.get_loaded_extension_names()}") -print(f"Next: run run_bombcell_labeling.py") diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 16bdddd870..27c50e117a 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -30,3 +30,7 @@ bombcell_label_units, save_bombcell_results, ) +from .bombcell_pipeline import ( + preprocess_for_bombcell, + run_bombcell_qc, +) diff --git a/src/spikeinterface/curation/bombcell_pipeline.py b/src/spikeinterface/curation/bombcell_pipeline.py new file mode 100644 index 0000000000..ea8e048d82 --- /dev/null +++ b/src/spikeinterface/curation/bombcell_pipeline.py @@ -0,0 +1,445 @@ +""" +BombCell pipeline functions for preprocessing and quality control. + +Functions +--------- +preprocess_for_bombcell + Preprocess recording and create SortingAnalyzer with required extensions. +run_bombcell_qc + Compute quality metrics, run BombCell labeling, and generate plots. +""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np + + +def preprocess_for_bombcell( + recording, + sorting, + analyzer_folder: str | Path, + # Preprocessing + freq_min: float = 300.0, + detect_bad_channels: bool = True, + bad_channel_method: str = "coherence+psd", + apply_phase_shift: bool = True, + apply_cmr: bool = True, + cmr_reference: str = "global", + cmr_operator: str = "median", + # Analyzer + sparse: bool = True, + return_in_uV: bool = True, + # Extensions + max_spikes_per_unit: int = 500, + ms_before: float = 3.0, + ms_after: float = 3.0, + template_operators: list[str] = ("average", "median", "std"), + include_multi_channel_metrics: bool = True, + # Rerun flags + rerun_extensions: bool = False, + # Job + n_jobs: int = -1, + progress_bar: bool = True, +): + """ + Preprocess recording and create SortingAnalyzer with extensions for BombCell. + + Parameters + ---------- + recording : BaseRecording + Raw recording to preprocess. + sorting : BaseSorting + Spike sorting result. + analyzer_folder : str or Path + Path to save the SortingAnalyzer (zarr format). + + Preprocessing Parameters + ------------------------ + freq_min : float, default: 300.0 + Highpass filter cutoff frequency in Hz. + detect_bad_channels : bool, default: True + Detect and remove bad channels. + bad_channel_method : str, default: "coherence+psd" + Method for bad channel detection. + apply_phase_shift : bool, default: True + Apply phase shift correction (for Neuropixels). + apply_cmr : bool, default: True + Apply common median reference. + cmr_reference : str, default: "global" + Reference type: "global", "local", or "single". + cmr_operator : str, default: "median" + Operator: "median" or "average". + + Analyzer Parameters + ------------------- + sparse : bool, default: True + Use sparse waveform representation. + return_in_uV : bool, default: True + Return waveforms in microvolts. + + Extension Parameters + -------------------- + max_spikes_per_unit : int, default: 500 + Number of spikes to extract per unit for waveforms. + ms_before : float, default: 3.0 + Milliseconds before spike peak for waveform extraction. + ms_after : float, default: 3.0 + Milliseconds after spike peak for waveform extraction. + template_operators : list, default: ("average", "median", "std") + Template statistics to compute. + include_multi_channel_metrics : bool, default: True + Include multi-channel template metrics (exp_decay, etc.). + rerun_extensions : bool, default: False + Force recomputation of existing extensions. + + Job Parameters + -------------- + n_jobs : int, default: -1 + Number of parallel jobs (-1 for all CPUs). + progress_bar : bool, default: True + Show progress bars. + + Returns + ------- + analyzer : SortingAnalyzer + SortingAnalyzer with computed extensions. + rec_preprocessed : BaseRecording + Preprocessed recording (lazy, not saved). + bad_channel_ids : list or None + List of removed bad channel IDs, or None if detection disabled. + """ + import spikeinterface.full as si + + analyzer_folder = Path(analyzer_folder) + job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=progress_bar) + + # Preprocess + rec = si.highpass_filter(recording, freq_min=freq_min) + + bad_channel_ids = None + if detect_bad_channels: + bad_channel_ids, _ = si.detect_bad_channels(rec, method=bad_channel_method) + bad_channel_ids = list(bad_channel_ids) + if len(bad_channel_ids) > 0: + rec = rec.remove_channels(bad_channel_ids) + + if apply_phase_shift: + rec = si.phase_shift(rec) + + if apply_cmr: + rec = si.common_reference(rec, reference=cmr_reference, operator=cmr_operator) + + rec_preprocessed = rec + + # Create or load analyzer + if analyzer_folder.exists(): + analyzer = si.load_sorting_analyzer(analyzer_folder) + if not analyzer.has_recording(): + analyzer.set_temporary_recording(rec_preprocessed) + else: + analyzer = si.create_sorting_analyzer( + sorting=sorting, + recording=rec_preprocessed, + sparse=sparse, + format="zarr", + folder=analyzer_folder, + return_in_uV=return_in_uV, + ) + + # Compute extensions + def _compute(name, **kwargs): + if analyzer.has_extension(name) and not rerun_extensions: + return + if analyzer.has_extension(name): + analyzer.delete_extension(name) + analyzer.compute(name, **kwargs) + + _compute("random_spikes", method="uniform", max_spikes_per_unit=max_spikes_per_unit) + _compute("waveforms", ms_before=ms_before, ms_after=ms_after, **job_kwargs) + _compute("templates", operators=list(template_operators)) + _compute("noise_levels") + _compute("spike_amplitudes", **job_kwargs) + _compute("unit_locations") + _compute("spike_locations", **job_kwargs) + _compute("template_metrics", include_multi_channel_metrics=include_multi_channel_metrics) + + return analyzer, rec_preprocessed, bad_channel_ids + + +def run_bombcell_qc( + sorting_analyzer, + output_folder: str | Path | None = None, + # Quality metric options + compute_distance_metrics: bool = False, + compute_drift: bool = True, + rp_method: str = "sliding_rp", + # Quality metric parameters + qm_params: dict | None = None, + # BombCell options + label_non_somatic: bool = True, + split_non_somatic_good_mua: bool = False, + use_valid_periods: bool = False, + valid_periods_params: dict | None = None, + # Thresholds (None = use defaults) + thresholds: dict | None = None, + # Plotting + plot_histograms: bool = True, + plot_waveforms: bool = True, + plot_upset: bool = True, + waveform_ylims: tuple | None = (-300, 100), + figsize_histograms: tuple = (15, 10), + # Rerun + rerun_quality_metrics: bool = False, + rerun_pca: bool = False, + # Job + n_jobs: int = -1, + progress_bar: bool = True, +): + """ + Compute quality metrics and run BombCell unit labeling. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + Analyzer with template_metrics computed (from preprocess_for_bombcell). + output_folder : str, Path, or None, default: None + Folder to save results (CSV files and plots). If None, results not saved. + + Quality Metric Options + ---------------------- + compute_distance_metrics : bool, default: False + Compute isolation_distance and l_ratio (requires PCA, slow). + compute_drift : bool, default: True + Compute drift metrics. + rp_method : str, default: "sliding_rp" + Refractory period violation method: "sliding_rp" or "llobet". + + Quality Metric Parameters + ------------------------- + qm_params : dict or None, default: None + Override default quality metric parameters. Keys are metric names, + values are parameter dicts. Default parameters: + + - presence_ratio: {"bin_duration_s": 60} + - rp_violation: {"refractory_period_ms": 2.0, "censored_period_ms": 0.1} + - sliding_rp_violation: {"exclude_ref_period_below_ms": 0.5, + "max_ref_period_ms": 10.0, "confidence_threshold": 0.9} + - drift: {"interval_s": 60, "min_spikes_per_interval": 100} + + BombCell Options + ---------------- + label_non_somatic : bool, default: True + Detect non-somatic (axonal/dendritic) units. + split_non_somatic_good_mua : bool, default: False + Split non-somatic into "non_soma_good" and "non_soma_mua". + use_valid_periods : bool, default: False + Restrict metrics to valid time periods per unit. + valid_periods_params : dict or None, default: None + Parameters for valid_unit_periods extension. + + Classification Thresholds + ------------------------- + thresholds : dict or None, default: None + BombCell thresholds dict with "noise", "mua", "non-somatic" sections. + If None, uses bombcell_get_default_thresholds(). Default values: + + noise (waveform quality - any failure -> "noise"): + - num_positive_peaks: < 2 + - num_negative_peaks: < 1 + - peak_to_trough_duration: 0.1-1.15 ms + - waveform_baseline_flatness: < 0.5 + - peak_after_to_trough_ratio: < 0.8 + - exp_decay: 0.01-0.1 + + mua (spike quality - any failure -> "mua"): + - amplitude_median: > 30 uV (absolute value) + - snr: > 5 + - amplitude_cutoff: < 0.2 + - num_spikes: > 300 + - rpv: < 0.1 (refractory period violations) + - presence_ratio: > 0.7 + - drift_ptp: < 100 um + - isolation_distance: > 20 (if computed) + - l_ratio: < 0.3 (if computed) + + non-somatic (waveform shape): + - peak_before_to_trough_ratio: < 3 + - peak_before_width: > 0.15 ms + - trough_width: > 0.2 ms + - peak_before_to_peak_after_ratio: < 3 + - main_peak_to_trough_ratio: < 0.8 + + To modify thresholds: + thresholds = bombcell_get_default_thresholds() + thresholds["mua"]["rpv"]["less"] = 0.05 # stricter RPV + thresholds["mua"]["num_spikes"]["greater"] = 100 # lower spike count + + To disable a threshold: + thresholds["mua"]["drift_ptp"] = {"greater": None, "less": None} + + Plotting Options + ---------------- + plot_histograms : bool, default: True + Plot metric histograms with threshold lines. + plot_waveforms : bool, default: True + Plot waveforms grouped by label. + plot_upset : bool, default: True + Plot UpSet plots showing metric failure combinations. + waveform_ylims : tuple or None, default: (-300, 100) + Y-axis limits for waveform plots. + figsize_histograms : tuple, default: (15, 10) + Figure size for histogram plot. + + Rerun Options + ------------- + rerun_quality_metrics : bool, default: False + Force recomputation of quality metrics. + rerun_pca : bool, default: False + Force recomputation of PCA (only if compute_distance_metrics=True). + + Job Parameters + -------------- + n_jobs : int, default: -1 + Number of parallel jobs. + progress_bar : bool, default: True + Show progress bars. + + Returns + ------- + labels : pd.DataFrame + DataFrame with unit_ids as index and "bombcell_label" column. + Labels: "good", "mua", "noise", "non_soma" (or "non_soma_good"/"non_soma_mua"). + metrics : pd.DataFrame + Combined quality metrics and template metrics. + figures : dict + Dictionary of matplotlib figures: {"histograms", "waveforms", "upset"}. + """ + import pandas as pd + + from .bombcell_curation import bombcell_get_default_thresholds, bombcell_label_units, save_bombcell_results + + job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=progress_bar) + + if output_folder is not None: + output_folder = Path(output_folder) + output_folder.mkdir(parents=True, exist_ok=True) + + # Default QM params + default_qm_params = { + "presence_ratio": {"bin_duration_s": 60}, + "rp_violation": {"refractory_period_ms": 2.0, "censored_period_ms": 0.1}, + "sliding_rp_violation": { + "exclude_ref_period_below_ms": 0.5, + "max_ref_period_ms": 10.0, + "confidence_threshold": 0.9, + }, + "drift": {"interval_s": 60, "min_spikes_per_interval": 100}, + } + if qm_params is not None: + for key, val in qm_params.items(): + default_qm_params[key] = val + qm_params = default_qm_params + + # Build metric names list + metric_names = [ + "amplitude_median", + "snr", + "amplitude_cutoff", + "num_spikes", + "presence_ratio", + "firing_rate", + ] + + if rp_method == "sliding_rp": + metric_names.append("sliding_rp_violation") + else: + metric_names.append("rp_violation") + + if compute_drift: + metric_names.append("drift") + + if compute_distance_metrics: + metric_names.append("mahalanobis") + if not sorting_analyzer.has_extension("principal_components") or rerun_pca: + if sorting_analyzer.has_extension("principal_components"): + sorting_analyzer.delete_extension("principal_components") + sorting_analyzer.compute( + "principal_components", n_components=5, mode="by_channel_local", **job_kwargs + ) + + if use_valid_periods and not sorting_analyzer.has_extension("amplitude_scalings"): + sorting_analyzer.compute("amplitude_scalings", **job_kwargs) + + # Compute quality metrics + if sorting_analyzer.has_extension("quality_metrics") and rerun_quality_metrics: + sorting_analyzer.delete_extension("quality_metrics") + + if not sorting_analyzer.has_extension("quality_metrics"): + sorting_analyzer.compute( + "quality_metrics", metric_names=metric_names, metric_params=qm_params, **job_kwargs + ) + + # Get thresholds + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + + # Run BombCell labeling + labels = bombcell_label_units( + sorting_analyzer=sorting_analyzer, + thresholds=thresholds, + label_non_somatic=label_non_somatic, + split_non_somatic_good_mua=split_non_somatic_good_mua, + use_valid_periods=use_valid_periods, + valid_periods_params=valid_periods_params, + **job_kwargs, + ) + + metrics = sorting_analyzer.get_metrics_extension_data() + + # Generate plots + figures = {} + + if plot_histograms or plot_waveforms or plot_upset: + import spikeinterface.widgets as sw + + if plot_histograms: + w = sw.plot_metric_histograms(sorting_analyzer, thresholds, figsize=figsize_histograms) + figures["histograms"] = w.figure + + if plot_waveforms: + w = sw.plot_unit_labels(sorting_analyzer, labels["bombcell_label"], ylims=waveform_ylims) + figures["waveforms"] = w.figure + + if plot_upset: + w = sw.plot_bombcell_labels_upset( + sorting_analyzer, + unit_labels=labels["bombcell_label"], + thresholds=thresholds, + unit_labels_to_plot=["noise", "mua"], + ) + figures["upset"] = w.figures + + # Save results + if output_folder is not None: + save_bombcell_results( + metrics=metrics, + unit_label=labels["bombcell_label"].values, + thresholds=thresholds, + folder=output_folder, + ) + + # Save figures + if "histograms" in figures: + figures["histograms"].savefig(output_folder / "metric_histograms.png", dpi=150, bbox_inches="tight") + if "waveforms" in figures: + figures["waveforms"].savefig(output_folder / "waveforms_by_label.png", dpi=150, bbox_inches="tight") + if "upset" in figures: + for i, fig in enumerate(figures["upset"]): + fig.savefig(output_folder / f"upset_plot_{i}.png", dpi=150, bbox_inches="tight") + + print(f"Labeled {len(labels)} units:") + print(labels["bombcell_label"].value_counts().to_string()) + + return labels, metrics, figures From c6237f6e393e7615cc9f6e08fc85034f0ad29fb6 Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Thu, 26 Mar 2026 15:21:11 -0400 Subject: [PATCH 12/15] bombcell preprocessing (standard) wrapper script --- examples/how_to/bombcell_pipeline_example.py | 207 +++++ ...> full_detailed_pipeline_with_bombcell.py} | 0 src/spikeinterface/curation/__init__.py | 2 + .../curation/bombcell_pipeline.py | 742 ++++++++++++------ 4 files changed, 695 insertions(+), 256 deletions(-) create mode 100644 examples/how_to/bombcell_pipeline_example.py rename examples/how_to/{full_pipeline_with_bombcell.py => full_detailed_pipeline_with_bombcell.py} (100%) diff --git a/examples/how_to/bombcell_pipeline_example.py b/examples/how_to/bombcell_pipeline_example.py new file mode 100644 index 0000000000..871709a4d1 --- /dev/null +++ b/examples/how_to/bombcell_pipeline_example.py @@ -0,0 +1,207 @@ +""" +BombCell pipeline example. + +This example shows how to use the preprocessing and QC wrapper functions +with customizable parameters and thresholds. +""" + +from pathlib import Path +import spikeinterface.full as si +import spikeinterface.curation as sc + +# %% Paths - edit these to match your data +spikeglx_folder = Path("/path/to/your/spikeglx/recording") # folder containing .ap.bin and .ap.meta files +sorting_folder = spikeglx_folder / "kilosort4_output" # folder with spike sorter output (spike_times.npy, etc.) +analyzer_folder = spikeglx_folder / "sorting_analyzer.zarr" # where to save the SortingAnalyzer (will be created) +output_folder = spikeglx_folder / "bombcell" # where to save BombCell results (metrics, plots, labels) + +# %% Load data +# NOTE: Recording and sorting are always needed - even when loading an existing analyzer, +# because the analyzer may need the recording for some computations. +recording = si.read_spikeglx(spikeglx_folder, stream_name="imec0.ap", load_sync_channel=False) # load raw recording +sorting = si.read_sorter_folder(sorting_folder, register_recording=False) # load spike sorting results + +# %% Preprocessing parameters +# Get defaults and modify as needed +preproc_params = sc.get_default_preprocessing_params() + +# Example parameters you may want to modify: + +# --- Filtering --- +preproc_params["freq_min"] = 300.0 # highpass cutoff in Hz - removes LFP/low-frequency noise + +# --- Bad channel detection --- +preproc_params["detect_bad_channels"] = True # auto-detect and remove bad/dead channels +preproc_params["bad_channel_method"] = "coherence+psd" # "coherence+psd" (recommended), "mad", or "std" + +# --- Phase shift (essential for Neuropixels) --- +preproc_params["apply_phase_shift"] = True # correct small timing offsets from multiplexed ADC sampling + # critical for common reference to work properly + +# --- Common reference --- +preproc_params["apply_cmr"] = True # subtract common signal to remove shared noise +preproc_params["cmr_reference"] = "global" # "global" (all chans), "local" (nearby), "single" (one chan) +preproc_params["cmr_operator"] = "median" # "median" (robust to outliers) or "average" + +# --- Waveform extraction --- +preproc_params["max_spikes_per_unit"] = 500 # spikes to extract per unit (more = accurate but slower) +preproc_params["ms_before"] = 3.0 # ms before spike peak to include in waveform +preproc_params["ms_after"] = 3.0 # ms after spike peak to include in waveform + +# %% QC parameters +qc_params = sc.get_default_qc_params() + +# Example parameters you may want to modify: + +# --- Metrics to compute --- +qc_params["compute_drift"] = True # compute drift metrics (position changes over time) +qc_params["compute_distance_metrics"] = False # isolation distance & L-ratio - slow, not drift-robust + # recommend True for stable/chronic recordings +qc_params["rp_method"] = "sliding_rp" # refractory period method: "sliding_rp" or "llobet" + +# --- BombCell classification options --- +qc_params["label_non_somatic"] = True # detect axonal/dendritic units via waveform shape +qc_params["split_non_somatic_good_mua"] = False # if True, non-somatic split into good/mua subcategories + +# --- Presence ratio --- +qc_params["presence_ratio_bin_duration_s"] = 60 # bin size (s) for checking if unit fires throughout recording + +# --- Refractory period violations --- +qc_params["refractory_period_ms"] = 2.0 # expected refractory period - use 1.0-1.5 for fast-spiking +qc_params["censored_period_ms"] = 0.1 # ignore ISIs shorter than this (spike sorting artifact) + +# --- Sliding RP method parameters --- +qc_params["sliding_rp_exclude_below_ms"] = 0.5 # exclude ISIs below this when fitting contamination +qc_params["sliding_rp_max_ms"] = 10.0 # max ISI to consider for refractory period analysis +qc_params["sliding_rp_confidence"] = 0.9 # confidence level for contamination estimate (0-1) + +# --- Drift parameters --- +qc_params["drift_interval_s"] = 60 # time bin (s) for computing position over time +qc_params["drift_min_spikes"] = 100 # min spikes in bin to estimate position (skip if fewer) + +# --- Plotting --- +qc_params["plot_histograms"] = True # save histogram plots of all metrics +qc_params["plot_waveforms"] = True # save waveform plots for each unit +qc_params["plot_upset"] = True # save UpSet plot showing threshold failure combinations + +# %% Classification thresholds +# Format: {"greater": min_value, "less": max_value} - unit passes if min < value < max +# Use None to disable a bound. Add "abs": True to use absolute value. +thresholds = sc.bombcell_get_default_thresholds() + +# --- Noise thresholds (waveform quality) --- +# Units failing ANY of these are labeled "noise" (not neural signals) +thresholds["noise"]["num_positive_peaks"] = {"greater": None, "less": 2} # max positive peaks in waveform (>1 = multi-unit/noise) +thresholds["noise"]["num_negative_peaks"] = {"greater": None, "less": 1} # max negative peaks (>0 unusual for somatic spikes) +thresholds["noise"]["peak_to_trough_duration"] = {"greater": 0.0001, "less": 0.00115} # spike width in seconds (0.1-1.15ms is physiological) +thresholds["noise"]["waveform_baseline_flatness"] = {"greater": None, "less": 0.5} # baseline variation (high = noisy/unstable) +thresholds["noise"]["peak_after_to_trough_ratio"] = {"greater": None, "less": 0.8} # repolarization peak vs trough amplitude +thresholds["noise"]["exp_decay"] = {"greater": 0.01, "less": 0.1} # exponential decay constant of waveform tail + +# --- MUA thresholds (spike quality) --- +# Units failing ANY of these (that passed noise) are labeled "mua" (multi-unit activity) +thresholds["mua"]["amplitude_median"] = {"greater": 30, "less": None, "abs": True} # minimum amplitude in uV (abs=True uses |amplitude|) +thresholds["mua"]["snr"] = {"greater": 5, "less": None} # signal-to-noise ratio (higher = cleaner unit) +thresholds["mua"]["amplitude_cutoff"] = {"greater": None, "less": 0.2} # fraction of spikes below detection threshold (0 = none missing) +thresholds["mua"]["num_spikes"] = {"greater": 300, "less": None} # minimum spike count (too few = unreliable metrics) +thresholds["mua"]["rpv"] = {"greater": None, "less": 0.1} # refractory period violation rate (0 = perfect isolation) +thresholds["mua"]["presence_ratio"] = {"greater": 0.7, "less": None} # fraction of recording with spikes (1 = fires throughout) +thresholds["mua"]["drift_ptp"] = {"greater": None, "less": 100} # peak-to-peak position drift in um (lower = more stable) + +# Optional distance metrics (only used if compute_distance_metrics=True) +thresholds["mua"]["isolation_distance"] = {"greater": 20, "less": None} # Mahalanobis distance to nearest cluster (higher = better isolated) +thresholds["mua"]["l_ratio"] = {"greater": None, "less": 0.3} # L-ratio contamination estimate (lower = better isolated) + +# --- Non-somatic thresholds (waveform shape) --- +# Detects axonal/dendritic units based on waveform features (these have different shapes than somatic spikes) +thresholds["non-somatic"]["peak_before_to_trough_ratio"] = {"greater": None, "less": 3} # ratio of pre-peak to trough amplitude +thresholds["non-somatic"]["peak_before_width"] = {"greater": 0.00015, "less": None} # width of peak before trough in seconds +thresholds["non-somatic"]["trough_width"] = {"greater": 0.0002, "less": None} # width of main trough in seconds +thresholds["non-somatic"]["peak_before_to_peak_after_ratio"] = {"greater": None, "less": 3} # ratio of pre-peak to post-peak amplitude +thresholds["non-somatic"]["main_peak_to_trough_ratio"] = {"greater": None, "less": 0.8} # ratio of main peak to trough amplitude + +# %% Adding custom quality metrics +# You can add ANY metric from the SortingAnalyzer's quality_metrics or +# template_metrics DataFrame to ANY threshold section (noise, mua, non-somatic). +# +# How it works: +# - Metrics in "noise" section: unit fails if ANY threshold is violated → labeled "noise" +# - Metrics in "mua" section: unit fails if ANY threshold is violated → labeled "mua" +# - Metrics in "non-somatic" section: OR'd with built-in waveform shape checks +# - Metrics that haven't been computed are automatically skipped (with a warning) +# +# Threshold format: +# {"greater": min_value, "less": max_value} - unit passes if min < value < max +# {"greater": min_value, "less": max_value, "abs": True} - uses |value| for comparison +# Use None to disable one bound (e.g., {"greater": 0.1, "less": None} means value > 0.1) +# +# Examples of adding custom metrics: +# thresholds["mua"]["firing_rate"] = {"greater": 0.1, "less": None} # exclude units with firing rate < 0.1 Hz +# thresholds["mua"]["silhouette"] = {"greater": 0.4, "less": None} # silhouette score (requires PCA) +# thresholds["noise"]["half_width"] = {"greater": 0.05e-3, "less": 0.6e-3} # spike half-width bounds (template_metrics) +# thresholds["non-somatic"]["velocity_above"] = {"greater": 2.0, "less": None} # axonal propagation velocity +# +# To DISABLE an existing threshold (skip it entirely): +# thresholds["mua"]["drift_ptp"] = {"greater": None, "less": None} # both bounds None = threshold ignored +# +# Available metrics depend on what extensions are computed. Common ones include: +# Quality metrics: amplitude_median, snr, amplitude_cutoff, num_spikes, presence_ratio, +# firing_rate, isi_violation, sliding_rp_violation, drift, isolation_distance, l_ratio +# Template metrics: peak_to_valley, half_width, repolarization_slope, recovery_slope, +# num_positive_peaks, num_negative_peaks, velocity_above, velocity_below, exp_decay + +# %% Step 1: Preprocess and create analyzer +# This applies all preprocessing steps and extracts waveforms into a SortingAnalyzer. +# +# IMPORTANT: If analyzer_folder already exists, the existing analyzer is LOADED (not recreated). +# Extensions are only computed if they don't already exist - nothing is recomputed by default. +# To force recomputation, use rerun_extensions=True. +analyzer, rec_preprocessed, bad_channels = sc.preprocess_for_bombcell( + recording=recording, # raw recording object + sorting=sorting, # spike sorting results + analyzer_folder=analyzer_folder, # if exists: loads it; if not: creates it + params=preproc_params, # preprocessing parameters defined above + rerun_extensions=False, # False (default): skip existing extensions; True: recompute all + n_jobs=-1, # parallel jobs: -1 = all CPUs, 1 = single-threaded +) +# Returns: +# analyzer: SortingAnalyzer with waveforms, templates, and extensions computed +# rec_preprocessed: the preprocessed recording (filtered, referenced, etc.) +# bad_channels: list of channel IDs that were detected as bad and removed (None if loaded from disk) +print(f"Bad channels removed: {bad_channels}") + +# %% Step 2: Run BombCell QC +# This computes quality metrics and classifies units as good/mua/noise/non-somatic. +# +# IMPORTANT: Quality metrics are only computed if they don't already exist in the analyzer. +# To force recomputation (e.g., after changing qc_params), use rerun_quality_metrics=True. +labels, metrics, figures = sc.run_bombcell_qc( + sorting_analyzer=analyzer, # SortingAnalyzer from step 1 + output_folder=output_folder, # where to save results (CSVs, plots) + params=qc_params, # QC parameters defined above + thresholds=thresholds, # classification thresholds defined above + rerun_quality_metrics=False, # False (default): use existing metrics; True: recompute + n_jobs=-1, # parallel jobs: -1 = all CPUs, 1 = single-threaded +) +# Returns: +# labels: DataFrame with unit_id index and 'bombcell_label' column (good/mua/noise/non_soma) +# metrics: DataFrame with all computed quality metrics for each unit +# figures: dict of matplotlib figures (histograms, waveforms, upset plot) + +# %% Results +print(f"\nResults saved to: {output_folder}") +print(f"\nLabel distribution:\n{labels['bombcell_label'].value_counts()}") + +# Get units by label +good_units = labels[labels["bombcell_label"] == "good"].index.tolist() +mua_units = labels[labels["bombcell_label"] == "mua"].index.tolist() +noise_units = labels[labels["bombcell_label"] == "noise"].index.tolist() +non_soma_units = labels[labels["bombcell_label"] == "non_soma"].index.tolist() + +print(f"\nGood units ({len(good_units)}): {good_units[:10]}...") +print(f"MUA units ({len(mua_units)}): {mua_units[:10]}...") + +# %% Access metrics for specific units +print(f"\nMetrics for first good unit:") +if good_units: + print(metrics.loc[good_units[0]]) diff --git a/examples/how_to/full_pipeline_with_bombcell.py b/examples/how_to/full_detailed_pipeline_with_bombcell.py similarity index 100% rename from examples/how_to/full_pipeline_with_bombcell.py rename to examples/how_to/full_detailed_pipeline_with_bombcell.py diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 27c50e117a..cfe0e749d9 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -31,6 +31,8 @@ save_bombcell_results, ) from .bombcell_pipeline import ( + get_default_preprocessing_params, + get_default_qc_params, preprocess_for_bombcell, run_bombcell_qc, ) diff --git a/src/spikeinterface/curation/bombcell_pipeline.py b/src/spikeinterface/curation/bombcell_pipeline.py index ea8e048d82..f3e6b466dd 100644 --- a/src/spikeinterface/curation/bombcell_pipeline.py +++ b/src/spikeinterface/curation/bombcell_pipeline.py @@ -1,135 +1,415 @@ """ BombCell pipeline functions for preprocessing and quality control. +This module provides wrapper functions for running the full BombCell quality +control pipeline on spike-sorted data. + Functions --------- +get_default_preprocessing_params + Get default parameters for preprocessing and SortingAnalyzer creation. +get_default_qc_params + Get default parameters for quality metrics and BombCell labeling. preprocess_for_bombcell Preprocess recording and create SortingAnalyzer with required extensions. run_bombcell_qc Compute quality metrics, run BombCell labeling, and generate plots. + +See Also +-------- +bombcell_get_default_thresholds : Get default classification thresholds. +bombcell_label_units : Core labeling function. """ from __future__ import annotations - from pathlib import Path -import numpy as np + +def get_default_preprocessing_params(): + """ + Get default parameters for preprocessing and SortingAnalyzer creation. + + Returns a dictionary that can be modified and passed to preprocess_for_bombcell(). + + Returns + ------- + dict + Dictionary with the following keys: + + **Highpass Filtering** + + freq_min : float, default: 300.0 + Highpass filter cutoff frequency in Hz. Removes low-frequency noise + and LFP signals. Standard value for spike detection is 300 Hz. + Lower values (150-250 Hz) may be used if spikes have significant + low-frequency components. + + **Bad Channel Detection** + + detect_bad_channels : bool, default: True + Whether to automatically detect and remove bad channels before + further processing. Recommended to leave enabled. + + bad_channel_method : str, default: "coherence+psd" + Method for detecting bad channels: + - "coherence+psd": Combines coherence with neighbors and power + spectral density analysis. Best for Neuropixels. (recommended) + - "mad": Median absolute deviation of signal amplitude. + - "std": Standard deviation based detection. + + **Phase Shift Correction** + + apply_phase_shift : bool, default: True + Whether to apply inter-sample phase shift correction. Essential for + Neuropixels probes where ADCs are multiplexed and channels are sampled + at slightly different times. Corrects for timing offsets that can + affect spike waveform shapes. Disable for non-Neuropixels probes. + + **Common Reference** + + apply_cmr : bool, default: True + Whether to apply common reference to remove correlated noise. + Highly recommended for Neuropixels recordings. + + cmr_reference : str, default: "global" + Type of common reference: + - "global": Use all channels (recommended for Neuropixels). + - "local": Use nearby channels only (for probes with distinct groups). + - "single": Reference to a single channel. + + cmr_operator : str, default: "median" + Operation for computing reference signal: + - "median": More robust to outliers (recommended). + - "average": Standard mean reference. + + **SortingAnalyzer Settings** + + sparse : bool, default: True + Use sparse waveform representation, storing only channels near each + unit. Significantly reduces memory usage for high-channel-count probes. + Recommended True for Neuropixels. + + return_in_uV : bool, default: True + Convert waveforms to microvolts using gain/offset from probe metadata. + Required for amplitude-based quality metrics to be meaningful. + + **Waveform Extraction** + + max_spikes_per_unit : int, default: 500 + Maximum number of spikes to extract per unit for waveform analysis. + Higher values give better templates but use more memory/time. + 500 is typically sufficient for stable template estimation. + + ms_before : float, default: 3.0 + Milliseconds before spike peak to extract. 3.0 ms captures the + pre-spike baseline and any pre-depolarization. + + ms_after : float, default: 3.0 + Milliseconds after spike peak to extract. 3.0 ms captures the + repolarization and afterhyperpolarization. + + **Template Computation** + + template_operators : list, default: ["average", "median", "std"] + Statistics to compute for templates: + - "average": Mean waveform (standard template). + - "median": Median waveform (robust to outliers). + - "std": Standard deviation (waveform variability). + + include_multi_channel_metrics : bool, default: True + Compute template metrics across multiple channels, including: + - exp_decay: Exponential decay of amplitude across channels. + - velocity: Propagation velocity estimate. + Required for BombCell noise detection. Leave True. + + Examples + -------- + >>> params = get_default_preprocessing_params() + >>> params["freq_min"] = 250.0 # Lower cutoff for some cell types + >>> params["max_spikes_per_unit"] = 1000 # More spikes for better templates + >>> analyzer, rec, bad_chs = preprocess_for_bombcell(recording, sorting, "analyzer.zarr", params=params) + """ + return { + # Highpass filtering + "freq_min": 300.0, + # Bad channel detection + "detect_bad_channels": True, + "bad_channel_method": "coherence+psd", + # Phase shift (Neuropixels) + "apply_phase_shift": True, + # Common reference + "apply_cmr": True, + "cmr_reference": "global", + "cmr_operator": "median", + # SortingAnalyzer + "sparse": True, + "return_in_uV": True, + # Waveforms + "max_spikes_per_unit": 500, + "ms_before": 3.0, + "ms_after": 3.0, + # Templates + "template_operators": ["average", "median", "std"], + "include_multi_channel_metrics": True, + } + + +def get_default_qc_params(): + """ + Get default parameters for quality metrics and BombCell labeling. + + Returns a dictionary that can be modified and passed to run_bombcell_qc(). + + Returns + ------- + dict + Dictionary with the following keys: + + **Quality Metrics Selection** + + compute_amplitude_cutoff : bool, default: False + Whether to compute amplitude_cutoff metric (estimated percentage of + missing spikes). Requires spike_amplitudes extension which is slow + to compute for large recordings. When enabled, spike_amplitudes will + be computed automatically if not already present. + + compute_distance_metrics : bool, default: False + Whether to compute isolation_distance and l_ratio metrics. + These require PCA computation and are slow for large datasets. + Useful for chronic recordings where cluster stability matters. + Not recommended for acute recordings with expected drift. + + compute_drift : bool, default: True + Whether to compute drift metrics (drift_ptp, drift_std, drift_mad). + Measures how much units move over the recording. Important for + acute recordings. drift_ptp (peak-to-peak drift in um) is used + by BombCell MUA thresholds. + + rp_method : str, default: "sliding_rp" + Method for computing refractory period violations: + - "sliding_rp": IBL/Steinmetz method that sweeps across RP values + and estimates contamination. More robust. (recommended) + - "llobet": Single RP value method from Llobet et al. + + **BombCell Labeling Options** + + label_non_somatic : bool, default: True + Whether to detect and label non-somatic (axonal/dendritic) units. + These have distinctive waveform shapes: narrow initial peak, + often triphasic. Set False to skip this classification. + + split_non_somatic_good_mua : bool, default: False + If True, split non-somatic units into "non_soma_good" and + "non_soma_mua" based on whether they pass MUA thresholds. + If False, all non-somatic units are labeled "non_soma". + + use_valid_periods : bool, default: False + If True, identify valid time periods per unit (where the unit + has stable amplitude and low refractory violations) and compute + quality metrics only on those periods. Useful for recordings + with unstable periods. Requires amplitude_scalings extension. + + **Presence Ratio Parameters** + + presence_ratio_bin_duration_s : float, default: 60 + Bin duration in seconds for computing presence ratio. + Presence ratio = fraction of bins containing at least one spike. + 60s bins are standard; shorter bins are stricter. + + **Refractory Period Violation Parameters** + + refractory_period_ms : float, default: 2.0 + Refractory period duration in milliseconds. Spikes closer than + this are considered violations. 2.0 ms is conservative; some + fast-spiking neurons may need 1.0-1.5 ms. + + censored_period_ms : float, default: 0.1 + Censored period in milliseconds. Spikes within this period of + each other are not counted (accounts for detection artifacts). + 0.1 ms is standard. + + **Sliding RP Method Parameters** (used if rp_method="sliding_rp") + + sliding_rp_exclude_below_ms : float, default: 0.5 + Exclude refractory periods below this value when sweeping. + Avoids artifacts from very short intervals. + + sliding_rp_max_ms : float, default: 10.0 + Maximum refractory period to test when sweeping. + + sliding_rp_confidence : float, default: 0.9 + Confidence level for contamination estimate. Higher values + give more conservative (higher) contamination estimates. + + **Drift Parameters** + + drift_interval_s : float, default: 60 + Interval in seconds for computing drift. Unit positions are + estimated in each interval and drift is the movement across intervals. + + drift_min_spikes : int, default: 100 + Minimum spikes required per interval to estimate position. + Intervals with fewer spikes are skipped. + + **Plotting Options** + + plot_histograms : bool, default: True + Generate histograms of all metrics with threshold lines. + Saved as "metric_histograms.png". + + plot_waveforms : bool, default: True + Generate waveform overlay plots grouped by label (good, mua, noise, etc.). + Saved as "waveforms_by_label.png". + + plot_upset : bool, default: True + Generate UpSet plots showing which metrics fail together. + Useful for understanding why units are labeled noise/mua. + Requires 'upsetplot' package. Saved as "upset_plot_*.png". + + waveform_ylims : tuple or None, default: (-300, 100) + Y-axis limits for waveform plots in microvolts. + None for automatic scaling. + + figsize_histograms : tuple, default: (15, 10) + Figure size (width, height) in inches for histogram plot. + + Examples + -------- + >>> params = get_default_qc_params() + >>> # Stricter for chronic recordings + >>> params["compute_distance_metrics"] = True + >>> params["compute_drift"] = False # Less relevant for chronic + >>> # More lenient refractory period for fast-spiking neurons + >>> params["refractory_period_ms"] = 1.5 + >>> labels, metrics, figs = run_bombcell_qc(analyzer, params=params) + """ + return { + # Which metrics to compute + "compute_amplitude_cutoff": False, # slow - requires spike_amplitudes + "compute_distance_metrics": False, + "compute_drift": True, + "rp_method": "sliding_rp", + # BombCell labeling options + "label_non_somatic": True, + "split_non_somatic_good_mua": False, + "use_valid_periods": False, + # Presence ratio + "presence_ratio_bin_duration_s": 60, + # Refractory period violations + "refractory_period_ms": 2.0, + "censored_period_ms": 0.1, + # Sliding RP method + "sliding_rp_exclude_below_ms": 0.5, + "sliding_rp_max_ms": 10.0, + "sliding_rp_confidence": 0.9, + # Drift + "drift_interval_s": 60, + "drift_min_spikes": 100, + # Plotting + "plot_histograms": True, + "plot_waveforms": True, + "plot_upset": True, + "waveform_ylims": (-300, 100), + "figsize_histograms": (15, 10), + } def preprocess_for_bombcell( recording, sorting, analyzer_folder: str | Path, - # Preprocessing - freq_min: float = 300.0, - detect_bad_channels: bool = True, - bad_channel_method: str = "coherence+psd", - apply_phase_shift: bool = True, - apply_cmr: bool = True, - cmr_reference: str = "global", - cmr_operator: str = "median", - # Analyzer - sparse: bool = True, - return_in_uV: bool = True, - # Extensions - max_spikes_per_unit: int = 500, - ms_before: float = 3.0, - ms_after: float = 3.0, - template_operators: list[str] = ("average", "median", "std"), - include_multi_channel_metrics: bool = True, - # Rerun flags + params: dict | None = None, rerun_extensions: bool = False, - # Job n_jobs: int = -1, progress_bar: bool = True, ): """ Preprocess recording and create SortingAnalyzer with extensions for BombCell. + This function applies standard preprocessing steps (filtering, bad channel + removal, phase shift correction, common reference) and creates a SortingAnalyzer + with all extensions required for BombCell quality control. + Parameters ---------- recording : BaseRecording - Raw recording to preprocess. + Raw recording to preprocess. Typically loaded with si.read_spikeglx() + or si.read_openephys(). sorting : BaseSorting - Spike sorting result. + Spike sorting result. Can be loaded with si.read_sorter_folder() or + any other sorting loader. analyzer_folder : str or Path - Path to save the SortingAnalyzer (zarr format). - - Preprocessing Parameters - ------------------------ - freq_min : float, default: 300.0 - Highpass filter cutoff frequency in Hz. - detect_bad_channels : bool, default: True - Detect and remove bad channels. - bad_channel_method : str, default: "coherence+psd" - Method for bad channel detection. - apply_phase_shift : bool, default: True - Apply phase shift correction (for Neuropixels). - apply_cmr : bool, default: True - Apply common median reference. - cmr_reference : str, default: "global" - Reference type: "global", "local", or "single". - cmr_operator : str, default: "median" - Operator: "median" or "average". - - Analyzer Parameters - ------------------- - sparse : bool, default: True - Use sparse waveform representation. - return_in_uV : bool, default: True - Return waveforms in microvolts. - - Extension Parameters - -------------------- - max_spikes_per_unit : int, default: 500 - Number of spikes to extract per unit for waveforms. - ms_before : float, default: 3.0 - Milliseconds before spike peak for waveform extraction. - ms_after : float, default: 3.0 - Milliseconds after spike peak for waveform extraction. - template_operators : list, default: ("average", "median", "std") - Template statistics to compute. - include_multi_channel_metrics : bool, default: True - Include multi-channel template metrics (exp_decay, etc.). + Path to save the SortingAnalyzer. Will be created in zarr format. + If folder exists, loads existing analyzer instead of creating new one. + params : dict or None, default: None + Preprocessing parameters from get_default_preprocessing_params(). + If None, uses all default values. rerun_extensions : bool, default: False - Force recomputation of existing extensions. - - Job Parameters - -------------- + If True, recompute all extensions even if they already exist. + Useful after changing parameters. n_jobs : int, default: -1 - Number of parallel jobs (-1 for all CPUs). + Number of parallel jobs for computation. -1 uses all available CPUs. progress_bar : bool, default: True - Show progress bars. + Show progress bars during computation. Returns ------- analyzer : SortingAnalyzer - SortingAnalyzer with computed extensions. + SortingAnalyzer saved to analyzer_folder with computed extensions: + random_spikes, waveforms, templates, noise_levels, unit_locations, + spike_locations, template_metrics. Note: spike_amplitudes is computed + on-demand by run_bombcell_qc() if compute_amplitude_cutoff=True. rec_preprocessed : BaseRecording - Preprocessed recording (lazy, not saved). + Preprocessed recording (lazy chain, not saved to disk). + Can be used for further analysis or passed to other functions. bad_channel_ids : list or None - List of removed bad channel IDs, or None if detection disabled. + List of channel IDs that were detected as bad and removed. + None if detect_bad_channels=False. + + Examples + -------- + Basic usage with defaults: + + >>> analyzer, rec, bad_chs = preprocess_for_bombcell(recording, sorting, "analyzer.zarr") + + With custom parameters: + + >>> params = get_default_preprocessing_params() + >>> params["freq_min"] = 250.0 + >>> params["detect_bad_channels"] = False # Already cleaned + >>> analyzer, rec, bad_chs = preprocess_for_bombcell( + ... recording, sorting, "analyzer.zarr", params=params + ... ) + + Rerun extensions after parameter change: + + >>> analyzer, rec, bad_chs = preprocess_for_bombcell( + ... recording, sorting, "analyzer.zarr", rerun_extensions=True + ... ) """ import spikeinterface.full as si + if params is None: + params = get_default_preprocessing_params() + analyzer_folder = Path(analyzer_folder) job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=progress_bar) # Preprocess - rec = si.highpass_filter(recording, freq_min=freq_min) + rec = si.highpass_filter(recording, freq_min=params["freq_min"]) bad_channel_ids = None - if detect_bad_channels: - bad_channel_ids, _ = si.detect_bad_channels(rec, method=bad_channel_method) + if params["detect_bad_channels"]: + bad_channel_ids, _ = si.detect_bad_channels(rec, method=params["bad_channel_method"]) bad_channel_ids = list(bad_channel_ids) if len(bad_channel_ids) > 0: rec = rec.remove_channels(bad_channel_ids) - if apply_phase_shift: + if params["apply_phase_shift"]: rec = si.phase_shift(rec) - if apply_cmr: - rec = si.common_reference(rec, reference=cmr_reference, operator=cmr_operator) + if params["apply_cmr"]: + rec = si.common_reference(rec, reference=params["cmr_reference"], operator=params["cmr_operator"]) rec_preprocessed = rec @@ -142,10 +422,10 @@ def preprocess_for_bombcell( analyzer = si.create_sorting_analyzer( sorting=sorting, recording=rec_preprocessed, - sparse=sparse, + sparse=params["sparse"], format="zarr", folder=analyzer_folder, - return_in_uV=return_in_uV, + return_in_uV=params["return_in_uV"], ) # Compute extensions @@ -156,151 +436,64 @@ def _compute(name, **kwargs): analyzer.delete_extension(name) analyzer.compute(name, **kwargs) - _compute("random_spikes", method="uniform", max_spikes_per_unit=max_spikes_per_unit) - _compute("waveforms", ms_before=ms_before, ms_after=ms_after, **job_kwargs) - _compute("templates", operators=list(template_operators)) + _compute("random_spikes", method="uniform", max_spikes_per_unit=params["max_spikes_per_unit"]) + _compute("waveforms", ms_before=params["ms_before"], ms_after=params["ms_after"], **job_kwargs) + _compute("templates", operators=list(params["template_operators"])) _compute("noise_levels") - _compute("spike_amplitudes", **job_kwargs) + # spike_amplitudes computed on-demand by run_bombcell_qc if compute_amplitude_cutoff=True _compute("unit_locations") _compute("spike_locations", **job_kwargs) - _compute("template_metrics", include_multi_channel_metrics=include_multi_channel_metrics) + _compute("template_metrics", include_multi_channel_metrics=params["include_multi_channel_metrics"]) return analyzer, rec_preprocessed, bad_channel_ids def run_bombcell_qc( sorting_analyzer, - output_folder: str | Path | None = None, - # Quality metric options - compute_distance_metrics: bool = False, - compute_drift: bool = True, - rp_method: str = "sliding_rp", - # Quality metric parameters - qm_params: dict | None = None, - # BombCell options - label_non_somatic: bool = True, - split_non_somatic_good_mua: bool = False, - use_valid_periods: bool = False, - valid_periods_params: dict | None = None, - # Thresholds (None = use defaults) + output_folder: str | Path = "bombcell", + params: dict | None = None, thresholds: dict | None = None, - # Plotting - plot_histograms: bool = True, - plot_waveforms: bool = True, - plot_upset: bool = True, - waveform_ylims: tuple | None = (-300, 100), - figsize_histograms: tuple = (15, 10), - # Rerun + valid_periods_params: dict | None = None, rerun_quality_metrics: bool = False, rerun_pca: bool = False, - # Job n_jobs: int = -1, progress_bar: bool = True, ): """ Compute quality metrics and run BombCell unit labeling. + This function computes quality metrics on the SortingAnalyzer, runs the + BombCell labeling algorithm to classify units as good/mua/noise/non_soma, + generates diagnostic plots, and saves results. + Parameters ---------- sorting_analyzer : SortingAnalyzer - Analyzer with template_metrics computed (from preprocess_for_bombcell). - output_folder : str, Path, or None, default: None - Folder to save results (CSV files and plots). If None, results not saved. - - Quality Metric Options - ---------------------- - compute_distance_metrics : bool, default: False - Compute isolation_distance and l_ratio (requires PCA, slow). - compute_drift : bool, default: True - Compute drift metrics. - rp_method : str, default: "sliding_rp" - Refractory period violation method: "sliding_rp" or "llobet". - - Quality Metric Parameters - ------------------------- - qm_params : dict or None, default: None - Override default quality metric parameters. Keys are metric names, - values are parameter dicts. Default parameters: - - - presence_ratio: {"bin_duration_s": 60} - - rp_violation: {"refractory_period_ms": 2.0, "censored_period_ms": 0.1} - - sliding_rp_violation: {"exclude_ref_period_below_ms": 0.5, - "max_ref_period_ms": 10.0, "confidence_threshold": 0.9} - - drift: {"interval_s": 60, "min_spikes_per_interval": 100} - - BombCell Options - ---------------- - label_non_somatic : bool, default: True - Detect non-somatic (axonal/dendritic) units. - split_non_somatic_good_mua : bool, default: False - Split non-somatic into "non_soma_good" and "non_soma_mua". - use_valid_periods : bool, default: False - Restrict metrics to valid time periods per unit. - valid_periods_params : dict or None, default: None - Parameters for valid_unit_periods extension. - - Classification Thresholds - ------------------------- + Analyzer with template_metrics extension computed (from preprocess_for_bombcell). + output_folder : str or Path, default: "bombcell" + Folder to save results (CSV files and plots). Set to None to skip saving. + Created if it doesn't exist. + params : dict or None, default: None + QC parameters from get_default_qc_params(). If None, uses defaults. thresholds : dict or None, default: None - BombCell thresholds dict with "noise", "mua", "non-somatic" sections. - If None, uses bombcell_get_default_thresholds(). Default values: - - noise (waveform quality - any failure -> "noise"): - - num_positive_peaks: < 2 - - num_negative_peaks: < 1 - - peak_to_trough_duration: 0.1-1.15 ms - - waveform_baseline_flatness: < 0.5 - - peak_after_to_trough_ratio: < 0.8 - - exp_decay: 0.01-0.1 - - mua (spike quality - any failure -> "mua"): - - amplitude_median: > 30 uV (absolute value) - - snr: > 5 - - amplitude_cutoff: < 0.2 - - num_spikes: > 300 - - rpv: < 0.1 (refractory period violations) - - presence_ratio: > 0.7 - - drift_ptp: < 100 um - - isolation_distance: > 20 (if computed) - - l_ratio: < 0.3 (if computed) - - non-somatic (waveform shape): - - peak_before_to_trough_ratio: < 3 - - peak_before_width: > 0.15 ms - - trough_width: > 0.2 ms - - peak_before_to_peak_after_ratio: < 3 - - main_peak_to_trough_ratio: < 0.8 - - To modify thresholds: - thresholds = bombcell_get_default_thresholds() - thresholds["mua"]["rpv"]["less"] = 0.05 # stricter RPV - thresholds["mua"]["num_spikes"]["greater"] = 100 # lower spike count - - To disable a threshold: - thresholds["mua"]["drift_ptp"] = {"greater": None, "less": None} - - Plotting Options - ---------------- - plot_histograms : bool, default: True - Plot metric histograms with threshold lines. - plot_waveforms : bool, default: True - Plot waveforms grouped by label. - plot_upset : bool, default: True - Plot UpSet plots showing metric failure combinations. - waveform_ylims : tuple or None, default: (-300, 100) - Y-axis limits for waveform plots. - figsize_histograms : tuple, default: (15, 10) - Figure size for histogram plot. - - Rerun Options - ------------- + BombCell classification thresholds from bombcell_get_default_thresholds(). + If None, uses defaults. Structure: + + - "noise": Thresholds for waveform quality. Failing ANY -> "noise". + - "mua": Thresholds for spike quality. Failing ANY -> "mua". + - "non-somatic": Thresholds for waveform shape. Determines non-somatic units. + + Each threshold is {"greater": value, "less": value}. Use None to disable. + See bombcell_get_default_thresholds() docstring for all thresholds. + + valid_periods_params : dict or None, default: None + Parameters for valid_unit_periods extension if params["use_valid_periods"]=True. + Keys: refractory_period_ms, censored_period_ms, period_mode, + period_duration_s_absolute, minimum_valid_period_duration. rerun_quality_metrics : bool, default: False - Force recomputation of quality metrics. + Force recomputation of quality metrics even if they exist. rerun_pca : bool, default: False - Force recomputation of PCA (only if compute_distance_metrics=True). - - Job Parameters - -------------- + Force recomputation of PCA (only relevant if compute_distance_metrics=True). n_jobs : int, default: -1 Number of parallel jobs. progress_bar : bool, default: True @@ -310,66 +503,112 @@ def run_bombcell_qc( ------- labels : pd.DataFrame DataFrame with unit_ids as index and "bombcell_label" column. - Labels: "good", "mua", "noise", "non_soma" (or "non_soma_good"/"non_soma_mua"). + Possible labels: "good", "mua", "noise", "non_soma" + (or "non_soma_good"/"non_soma_mua" if split_non_somatic_good_mua=True). metrics : pd.DataFrame - Combined quality metrics and template metrics. + Combined DataFrame of all quality metrics and template metrics. + Index is unit_ids, columns are metric names. figures : dict - Dictionary of matplotlib figures: {"histograms", "waveforms", "upset"}. + Dictionary of matplotlib figures: + - "histograms": Metric histograms with threshold lines. + - "waveforms": Waveform overlays grouped by label. + - "upset": List of UpSet plot figures (one per label type). + + Saved Files (in output_folder) + ------------------------------ + - labeling_results_wide.csv: One row per unit with all metrics and label. + - labeling_results_narrow.csv: One row per unit-metric with pass/fail status. + - metric_histograms.png: Histogram of each metric with threshold lines. + - waveforms_by_label.png: Waveform overlays for each label category. + - upset_plot_*.png: UpSet plots showing metric failure combinations. + + Examples + -------- + Basic usage with defaults: + + >>> labels, metrics, figs = run_bombcell_qc(analyzer) + + With custom parameters and thresholds: + + >>> params = get_default_qc_params() + >>> params["compute_distance_metrics"] = True # For chronic recordings + >>> params["refractory_period_ms"] = 1.5 # For fast-spiking neurons + >>> + >>> thresholds = bombcell_get_default_thresholds() + >>> thresholds["mua"]["rpv"]["less"] = 0.05 # Stricter RP violations + >>> thresholds["mua"]["num_spikes"]["greater"] = 100 # Lower spike threshold + >>> + >>> labels, metrics, figs = run_bombcell_qc( + ... analyzer, + ... output_folder="qc_results", + ... params=params, + ... thresholds=thresholds, + ... ) + + Get good units for downstream analysis: + + >>> good_units = labels[labels["bombcell_label"] == "good"].index.tolist() + >>> mua_units = labels[labels["bombcell_label"] == "mua"].index.tolist() """ - import pandas as pd - from .bombcell_curation import bombcell_get_default_thresholds, bombcell_label_units, save_bombcell_results + if params is None: + params = get_default_qc_params() + + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=progress_bar) if output_folder is not None: output_folder = Path(output_folder) output_folder.mkdir(parents=True, exist_ok=True) - # Default QM params - default_qm_params = { - "presence_ratio": {"bin_duration_s": 60}, - "rp_violation": {"refractory_period_ms": 2.0, "censored_period_ms": 0.1}, + # Build QM params + qm_params = { + "presence_ratio": {"bin_duration_s": params["presence_ratio_bin_duration_s"]}, + "rp_violation": { + "refractory_period_ms": params["refractory_period_ms"], + "censored_period_ms": params["censored_period_ms"], + }, "sliding_rp_violation": { - "exclude_ref_period_below_ms": 0.5, - "max_ref_period_ms": 10.0, - "confidence_threshold": 0.9, + "exclude_ref_period_below_ms": params["sliding_rp_exclude_below_ms"], + "max_ref_period_ms": params["sliding_rp_max_ms"], + "confidence_threshold": params["sliding_rp_confidence"], + }, + "drift": { + "interval_s": params["drift_interval_s"], + "min_spikes_per_interval": params["drift_min_spikes"], }, - "drift": {"interval_s": 60, "min_spikes_per_interval": 100}, } - if qm_params is not None: - for key, val in qm_params.items(): - default_qm_params[key] = val - qm_params = default_qm_params - - # Build metric names list - metric_names = [ - "amplitude_median", - "snr", - "amplitude_cutoff", - "num_spikes", - "presence_ratio", - "firing_rate", - ] - - if rp_method == "sliding_rp": + + # Build metric names + metric_names = ["amplitude_median", "snr", "num_spikes", "presence_ratio", "firing_rate"] + + if params["compute_amplitude_cutoff"]: + metric_names.append("amplitude_cutoff") + # amplitude_cutoff requires spike_amplitudes or amplitude_scalings + if not sorting_analyzer.has_extension("spike_amplitudes") and not sorting_analyzer.has_extension( + "amplitude_scalings" + ): + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + + if params["rp_method"] == "sliding_rp": metric_names.append("sliding_rp_violation") else: metric_names.append("rp_violation") - if compute_drift: + if params["compute_drift"]: metric_names.append("drift") - if compute_distance_metrics: + if params["compute_distance_metrics"]: metric_names.append("mahalanobis") if not sorting_analyzer.has_extension("principal_components") or rerun_pca: if sorting_analyzer.has_extension("principal_components"): sorting_analyzer.delete_extension("principal_components") - sorting_analyzer.compute( - "principal_components", n_components=5, mode="by_channel_local", **job_kwargs - ) + sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - if use_valid_periods and not sorting_analyzer.has_extension("amplitude_scalings"): + if params["use_valid_periods"] and not sorting_analyzer.has_extension("amplitude_scalings"): sorting_analyzer.compute("amplitude_scalings", **job_kwargs) # Compute quality metrics @@ -377,42 +616,35 @@ def run_bombcell_qc( sorting_analyzer.delete_extension("quality_metrics") if not sorting_analyzer.has_extension("quality_metrics"): - sorting_analyzer.compute( - "quality_metrics", metric_names=metric_names, metric_params=qm_params, **job_kwargs - ) + sorting_analyzer.compute("quality_metrics", metric_names=metric_names, metric_params=qm_params, **job_kwargs) - # Get thresholds - if thresholds is None: - thresholds = bombcell_get_default_thresholds() - - # Run BombCell labeling + # Run BombCell labels = bombcell_label_units( sorting_analyzer=sorting_analyzer, thresholds=thresholds, - label_non_somatic=label_non_somatic, - split_non_somatic_good_mua=split_non_somatic_good_mua, - use_valid_periods=use_valid_periods, + label_non_somatic=params["label_non_somatic"], + split_non_somatic_good_mua=params["split_non_somatic_good_mua"], + use_valid_periods=params["use_valid_periods"], valid_periods_params=valid_periods_params, **job_kwargs, ) metrics = sorting_analyzer.get_metrics_extension_data() - # Generate plots + # Plots figures = {} - - if plot_histograms or plot_waveforms or plot_upset: + if params["plot_histograms"] or params["plot_waveforms"] or params["plot_upset"]: import spikeinterface.widgets as sw - if plot_histograms: - w = sw.plot_metric_histograms(sorting_analyzer, thresholds, figsize=figsize_histograms) + if params["plot_histograms"]: + w = sw.plot_metric_histograms(sorting_analyzer, thresholds, figsize=params["figsize_histograms"]) figures["histograms"] = w.figure - if plot_waveforms: - w = sw.plot_unit_labels(sorting_analyzer, labels["bombcell_label"], ylims=waveform_ylims) + if params["plot_waveforms"]: + w = sw.plot_unit_labels(sorting_analyzer, labels["bombcell_label"], ylims=params["waveform_ylims"]) figures["waveforms"] = w.figure - if plot_upset: + if params["plot_upset"]: w = sw.plot_bombcell_labels_upset( sorting_analyzer, unit_labels=labels["bombcell_label"], @@ -421,7 +653,7 @@ def run_bombcell_qc( ) figures["upset"] = w.figures - # Save results + # Save if output_folder is not None: save_bombcell_results( metrics=metrics, @@ -429,8 +661,6 @@ def run_bombcell_qc( thresholds=thresholds, folder=output_folder, ) - - # Save figures if "histograms" in figures: figures["histograms"].savefig(output_folder / "metric_histograms.png", dpi=150, bbox_inches="tight") if "waveforms" in figures: From 16ff7123aff11c3edfe0f72fce345ae104337d1a Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Thu, 26 Mar 2026 15:32:27 -0400 Subject: [PATCH 13/15] bombcell preprocessing (standard) wrapper script --- examples/how_to/bombcell_pipeline_example.py | 70 ++++++++++++++++ .../curation/bombcell_curation.py | 81 +++++++++++++++++++ .../curation/bombcell_pipeline.py | 11 ++- 3 files changed, 161 insertions(+), 1 deletion(-) diff --git a/examples/how_to/bombcell_pipeline_example.py b/examples/how_to/bombcell_pipeline_example.py index 871709a4d1..284d7b45ab 100644 --- a/examples/how_to/bombcell_pipeline_example.py +++ b/examples/how_to/bombcell_pipeline_example.py @@ -62,6 +62,8 @@ # --- BombCell classification options --- qc_params["label_non_somatic"] = True # detect axonal/dendritic units via waveform shape qc_params["split_non_somatic_good_mua"] = False # if True, non-somatic split into good/mua subcategories +qc_params["use_valid_periods"] = False # if True, identify valid time chunks per unit and recompute + # quality metrics only on those periods (see below for details) # --- Presence ratio --- qc_params["presence_ratio_bin_duration_s"] = 60 # bin size (s) for checking if unit fires throughout recording @@ -205,3 +207,71 @@ print(f"\nMetrics for first good unit:") if good_units: print(metrics.loc[good_units[0]]) + +# %% Output files +# BombCell saves the following files to output_folder: +# +# labeling_results_wide.csv +# - One row per unit, all metrics as columns, plus "label" column +# - Format: unit_id (index), label, metric1, metric2, ... +# - Use for quick overview of all units and their metrics +# +# labeling_results_narrow.csv +# - One row per unit-metric combination (tidy/long format) +# - Columns: unit_id, label, metric_name, value, threshold_min, threshold_max, passed +# - Use to see exactly which metrics failed for each unit +# +# valid_periods.tsv (only if use_valid_periods=True) +# - Valid time periods per unit for downstream analysis +# - Columns: unit_id, segment_index, start_time_s, end_time_s, duration_s +# - Use to filter spikes to stable periods in your analysis +# +# metric_histograms.png +# - Histogram of each metric with threshold lines marked +# - Useful for adjusting thresholds based on your data distribution +# +# waveforms_by_label.png +# - Waveform overlays grouped by label (good, mua, noise, non_soma) +# - Verify that labels match expected waveform shapes +# +# upset_plot_*.png +# - UpSet plots showing which metrics fail together +# - Understand why units are labeled noise/mua + +# %% Using valid time periods +# Valid periods identify chunks of time where each unit has stable amplitude +# and low refractory period violations. This is useful when recordings have +# unstable periods (e.g., drift, probe movement, or electrode noise). +# +# When use_valid_periods=True: +# 1. Recording is divided into chunks (default 30s or ~300 spikes per unit) +# 2. For each chunk, false positive rate (RP violations) and false negative +# rate (amplitude cutoff) are computed +# 3. Chunks where BOTH rates are below threshold are marked as "valid" +# 4. Overlapping valid chunks are merged; short periods (<180s) are removed +# 5. Quality metrics are recomputed using only spikes within valid periods +# 6. BombCell labeling is applied to these restricted metrics +# 7. valid_periods.tsv is saved with the valid time windows per unit +# +# Example: Enable valid periods +# qc_params["use_valid_periods"] = True +# +# Example: Customize valid period parameters +# valid_periods_params = { +# "period_duration_s_absolute": 30.0, # chunk size in seconds (if period_mode="absolute") +# "period_target_num_spikes": 300, # target spikes per chunk (if period_mode="relative") +# "period_mode": "absolute", # "absolute" (fixed duration) or "relative" (fixed spike count) +# "minimum_valid_period_duration": 180, # min duration to keep a valid period (seconds) +# "fp_threshold": 0.1, # max false positive rate (derived from rpv threshold if not set) +# "fn_threshold": 0.1, # max false negative rate (derived from amplitude_cutoff if not set) +# } +# labels, metrics, figures = sc.run_bombcell_qc( +# analyzer, params=qc_params, valid_periods_params=valid_periods_params +# ) +# +# Example: Load valid_periods.tsv for downstream analysis +# import pandas as pd +# valid_periods = pd.read_csv(output_folder / "valid_periods.tsv", sep="\t") +# # Filter to get valid periods for a specific unit +# unit_periods = valid_periods[valid_periods["unit_id"] == good_units[0]] +# print(f"Unit {good_units[0]} has {len(unit_periods)} valid period(s)") diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py index a0aabe39df..bc212adda4 100644 --- a/src/spikeinterface/curation/bombcell_curation.py +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -492,3 +492,84 @@ def save_bombcell_results( narrow_df = pd.DataFrame(rows) narrow_df.to_csv(folder / "labeling_results_narrow.csv", index=False) + + +def save_valid_periods( + sorting_analyzer, + folder, +) -> None: + """ + Save valid time periods per unit to a TSV file for downstream analysis. + + This function extracts the valid_unit_periods extension data and saves it + in a simple, human-readable TSV format with times in seconds. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + Analyzer with the valid_unit_periods extension computed. + folder : str or Path + Folder to save the TSV file. + + Returns + ------- + None + + Notes + ----- + The output file `valid_periods.tsv` contains one row per valid period with columns: + + - unit_id: The unit identifier + - segment_index: Recording segment (0-indexed) + - start_time_s: Start of valid period in seconds + - end_time_s: End of valid period in seconds + - duration_s: Duration of valid period in seconds + + This file can be easily loaded with pandas or any TSV reader for downstream + analysis, filtering spikes to valid periods, or visualization. + """ + from pathlib import Path + import pandas as pd + + folder = Path(folder) + folder.mkdir(parents=True, exist_ok=True) + + if not sorting_analyzer.has_extension("valid_unit_periods"): + return + + vp_ext = sorting_analyzer.get_extension("valid_unit_periods") + valid_periods = vp_ext.get_data(outputs="numpy") + + if len(valid_periods) == 0: + # No valid periods found - save empty file with header + df = pd.DataFrame(columns=["unit_id", "segment_index", "start_time_s", "end_time_s", "duration_s"]) + df.to_csv(folder / "valid_periods.tsv", sep="\t", index=False) + return + + fs = sorting_analyzer.sampling_frequency + unit_ids = sorting_analyzer.unit_ids + + rows = [] + for period in valid_periods: + unit_index = period["unit_index"] + unit_id = unit_ids[unit_index] + segment_index = period["segment_index"] + start_sample = period["start_sample_index"] + end_sample = period["end_sample_index"] + + start_time_s = start_sample / fs + end_time_s = end_sample / fs + duration_s = end_time_s - start_time_s + + rows.append( + { + "unit_id": unit_id, + "segment_index": segment_index, + "start_time_s": round(start_time_s, 6), + "end_time_s": round(end_time_s, 6), + "duration_s": round(duration_s, 6), + } + ) + + df = pd.DataFrame(rows) + df.to_csv(folder / "valid_periods.tsv", sep="\t", index=False) diff --git a/src/spikeinterface/curation/bombcell_pipeline.py b/src/spikeinterface/curation/bombcell_pipeline.py index f3e6b466dd..6506704b5c 100644 --- a/src/spikeinterface/curation/bombcell_pipeline.py +++ b/src/spikeinterface/curation/bombcell_pipeline.py @@ -518,6 +518,8 @@ def run_bombcell_qc( ------------------------------ - labeling_results_wide.csv: One row per unit with all metrics and label. - labeling_results_narrow.csv: One row per unit-metric with pass/fail status. + - valid_periods.tsv: Valid time periods per unit (only if use_valid_periods=True). + Columns: unit_id, segment_index, start_time_s, end_time_s, duration_s. - metric_histograms.png: Histogram of each metric with threshold lines. - waveforms_by_label.png: Waveform overlays for each label category. - upset_plot_*.png: UpSet plots showing metric failure combinations. @@ -550,7 +552,12 @@ def run_bombcell_qc( >>> good_units = labels[labels["bombcell_label"] == "good"].index.tolist() >>> mua_units = labels[labels["bombcell_label"] == "mua"].index.tolist() """ - from .bombcell_curation import bombcell_get_default_thresholds, bombcell_label_units, save_bombcell_results + from .bombcell_curation import ( + bombcell_get_default_thresholds, + bombcell_label_units, + save_bombcell_results, + save_valid_periods, + ) if params is None: params = get_default_qc_params() @@ -661,6 +668,8 @@ def run_bombcell_qc( thresholds=thresholds, folder=output_folder, ) + if params["use_valid_periods"]: + save_valid_periods(sorting_analyzer, output_folder) if "histograms" in figures: figures["histograms"].savefig(output_folder / "metric_histograms.png", dpi=150, bbox_inches="tight") if "waveforms" in figures: From 05a6ec2278333aa5b1384f6ba3d34e21fb882084 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:34:06 +0000 Subject: [PATCH 14/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/how_to/bombcell_pipeline_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/how_to/bombcell_pipeline_example.py b/examples/how_to/bombcell_pipeline_example.py index 284d7b45ab..2d69177810 100644 --- a/examples/how_to/bombcell_pipeline_example.py +++ b/examples/how_to/bombcell_pipeline_example.py @@ -46,7 +46,7 @@ # --- Waveform extraction --- preproc_params["max_spikes_per_unit"] = 500 # spikes to extract per unit (more = accurate but slower) preproc_params["ms_before"] = 3.0 # ms before spike peak to include in waveform -preproc_params["ms_after"] = 3.0 # ms after spike peak to include in waveform +preproc_params["ms_after"] = 3.0 # ms after spike peak to include in waveform # %% QC parameters qc_params = sc.get_default_qc_params() From edd1d30461f3b734e1fdcafd8880d386ddfec2e7 Mon Sep 17 00:00:00 2001 From: Julie-Fabre Date: Thu, 26 Mar 2026 15:39:20 -0400 Subject: [PATCH 15/15] bombcell preprocessing (standard) wrapper script --- src/spikeinterface/curation/bombcell_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py index bc212adda4..b8ba62c978 100644 --- a/src/spikeinterface/curation/bombcell_curation.py +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -104,7 +104,7 @@ def bombcell_label_units( thresholds: dict | str | Path | None = None, label_non_somatic: bool = True, split_non_somatic_good_mua: bool = False, - external_metrics: "pd.DataFrame | list[pd.DataFrame]" | None = None, + external_metrics: "pd.DataFrame | list[pd.DataFrame] | None" = None, use_valid_periods: bool = False, valid_periods_params: dict | None = None, recompute_quality_metrics: bool = True,