diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 96eda404dd..0b63288024 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -169,6 +169,24 @@ def _handle_backward_compatibility_on_load(self): if "waveform_ratios" not in self.params["metric_names"]: self.params["metric_names"].append("waveform_ratios") + if self.data.get("peaks_data") is None: + import pandas as pd + + self.data["peaks_data"] = pd.DataFrame( + columns=get_peaks_data_columns(), index=self.sorting_analyzer.unit_ids + ) + + if self.data.get("main_channel_templates") is None: + num_units = self.sorting_analyzer.get_num_units() + if self.sorting_analyzer.has_extension("templates"): + templates_ext = self.sorting_analyzer.get_extension("templates") + template_samples = templates_ext.nbefore + templates_ext.nafter + upsampling_factor = self.params["upsampling_factor"] + upsampled_template_samples = template_samples * upsampling_factor + self.data["main_channel_templates"] = np.zeros((num_units, upsampled_template_samples)) + else: + warnings.warn("Cannot set all `template_metrics` metadata as `templates` extension is not available.") + def _set_params( self, metric_names: list[str] | None = None, @@ -315,16 +333,8 @@ def _prepare_data(self, sorting_analyzer, unit_ids): tmp_data["depth_direction"] = self.params["depth_direction"] # Add peaks_info and preprocessed templates to self.data for storage in extension - columns = [] - for k in ("trough", "peak_before", "peak_after"): - for suffix in ( - "index", - "width_left", - "width_right", - "half_width_left", - "half_width_right", - ): - columns.append(f"{k}_{suffix}") + columns = get_peaks_data_columns() + tmp_data["peaks_data"] = pd.DataFrame( index=unit_ids, data=peaks_info, @@ -335,6 +345,21 @@ def _prepare_data(self, sorting_analyzer, unit_ids): return tmp_data +def get_peaks_data_columns(): + """Generates the column names of the `peaks_data` DataFrame.""" + columns = [] + for k in ("trough", "peak_before", "peak_after"): + for suffix in ( + "index", + "width_left", + "width_right", + "half_width_left", + "half_width_right", + ): + columns.append(f"{k}_{suffix}") + return columns + + register_result_extension(ComputeTemplateMetrics) compute_template_metrics = ComputeTemplateMetrics.function_factory()