From adf8c9eeddfe3302595764eba2da038c44f2c126 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 11 Dec 2025 13:14:02 +0100 Subject: [PATCH 01/10] wip: support zarr v3 --- pyproject.toml | 3 +- src/spikeinterface/core/sortinganalyzer.py | 94 ++++++++++--------- src/spikeinterface/core/template.py | 12 ++- .../tests/test_analyzer_extension_core.py | 2 +- .../core/tests/test_baserecording.py | 4 +- .../core/tests/test_sortinganalyzer.py | 45 ++++----- .../core/tests/test_zarrextractors.py | 35 ++++--- src/spikeinterface/core/zarrextractors.py | 91 ++++++++++++------ 8 files changed, 169 insertions(+), 117 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8c3a3cf3b1..a753911fa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,11 @@ dependencies = [ "numpy>=2.0.0;python_version>='3.13'", "threadpoolctl>=3.0.0", "tqdm", - "zarr>=2.18,<3", + "zarr>=3,<4", "neo>=0.14.3", "probeinterface>=0.3.1", "packaging", "pydantic", - "numcodecs<0.16.0", # For supporting zarr < 3 ] [build-system] diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1870c24e7a..31cb317565 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Literal, Optional, Any +from typing import Literal, Optional, Any, Iterable from pathlib import Path from itertools import chain @@ -621,6 +621,7 @@ def _get_zarr_root(self, mode="r+"): assert mode in ("r+", "a", "r"), "mode must be 'r+', 'a' or 'r'" storage_options = self._backend_options.get("storage_options", {}) + zarr_root = super_zarr_open(self.folder, mode=mode, storage_options=storage_options) return zarr_root @@ -644,7 +645,12 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att storage_options = backend_options.get("storage_options", {}) saving_options = backend_options.get("saving_options", {}) - zarr_root = zarr.open(folder, mode="w", storage_options=storage_options) + if not is_path_remote(str(folder)): + storage_options_kwargs = {} + else: + storage_options_kwargs = storage_options + + zarr_root = zarr.open(folder, mode="w", **storage_options_kwargs) info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) @@ -657,13 +663,8 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att if recording is not None: rec_dict = recording.to_dict(relative_to=relative_to, recursive=True) if recording.check_serializability("json"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) - zarr_rec = np.array([check_json(rec_dict)], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON()) - elif recording.check_serializability("pickle"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) - zarr_rec = np.array([rec_dict], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) + # In zarr v3, store JSON-serializable data in attributes instead of using object_codec + zarr_root.attrs["recording"] = check_json(rec_dict) else: warnings.warn("The Recording is not serializable! The recording link will be lost for future load") else: @@ -673,11 +674,8 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att # sorting provenance sort_dict = sorting.to_dict(relative_to=relative_to, recursive=True) if sorting.check_serializability("json"): - zarr_sort = np.array([check_json(sort_dict)], dtype=object) - zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON()) - elif sorting.check_serializability("pickle"): - zarr_sort = np.array([sort_dict], dtype=object) - zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle()) + # In zarr v3, store JSON-serializable data in attributes instead of using object_codec + zarr_root.attrs["sorting_provenance"] = check_json(sort_dict) else: warnings.warn( "The sorting provenance is not serializable! The sorting provenance link will be lost for future load" @@ -698,12 +696,13 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att recording_info.attrs["probegroup"] = check_json(probegroup.to_dict()) if sparsity is not None: - zarr_root.create_dataset("sparsity_mask", data=sparsity.mask, **saving_options) + zarr_root.create_array("sparsity_mask", data=sparsity.mask, **saving_options) add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **saving_options) recording_info = zarr_root.create_group("extensions") + # consolidate metadata for zarr v3 zarr.consolidate_metadata(zarr_root.store) return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options) @@ -715,6 +714,10 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): backend_options = {} if backend_options is None else backend_options storage_options = backend_options.get("storage_options", {}) + if not is_path_remote(str(folder)): + storage_options_kwargs = {} + else: + storage_options_kwargs = storage_options zarr_root = super_zarr_open(str(folder), mode="r", storage_options=storage_options) @@ -723,7 +726,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): # v0.101.0 did not have a consolidate metadata step after computing extensions. # Here we try to consolidate the metadata and throw a warning if it fails. try: - zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options) + zarr_root_a = zarr.open(str(folder), mode="a", **storage_options_kwargs) zarr.consolidate_metadata(zarr_root_a.store) except Exception as e: warnings.warn( @@ -741,9 +744,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): # load recording if possible if recording is None: - rec_field = zarr_root.get("recording") - if rec_field is not None: - rec_dict = rec_field[0] + # In zarr v3, recording is stored in attributes + rec_dict = zarr_root.attrs.get("recording", None) + if rec_dict is not None: try: recording = load(rec_dict, base_folder=folder) except: @@ -859,7 +862,7 @@ def set_sorting_property( if key in zarr_root["sorting"]["properties"]: zarr_root["sorting"]["properties"][key][:] = prop_values else: - zarr_root["sorting"]["properties"].create_dataset(name=key, data=prop_values, compressor=None) + zarr_root["sorting"]["properties"].create_array(name=key, data=prop_values, compressors=None) # IMPORTANT: we need to re-consolidate the zarr store! zarr.consolidate_metadata(zarr_root.store) @@ -1531,12 +1534,13 @@ def get_sorting_provenance(self): elif self.format == "zarr": zarr_root = self._get_zarr_root(mode="r") sorting_provenance = None - if "sorting_provenance" in zarr_root.keys(): + # In zarr v3, sorting_provenance is stored in attributes + sort_dict = zarr_root.attrs.get("sorting_provenance", None) + if sort_dict is not None: # try-except here is because it's not required to be able # to load the sorting provenance, as the user might have deleted # the original sorting folder try: - sort_dict = zarr_root["sorting_provenance"][0] sorting_provenance = load(sort_dict, base_folder=self.folder) except: pass @@ -2479,8 +2483,9 @@ def load_data(self): extension_group = self._get_zarr_extension_group(mode="r") for ext_data_name in extension_group.keys(): ext_data_ = extension_group[ext_data_name] - if "dict" in ext_data_.attrs: - ext_data = ext_data_[0] + # In zarr v3, check if it's a group with dict_data attribute + if "dict_data" in ext_data_.attrs: + ext_data = ext_data_.attrs["dict_data"] elif "dataframe" in ext_data_.attrs: import pandas as pd @@ -2565,9 +2570,10 @@ def run(self, save=True, **kwargs): if self.format == "zarr": import zarr - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root(mode="r+").store) def save(self): + self._reset_extension_folder() self._save_params() self._save_importing_provenance() self._save_run_info() @@ -2576,7 +2582,7 @@ def save(self): if self.format == "zarr": import zarr - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root(mode="r+").store) def _save_data(self): if self.format == "memory": @@ -2623,40 +2629,44 @@ def _save_data(self): extension_group = self._get_zarr_extension_group(mode="r+") # if compression is not externally given, we use the default - if "compressor" not in saving_options: - saving_options["compressor"] = get_default_zarr_compressor() + if "compressors" not in saving_options and "compressor" not in saving_options: + saving_options["compressors"] = get_default_zarr_compressor() + if "compressor" in saving_options: + saving_options["compressors"] = [saving_options["compressor"]] + del saving_options["compressor"] for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: del extension_group[ext_data_name] if isinstance(ext_data, dict): - extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() - ) + # In zarr v3, store dict in a subgroup with attributes + dict_group = extension_group.create_group(ext_data_name) + dict_group.attrs["dict_data"] = check_json(ext_data) elif isinstance(ext_data, np.ndarray): - extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options) + extension_group.create_array(name=ext_data_name, data=ext_data, **saving_options) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): df_group = extension_group.create_group(ext_data_name) # first we save the index indices = ext_data.index.to_numpy() if indices.dtype.kind == "O": indices = indices.astype(str) - df_group.create_dataset(name="index", data=indices) + df_group.create_array(name="index", data=indices) for col in ext_data.columns: col_data = ext_data[col].to_numpy() if col_data.dtype.kind == "O": col_data = col_data.astype(str) - df_group.create_dataset(name=col, data=col_data) + df_group.create_array(name=col, data=col_data) df_group.attrs["dataframe"] = True else: # any object - try: - extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle() - ) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") - extension_group[ext_data_name].attrs["object"] = True + # try: + # extension_group.create_array( + # name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle() + # ) + # except: + # raise Exception(f"Could not save {ext_data_name} as extension data") + # extension_group[ext_data_name].attrs["object"] = True + warnings.warn(f"Data type of {ext_data_name} not supported for zarr saving, skipping.") def _reset_extension_folder(self): """ @@ -2734,8 +2744,6 @@ def set_params(self, save=True, **params): def _save_params(self): params_to_save = self.params.copy() - self._reset_extension_folder() - # TODO make sparsity local Result specific # if "sparsity" in params_to_save and params_to_save["sparsity"] is not None: # assert isinstance( diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 91d25bece6..50ae6cbfdf 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -321,17 +321,19 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None: """ # Saves one chunk per unit - arrays_chunk = (1, None, None) - zarr_group.create_dataset("templates_array", data=self.templates_array, chunks=arrays_chunk) - zarr_group.create_dataset("channel_ids", data=self.channel_ids) - zarr_group.create_dataset("unit_ids", data=self.unit_ids) + # In zarr v3, chunks must be a full tuple with actual dimensions + num_units, num_samples, num_channels = self.templates_array.shape + arrays_chunk = (1, num_samples, num_channels) + zarr_group.create_array("templates_array", data=self.templates_array, chunks=arrays_chunk) + zarr_group.create_array("channel_ids", data=self.channel_ids) + zarr_group.create_array("unit_ids", data=self.unit_ids) zarr_group.attrs["sampling_frequency"] = self.sampling_frequency zarr_group.attrs["nbefore"] = self.nbefore zarr_group.attrs["is_in_uV"] = self.is_in_uV if self.sparsity_mask is not None: - zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask) + zarr_group.create_array("sparsity_mask", data=self.sparsity_mask) if self.probe is not None: probe_group = zarr_group.create_group("probe") diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index 6f5bef3c6c..574cc89e10 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -91,7 +91,7 @@ def test_ComputeRandomSpikes(format, sparse, create_cache_folder): print("Checking results") _check_result_extension(sorting_analyzer, "random_spikes", cache_folder) - print("Delering extension") + print("Deleting extension") sorting_analyzer.delete_extension("random_spikes") print("Re-computing random spikes") diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 9de800b33d..186767c026 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -324,7 +324,7 @@ def test_BaseRecording(create_cache_folder): # test save to zarr compressor = get_default_zarr_compressor() - rec_zarr = rec2.save(format="zarr", folder=cache_folder / "recording", compressor=compressor) + rec_zarr = rec2.save(format="zarr", folder=cache_folder / "recording", compressors=compressor) rec_zarr_loaded = load(cache_folder / "recording.zarr") # annotations is False because Zarr adds compression ratios check_recordings_equal(rec2, rec_zarr, return_in_uV=False, check_annotations=False, check_properties=True) @@ -336,7 +336,7 @@ def test_BaseRecording(create_cache_folder): assert rec2.get_annotation(annotation_name) == rec_zarr_loaded.get_annotation(annotation_name) rec_zarr2 = rec2.save( - format="zarr", folder=cache_folder / "recording_channel_chunk", compressor=compressor, channel_chunk_size=2 + format="zarr", folder=cache_folder / "recording_channel_chunk", compressors=compressor, channel_chunk_size=2 ) rec_zarr2_loaded = load(cache_folder / "recording_channel_chunk.zarr") diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index ab0b071df4..3a1c73d746 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -17,6 +17,7 @@ AnalyzerExtension, _sort_extensions_by_dependency, ) +from spikeinterface.core.zarrextractors import check_compressors_match import numpy as np @@ -38,6 +39,8 @@ def get_dataset(): integer_unit_ids = [int(id) for id in sorting.get_unit_ids()] recording = recording.rename_channels(new_channel_ids=integer_channel_ids) + # make sure the recording is serializable + recording = recording.save() sorting = sorting.rename_units(new_unit_ids=integer_unit_ids) return recording, sorting @@ -133,13 +136,12 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) # check that compression is applied - assert ( - sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressor.codec_id - == default_compressor.codec_id + check_compressors_match( + default_compressor, + sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressors[0], ) - assert ( - sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id - == default_compressor.codec_id + check_compressors_match( + default_compressor, sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressors[0] ) # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 @@ -160,35 +162,34 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): sparsity=None, return_in_uV=False, overwrite=True, - backend_options={"saving_options": {"compressor": None}}, + backend_options={"saving_options": {"compressors": None}}, ) print(sorting_analyzer_no_compression._backend_options) sorting_analyzer_no_compression.compute(["random_spikes", "templates"]) assert ( - sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ - "random_spikes_indices" - ].compressor - is None + len( + sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ + "random_spikes_indices" + ].compressors + ) + == 0 ) - assert sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressor is None + assert len(sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressors) == 0 # test a different compressor - from numcodecs import LZMA + from zarr.codecs.numcodecs import LZMA lzma_compressor = LZMA() folder = tmp_path / "test_SortingAnalyzer_zarr_lzma.zarr" sorting_analyzer_lzma = sorting_analyzer_no_compression.save_as( - format="zarr", folder=folder, backend_options={"saving_options": {"compressor": lzma_compressor}} + format="zarr", folder=folder, backend_options={"saving_options": {"compressors": lzma_compressor}} ) - assert ( - sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"][ - "random_spikes_indices" - ].compressor.codec_id - == LZMA.codec_id + check_compressors_match( + lzma_compressor, + sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressors[0], ) - assert ( - sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id - == LZMA.codec_id + check_compressors_match( + lzma_compressor, sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressors[0] ) # test set_sorting_property diff --git a/src/spikeinterface/core/tests/test_zarrextractors.py b/src/spikeinterface/core/tests/test_zarrextractors.py index cc0c60721e..a52d456594 100644 --- a/src/spikeinterface/core/tests/test_zarrextractors.py +++ b/src/spikeinterface/core/tests/test_zarrextractors.py @@ -10,50 +10,55 @@ generate_sorting, load, ) -from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group, get_default_zarr_compressor +from spikeinterface.core.zarrextractors import ( + add_sorting_to_zarr_group, + get_default_zarr_compressor, + check_compressors_match, +) def test_zarr_compression_options(tmp_path): - from numcodecs import Blosc, Delta, FixedScaleOffset + from zarr.codecs.numcodecs import Delta, FixedScaleOffset + from zarr.codecs import BloscCodec, BloscShuffle recording = generate_recording(durations=[2]) recording.set_times(recording.get_times() + 100) # store in root standard normal way # default compressor - defaut_compressor = get_default_zarr_compressor() + default_compressor = get_default_zarr_compressor() # other compressor - other_compressor1 = Blosc(cname="zlib", clevel=3, shuffle=Blosc.NOSHUFFLE) - other_compressor2 = Blosc(cname="blosclz", clevel=8, shuffle=Blosc.AUTOSHUFFLE) + other_compressor1 = BloscCodec(cname="zlib", clevel=3, shuffle=BloscShuffle.noshuffle) + other_compressor2 = BloscCodec(cname="blosclz", clevel=8, shuffle=BloscShuffle.shuffle) # timestamps compressors / filters default_filters = None - other_filters1 = [FixedScaleOffset(scale=5, offset=2, dtype=recording.get_dtype())] + other_filters1 = [FixedScaleOffset(scale=5, offset=2, dtype=recording.get_dtype().str)] other_filters2 = [Delta(dtype="float64")] # default ZarrRecordingExtractor.write_recording(recording, tmp_path / "rec_default.zarr") rec_default = ZarrRecordingExtractor(tmp_path / "rec_default.zarr") - assert rec_default._root["traces_seg0"].compressor == defaut_compressor - assert rec_default._root["traces_seg0"].filters == default_filters - assert rec_default._root["times_seg0"].compressor == defaut_compressor - assert rec_default._root["times_seg0"].filters == default_filters + check_compressors_match(rec_default._root["traces_seg0"].compressors[0], default_compressor) + check_compressors_match(rec_default._root["times_seg0"].compressors[0], default_compressor) + check_compressors_match(rec_default._root["traces_seg0"].filters, default_filters) + check_compressors_match(rec_default._root["times_seg0"].filters, default_filters) # now with other compressor ZarrRecordingExtractor.write_recording( recording, tmp_path / "rec_other.zarr", - compressor=defaut_compressor, + compressors=default_compressor, filters=default_filters, compressor_by_dataset={"traces": other_compressor1, "times": other_compressor2}, filters_by_dataset={"traces": other_filters1, "times": other_filters2}, ) rec_other = ZarrRecordingExtractor(tmp_path / "rec_other.zarr") - assert rec_other._root["traces_seg0"].compressor == other_compressor1 - assert rec_other._root["traces_seg0"].filters == other_filters1 - assert rec_other._root["times_seg0"].compressor == other_compressor2 - assert rec_other._root["times_seg0"].filters == other_filters2 + check_compressors_match(rec_other._root["traces_seg0"].compressors[0], other_compressor1) + check_compressors_match(rec_other._root["traces_seg0"].filters, other_filters1) + check_compressors_match(rec_other._root["times_seg0"].compressors[0], other_compressor2) + check_compressors_match(rec_other._root["times_seg0"].filters, other_filters2) def test_ZarrSortingExtractor(tmp_path): diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 162d67a458..6b20c6bbaa 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -48,11 +48,13 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d import zarr # if mode is append or read/write, we try to open the folder with zarr.open - # since zarr.open_consolidated does not support creating new groups/datasets + # In zarr v3, we use use_consolidated parameter instead of open_consolidated if mode in ("a", "r+"): open_funcs = (zarr.open,) + use_consolidated_options = (False,) else: - open_funcs = (zarr.open_consolidated, zarr.open) + open_funcs = (zarr.open,) + use_consolidated_options = (True, False) # if storage_options is None, we try to open the folder with and without anonymous access # if storage_options is not None, we try to open the folder with the given storage options @@ -64,12 +66,14 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d root = None exception = None if is_path_remote(str(folder_path)): - for open_func in open_funcs: + for use_consolidated in use_consolidated_options: if root is not None: break for storage_options in storage_options_to_test: try: - root = open_func(str(folder_path), mode=mode, storage_options=storage_options) + root = zarr.open( + str(folder_path), mode=mode, storage_options=storage_options, use_consolidated=use_consolidated + ) break except Exception as e: exception = e @@ -77,9 +81,9 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d else: if not Path(folder_path).is_dir(): raise ValueError(f"Folder {folder_path} does not exist") - for open_func in open_funcs: + for use_consolidated in use_consolidated_options: try: - root = open_func(str(folder_path), mode=mode, storage_options=storage_options) + root = zarr.open(str(folder_path), mode=mode, use_consolidated=use_consolidated) break except Exception as e: exception = e @@ -91,6 +95,34 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d return root +def check_compressors_match(comp1, comp2, skip_typesize=True): + """ + Check if two compressor objects match. + + Parameters + ---------- + comp1 : zarr.Codec | Tuple[zarr.Codec] + The first compressor object to compare. + comp2 : zarr.Codec | Tuple[zarr.Codec] + The second compressor object to compare. + skip_typesize : bool, optional + Whether to skip the typesize check, default: True + """ + if not isinstance(comp1, (list, tuple)): + assert not isinstance(comp2, list) + comp1 = [comp1] + comp2 = [comp2] + for i in range(len(comp1)): + comp1_dict = comp1[i].to_dict() + comp2_dict = comp2[i].to_dict() + if skip_typesize: + if "typesize" in comp1_dict["configuration"]: + comp1_dict["configuration"].pop("typesize", None) + if "typesize" in comp2_dict["configuration"]: + comp2_dict["configuration"].pop("typesize", None) + assert comp1_dict == comp2_dict + + class ZarrRecordingExtractor(BaseRecording): """ RecordingExtractor for a zarr format @@ -289,7 +321,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, BaseSorting.__init__(self, sampling_frequency, unit_ids) - spikes = np.zeros(len(spikes_group["sample_index"]), dtype=minimum_spike_dtype) + spikes = np.zeros(spikes_group["sample_index"].shape[0], dtype=minimum_spike_dtype) spikes["sample_index"] = spikes_group["sample_index"][:] spikes["unit_index"] = spikes_group["unit_index"][:] for i, (start, end) in enumerate(segment_slices_list): @@ -392,9 +424,9 @@ def get_default_zarr_compressor(clevel: int = 5): Blosc.compressor The compressor object that can be used with the save to zarr function """ - from numcodecs import Blosc + from zarr.codecs import BloscCodec, BloscShuffle - return Blosc(cname="zstd", clevel=clevel, shuffle=Blosc.BITSHUFFLE) + return BloscCodec(cname="zstd", clevel=clevel, shuffle=BloscShuffle.bitshuffle) def add_properties_and_annotations(zarr_group: zarr.hierarchy.Group, recording_or_sorting: BaseRecording | BaseSorting): @@ -405,7 +437,7 @@ def add_properties_and_annotations(zarr_group: zarr.hierarchy.Group, recording_o if values.dtype.kind == "O": warnings.warn(f"Property {key} not saved because it is a python Object type") continue - prop_group.create_dataset(name=key, data=values, compressor=None) + prop_group.create_array(name=key, data=values, compressors=None) # save annotations zarr_group.attrs["annotations"] = check_json(recording_or_sorting._annotations) @@ -424,12 +456,12 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G kwargs : dict Other arguments passed to the zarr compressor """ - from numcodecs import Delta + from zarr.codecs.numcodecs import Delta num_segments = sorting.get_num_segments() zarr_group.attrs["sampling_frequency"] = float(sorting.sampling_frequency) zarr_group.attrs["num_segments"] = int(num_segments) - zarr_group.create_dataset(name="unit_ids", data=sorting.unit_ids, compressor=None) + zarr_group.create_array(name="unit_ids", data=sorting.unit_ids, compressors=None) compressor = kwargs.get("compressor", get_default_zarr_compressor()) @@ -438,18 +470,21 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G spikes = sorting.to_spike_vector() for field in spikes.dtype.fields: if field != "segment_index": - spikes_group.create_dataset( + dtype = spikes[field].dtype + spikes_data = spikes[field] + spikes_group.create_array( name=field, - data=spikes[field], - compressor=compressor, - filters=[Delta(dtype=spikes[field].dtype)], + data=spikes_data, + compressors=compressor, + filters=[Delta(dtype=spikes[field].dtype.str)], ) else: segment_slices = [] for segment_index in range(num_segments): i0, i1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append([i0, i1]) - spikes_group.create_dataset(name="segment_slices", data=segment_slices, compressor=None) + segment_slices = np.array(segment_slices, dtype="int64") + spikes_group.create_array(name="segment_slices", data=segment_slices, compressors=None) add_properties_and_annotations(zarr_group, sorting) @@ -468,7 +503,7 @@ def add_recording_to_zarr_group( # save data (done the subclass) zarr_group.attrs["sampling_frequency"] = float(recording.get_sampling_frequency()) zarr_group.attrs["num_segments"] = int(recording.get_num_segments()) - zarr_group.create_dataset(name="channel_ids", data=recording.get_channel_ids(), compressor=None) + zarr_group.create_array(name="channel_ids", data=recording.get_channel_ids(), compressors=None) dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())] dtype = recording.get_dtype() if dtype is None else dtype @@ -484,7 +519,7 @@ def add_recording_to_zarr_group( recording=recording, zarr_group=zarr_group, dataset_paths=dataset_paths, - compressor=compressor_traces, + compressors=compressor_traces, filters=filters_traces, dtype=dtype, channel_chunk_size=channel_chunk_size, @@ -507,17 +542,17 @@ def add_recording_to_zarr_group( filters_times = filters_by_dataset.get("times", global_filters) if time_vector is not None: - _ = zarr_group.create_dataset( + _ = zarr_group.create_array( name=f"times_seg{segment_index}", data=time_vector, filters=filters_times, - compressor=compressor_times, + compressors=compressor_times, ) elif d["t_start"] is not None: t_starts[segment_index] = d["t_start"] if np.any(~np.isnan(t_starts)): - zarr_group.create_dataset(name="t_starts", data=t_starts, compressor=None) + zarr_group.create_array(name="t_starts", data=t_starts, compressors=None) add_properties_and_annotations(zarr_group, recording) @@ -528,7 +563,7 @@ def add_traces_to_zarr( dataset_paths, channel_chunk_size=None, dtype=None, - compressor=None, + compressors=None, filters=None, verbose=False, **job_kwargs, @@ -548,7 +583,7 @@ def add_traces_to_zarr( Channels per chunk dtype : dtype, default: None Type of the saved data - compressor : zarr compressor or None, default: None + compressors : zarr compressor or None, default: None Zarr compressor filters : list, default: None List of zarr filters @@ -581,13 +616,15 @@ def add_traces_to_zarr( num_channels = recording.get_num_channels() dset_name = dataset_paths[segment_index] shape = (num_frames, num_channels) - dset = zarr_group.create_dataset( + # In zarr v3, chunks must be a tuple of integers (no None allowed) + chunks = (chunk_size, channel_chunk_size if channel_chunk_size is not None else num_channels) + dset = zarr_group.create_array( name=dset_name, shape=shape, - chunks=(chunk_size, channel_chunk_size), + chunks=chunks, dtype=dtype, filters=filters, - compressor=compressor, + compressors=compressors, ) zarr_datasets.append(dset) # synchronizer=zarr.ThreadSynchronizer()) From 4dbb7b38fab00795aebb3644ba4bb325052607cf Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Dec 2025 16:49:23 +0100 Subject: [PATCH 02/10] wip --- src/spikeinterface/core/zarrextractors.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 6b20c6bbaa..0ec28d544a 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -14,6 +14,9 @@ from .core_tools import is_path_remote +zarr.config.set({"default_zarr_version": 3}) + + def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: dict | None = None): """ Open a zarr folder with super powers. @@ -463,7 +466,9 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G zarr_group.attrs["num_segments"] = int(num_segments) zarr_group.create_array(name="unit_ids", data=sorting.unit_ids, compressors=None) - compressor = kwargs.get("compressor", get_default_zarr_compressor()) + compressor = kwargs.get("compressors") or kwargs.get("compressor") + if compressor is None: + compressor = get_default_zarr_compressor() # save sub fields spikes_group = zarr_group.create_group(name="spikes") @@ -508,7 +513,9 @@ def add_recording_to_zarr_group( dtype = recording.get_dtype() if dtype is None else dtype channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None) - global_compressor = zarr_kwargs.pop("compressor", get_default_zarr_compressor()) + global_compressor = kwargs.get("compressors") or kwargs.get("compressor") + if global_compressor is None: + global_compressor = get_default_zarr_compressor() compressor_by_dataset = zarr_kwargs.pop("compressor_by_dataset", {}) global_filters = zarr_kwargs.pop("filters", None) filters_by_dataset = zarr_kwargs.pop("filters_by_dataset", {}) @@ -609,6 +616,9 @@ def add_traces_to_zarr( job_kwargs = fix_job_kwargs(job_kwargs) chunk_size = ensure_chunk_size(recording, **job_kwargs) + if not isinstance(compressors, (list, tuple)): + compressors = [compressors] + # create zarr datasets files zarr_datasets = [] for segment_index in range(recording.get_num_segments()): @@ -618,13 +628,8 @@ def add_traces_to_zarr( shape = (num_frames, num_channels) # In zarr v3, chunks must be a tuple of integers (no None allowed) chunks = (chunk_size, channel_chunk_size if channel_chunk_size is not None else num_channels) - dset = zarr_group.create_array( - name=dset_name, - shape=shape, - chunks=chunks, - dtype=dtype, - filters=filters, - compressors=compressors, + dset = zarr_group.create( + name=dset_name, shape=shape, chunks=chunks, dtype=dtype, filters=filters, codecs=compressors, zarr_format=3 ) zarr_datasets.append(dset) # synchronizer=zarr.ThreadSynchronizer()) From 4bc45daa956d1fb2f7cdb5f6b27147cc708ad2ff Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Mar 2026 12:43:42 +0100 Subject: [PATCH 03/10] fix: zarr.Group --- src/spikeinterface/core/zarrextractors.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 43a7f55f3d..4e467ebdbd 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -38,7 +38,7 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d Returns ------- - root: zarr.hierarchy.Group + root: zarr.Group The zarr root group object Raises @@ -496,7 +496,7 @@ def build_codec_pipeline(filters=None, compressors=None): return codecs if codecs else None -def add_properties_and_annotations(zarr_group: zarr.hierarchy.Group, recording_or_sorting: BaseRecording | BaseSorting): +def add_properties_and_annotations(zarr_group: zarr.Group, recording_or_sorting: BaseRecording | BaseSorting): # save properties prop_group = zarr_group.create_group("properties") for key in recording_or_sorting.get_property_keys(): @@ -510,7 +510,7 @@ def add_properties_and_annotations(zarr_group: zarr.hierarchy.Group, recording_o zarr_group.attrs["annotations"] = check_json(recording_or_sorting._annotations) -def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.Group, **kwargs): +def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.Group, **kwargs): """ Add a sorting extractor to a zarr group. @@ -518,7 +518,7 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G ---------- sorting : BaseSorting The sorting extractor object to be added to the zarr group - zarr_group : zarr.hierarchy.Group + zarr_group : zarr.Group The zarr group kwargs : dict Other arguments passed to the zarr compressor @@ -556,9 +556,7 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G # Recording -def add_recording_to_zarr_group( - recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, dtype=None, **kwargs -): +def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.Group, verbose=False, dtype=None, **kwargs): zarr_kwargs, job_kwargs = split_job_kwargs(kwargs) if recording.check_if_json_serializable(): From e9a567f29a9574d8af8f2a95db2903ec62c0c427 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Mar 2026 13:12:22 +0100 Subject: [PATCH 04/10] wip: fix v3 --- src/spikeinterface/core/zarrextractors.py | 88 +++++++++++++++-------- 1 file changed, 58 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 4e467ebdbd..ba3d2f1318 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -163,7 +163,8 @@ def __init__( assert sampling_frequency is not None, "'sampling_frequency' attiribute not found!" assert num_segments is not None, "'num_segments' attiribute not found!" - channel_ids = np.array(channel_ids) + # zarr returns vlen-utf8 as StringDType (numpy 2.0); convert via list to classic unicode array. + channel_ids = np.array(channel_ids.tolist()) dtype = self._root["traces_seg0"].dtype @@ -201,7 +202,7 @@ def __init__( if load_compression_ratio: nbytes_segment = self._root[trace_name].nbytes - nbytes_stored_segment = self._root[trace_name].nbytes_stored + nbytes_stored_segment = self._root[trace_name].nbytes_stored() if nbytes_stored_segment > 0: cr_by_segment[segment_index] = nbytes_segment / nbytes_stored_segment else: @@ -220,7 +221,11 @@ def __init__( if "properties" in self._root: prop_group = self._root["properties"] for key in prop_group.keys(): - values = self._root["properties"][key] + values = self._root["properties"][key][:] + # zarr returns vlen-utf8 as StringDType (numpy 2.0); convert via list to classic unicode array. + if hasattr(values.dtype, "na_object") or values.dtype.kind == "O": + if values.size > 0 and isinstance(values.tolist()[0], str): + values = np.array(values.tolist()) self.set_property(key, values) # load annotations @@ -338,7 +343,11 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, if "properties" in self._root: prop_group = self._root["properties"] for key in prop_group.keys(): - values = self._root["properties"][key] + values = self._root["properties"][key][:] + # zarr returns vlen-utf8 as StringDType (numpy 2.0); convert via list to classic unicode array. + if hasattr(values.dtype, "na_object") or values.dtype.kind == "O": + if values.size > 0 and isinstance(values.tolist()[0], str): + values = np.array(values.tolist()) self.set_property(key, values) # load annotations @@ -434,12 +443,12 @@ def get_default_zarr_compressor(clevel: int = 5): def build_codec_pipeline(filters=None, compressors=None): """ - Build a zarr v3 codecs list from filters and compressors. + Build zarr v3 codec kwargs from filters and compressors. - Assembles a valid zarr v3 codec pipeline in the required order: - 1. ArrayArrayCodec (filters, e.g. Delta) - 2. ArrayBytesCodec (serializer, e.g. WavPack, BytesCodec) - 3. BytesBytesCodec (compressors, e.g. BloscCodec, ZstdCodec) + Classifies codecs into the three slots accepted by ``zarr.Group.create_array()``: + 1. ``filters`` — ArrayArrayCodec (e.g. Delta) + 2. ``serializer`` — ArrayBytesCodec (e.g. WavPack, BytesCodec) + 3. ``compressors``— BytesBytesCodec (e.g. BloscCodec, ZstdCodec) This allows callers to pass an ArrayBytesCodec (e.g. WavPack) as a compressor and have it placed in the correct serializer slot automatically. @@ -454,10 +463,10 @@ def build_codec_pipeline(filters=None, compressors=None): Returns ------- - list of codecs or None - Full codec pipeline suitable for the ``codecs=`` parameter of - ``zarr.create()``. Returns None when both inputs are empty/None, - letting zarr use its defaults. + dict + Keyword arguments to unpack into ``zarr.Group.create_array()``. + Only keys with explicit values are included; omitted keys let zarr + use its defaults. Raises ------ @@ -492,8 +501,18 @@ def build_codec_pipeline(filters=None, compressors=None): if len(serializers) > 1: raise ValueError("Only one ArrayBytesCodec (serializer) is allowed in the codec pipeline.") - codecs = filters + serializers + byte_compressors - return codecs if codecs else None + codec_kwargs = {} + codec_kwargs["filters"] = filters + codec_kwargs["serializer"] = serializers[0] + codec_kwargs["compressors"] = byte_compressors + return codec_kwargs + + +def _has_string_fields(dtype: np.dtype) -> bool: + """Return True if dtype is or contains fixed-length unicode (U) sub-fields.""" + if dtype.names: + return any(_has_string_fields(dtype.fields[name][0]) for name in dtype.names) + return dtype.kind == "U" def add_properties_and_annotations(zarr_group: zarr.Group, recording_or_sorting: BaseRecording | BaseSorting): @@ -504,7 +523,20 @@ def add_properties_and_annotations(zarr_group: zarr.Group, recording_or_sorting: if values.dtype.kind == "O": warnings.warn(f"Property {key} not saved because it is a python Object type") continue - prop_group.create_array(name=key, data=values, compressors=None) + if values.dtype.names and _has_string_fields(values.dtype): + # Structured arrays with unicode sub-fields have no stable zarr v3 spec; skip them. + # Probe geometry (contact_vector) is already persisted via zarr_group.attrs["probe"]. + warnings.warn( + f"Property '{key}' not saved because it is a structured array with unicode fields, " + "which do not have a stable zarr V3 specification." + ) + continue + # Use variable-length UTF-8 (stable zarr v3 spec) for unicode arrays. + if values.dtype.kind == "U": + arr = prop_group.create_array(name=key, shape=values.shape, dtype=str, compressors=None) + arr[:] = values + else: + prop_group.create_array(name=key, data=values, compressors=None) # save annotations zarr_group.attrs["annotations"] = check_json(recording_or_sorting._annotations) @@ -541,9 +573,8 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.Group, **kw if field != "segment_index": dtype = spikes[field].dtype spikes_data = spikes[field] - codecs = build_codec_pipeline(filters=[Delta(dtype=spikes[field].dtype.str)], compressors=compressor) - arr = spikes_group.create(name=field, shape=spikes_data.shape, dtype=spikes_data.dtype, codecs=codecs) - arr[:] = spikes_data + codec_kwargs = build_codec_pipeline(filters=[Delta(dtype=spikes[field].dtype.str)], compressors=compressor) + spikes_group.create_array(name=field, data=spikes_data, **codec_kwargs) else: segment_slices = [] for segment_index in range(num_segments): @@ -567,7 +598,10 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.Group # save data (done the subclass) zarr_group.attrs["sampling_frequency"] = float(recording.get_sampling_frequency()) zarr_group.attrs["num_segments"] = int(recording.get_num_segments()) - zarr_group.create_array(name="channel_ids", data=recording.get_channel_ids(), compressors=None) + # Use variable-length UTF-8 (stable zarr v3 spec) instead of fixed-length unicode. + channel_ids = recording.get_channel_ids() + arr = zarr_group.create_array(name="channel_ids", shape=channel_ids.shape, dtype=str, compressors=None) + arr[:] = channel_ids dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())] dtype = recording.get_dtype() if dtype is None else dtype @@ -608,14 +642,8 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.Group filters_times = filters_by_dataset.get("times", global_filters) if time_vector is not None: - codecs = build_codec_pipeline(filters=filters_times, compressors=compressor_times) - arr = zarr_group.create( - name=f"times_seg{segment_index}", - shape=time_vector.shape, - dtype=time_vector.dtype, - codecs=codecs, - ) - arr[:] = time_vector + codec_kwargs = build_codec_pipeline(filters=filters_times, compressors=compressor_times) + zarr_group.create_array(name=f"times_seg{segment_index}", data=time_vector, **codec_kwargs) elif d["t_start"] is not None: t_starts[segment_index] = d["t_start"] @@ -677,7 +705,7 @@ def add_traces_to_zarr( job_kwargs = fix_job_kwargs(job_kwargs) chunk_size = ensure_chunk_size(recording, **job_kwargs) - codecs = build_codec_pipeline(filters=filters, compressors=compressors) + codec_kwargs = build_codec_pipeline(filters=filters, compressors=compressors) # create zarr datasets files zarr_datasets = [] @@ -688,7 +716,7 @@ def add_traces_to_zarr( shape = (num_frames, num_channels) # In zarr v3, chunks must be a tuple of integers (no None allowed) chunks = (chunk_size, channel_chunk_size if channel_chunk_size is not None else num_channels) - dset = zarr_group.create(name=dset_name, shape=shape, chunks=chunks, dtype=dtype, codecs=codecs, zarr_format=3) + dset = zarr_group.create_array(name=dset_name, shape=shape, chunks=chunks, dtype=dtype, **codec_kwargs) zarr_datasets.append(dset) # synchronizer=zarr.ThreadSynchronizer()) From 39a75883f71dadeaa294ceb2db8a59a94be33b71 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Mar 2026 10:47:08 +0100 Subject: [PATCH 05/10] Fix v2/v3 test, add zarr_class_info, pin to dev probeinterface for testing --- .github/scripts/generate_zarr_v2_fixtures.py | 37 +++++++++---------- pyproject.toml | 16 ++++++-- src/spikeinterface/core/sortinganalyzer.py | 8 +++- .../core/tests/test_zarr_backwards_compat.py | 16 +++----- src/spikeinterface/core/zarrextractors.py | 26 +++++++++---- 5 files changed, 60 insertions(+), 43 deletions(-) diff --git a/.github/scripts/generate_zarr_v2_fixtures.py b/.github/scripts/generate_zarr_v2_fixtures.py index 90d0515da8..e555873458 100644 --- a/.github/scripts/generate_zarr_v2_fixtures.py +++ b/.github/scripts/generate_zarr_v2_fixtures.py @@ -12,64 +12,61 @@ - expected_values.json : key values used to verify correct loading """ import argparse +import shutil import json from pathlib import Path import numpy as np +import zarr +import spikeinterface as si -def main(output_dir: Path) -> None: - import spikeinterface - - print(f"spikeinterface version : {spikeinterface.__version__}") - - import zarr +def main(output_dir: Path) -> None: + print(f"spikeinterface version : {si.__version__}") print(f"zarr version : {zarr.__version__}") - from spikeinterface.core import generate_recording, generate_sorting - from spikeinterface.core import ZarrRecordingExtractor, ZarrSortingExtractor - from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer output_dir.mkdir(parents=True, exist_ok=True) - recording = generate_recording(num_channels=4, num_segments=2, seed=0) - sorting = generate_sorting(num_units=3, num_segments=2, seed=0) - + recording, sorting = si.generate_ground_truth_recording(durations=[10, 5],num_channels=32, num_units=10, seed=0) + # save to binary to make them JSON serializable for later expected values extraction + recording = recording.save(folder=output_dir / "recording_binary", overwrite=True) + sorting = sorting.save(folder=output_dir / "sorting_binary", overwrite=True) # --- save recording --- recording_path = output_dir / "recording.zarr" - ZarrRecordingExtractor.write_recording(recording, recording_path) + recording_zarr = recording.save(format="zarr", folder=recording_path, overwrite=True) print(f"Saved recording -> {recording_path}") # --- save sorting --- sorting_path = output_dir / "sorting.zarr" - ZarrSortingExtractor.write_sorting(sorting, sorting_path) + sorting_zarr = sorting.save(format="zarr", folder=sorting_path, overwrite=True) print(f"Saved sorting -> {sorting_path}") # --- save SortingAnalyzer --- # Reload the recording from zarr so it is a serializable ZarrRecordingExtractor, # which the analyzer can store as provenance. - recording_zarr = ZarrRecordingExtractor(recording_path) analyzer_path = output_dir / "analyzer.zarr" - analyzer = create_sorting_analyzer( - sorting, recording_zarr, format="zarr", folder=analyzer_path, sparse=False, sparsity=None + if analyzer_path.is_dir(): + shutil.rmtree(analyzer_path) + analyzer = si.create_sorting_analyzer( + sorting_zarr, recording_zarr, format="zarr", folder=analyzer_path, overwrite=True ) analyzer.compute(["random_spikes", "templates"]) print(f"Saved analyzer -> {analyzer_path}") # Reload to verify templates are accessible before writing expected values - analyzer = load_sorting_analyzer(analyzer_path) templates_array = analyzer.get_extension("templates").get_data() # --- capture expected values for later assertion --- expected = { - "spikeinterface_version": spikeinterface.__version__, + "spikeinterface_version": si.__version__, "zarr_version": zarr.__version__, "recording": { "num_channels": int(recording.get_num_channels()), "num_segments": int(recording.get_num_segments()), "sampling_frequency": float(recording.get_sampling_frequency()), - "num_frames_per_segment": [int(recording.get_num_frames(seg)) for seg in range(recording.get_num_segments())], + "num_samples_per_segment": [int(recording.get_num_samples(seg)) for seg in range(recording.get_num_segments())], "channel_ids": recording.get_channel_ids().tolist(), "dtype": str(recording.get_dtype()), # first 10 frames of segment 0 for all channels diff --git a/pyproject.toml b/pyproject.toml index 7978862df4..0fa6d12fd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,7 +126,9 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # FOR TESTING: use probeinterface zarrv3 branch + "probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs, @@ -138,7 +140,9 @@ test_extractors = [ "pooch>=1.8.2", "datalad>=1.0.2", # Commenting out for release - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # FOR TESTING: use probeinterface zarrv3 branch + "probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] @@ -189,7 +193,9 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # FOR TESTING: use probeinterface zarrv3 branch + "probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs @@ -218,7 +224,9 @@ docs = [ "huggingface_hub", # For automated curation # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # FOR TESTING: use probeinterface zarrv3 branch + "probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 6f47790b18..6e26508c81 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1887,8 +1887,14 @@ def get_saved_extension_names(self): elif self.format == "zarr": zarr_root = self._get_zarr_root(mode="r") - if "extensions" in zarr_root.keys(): + # Avoid iterating zarr_root.keys() because legacy v2 stores may contain + # object-dtype arrays (e.g. "recording", "sorting_provenance") that zarr v3 + # cannot parse, causing ValueError on enumeration. + try: extension_group = zarr_root["extensions"] + except KeyError: + extension_group = None + if extension_group is not None: for extension_name in extension_group.keys(): if "params" in extension_group[extension_name].attrs.keys(): saved_extension_names.append(extension_name) diff --git a/src/spikeinterface/core/tests/test_zarr_backwards_compat.py b/src/spikeinterface/core/tests/test_zarr_backwards_compat.py index d4f49d6fd8..49d99ce8eb 100644 --- a/src/spikeinterface/core/tests/test_zarr_backwards_compat.py +++ b/src/spikeinterface/core/tests/test_zarr_backwards_compat.py @@ -18,6 +18,8 @@ import numpy as np import pytest +import spikeinterface as si + FIXTURES_PATH = os.environ.get("ZARR_V2_FIXTURES_PATH") pytestmark = pytest.mark.skipif( @@ -38,9 +40,7 @@ def expected(fixtures_dir: Path) -> dict: def test_load_recording(fixtures_dir, expected): - from spikeinterface.core import read_zarr_recording - - recording = read_zarr_recording(fixtures_dir / "recording.zarr") + recording = si.load(fixtures_dir / "recording.zarr") exp = expected["recording"] assert recording.get_num_channels() == exp["num_channels"] @@ -49,7 +49,7 @@ def test_load_recording(fixtures_dir, expected): assert str(recording.get_dtype()) == exp["dtype"] for seg in range(recording.get_num_segments()): - assert recording.get_num_frames(seg) == exp["num_frames_per_segment"][seg] + assert recording.get_num_samples(seg) == exp["num_samples_per_segment"][seg] assert list(recording.get_channel_ids()) == exp["channel_ids"] @@ -58,9 +58,7 @@ def test_load_recording(fixtures_dir, expected): def test_load_sorting(fixtures_dir, expected): - from spikeinterface.core import read_zarr_sorting - - sorting = read_zarr_sorting(fixtures_dir / "sorting.zarr") + sorting = si.load(fixtures_dir / "sorting.zarr") exp = expected["sorting"] assert sorting.get_num_segments() == exp["num_segments"] @@ -73,9 +71,7 @@ def test_load_sorting(fixtures_dir, expected): def test_load_sorting_analyzer(fixtures_dir, expected): - from spikeinterface.core import load_sorting_analyzer - - analyzer = load_sorting_analyzer(fixtures_dir / "analyzer.zarr") + analyzer = si.load(fixtures_dir / "analyzer.zarr") exp = expected["analyzer"] assert analyzer.get_num_units() == exp["num_units"] diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index ba3d2f1318..4c3058cde0 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -8,7 +8,7 @@ from .base import minimum_spike_dtype from .baserecording import BaseRecording, BaseRecordingSegment from .basesorting import BaseSorting, SpikeVectorSortingSegment -from .core_tools import define_function_from_class, check_json +from .core_tools import define_function_from_class, check_json, retrieve_importing_provenance from .job_tools import split_job_kwargs from .core_tools import is_path_remote @@ -251,6 +251,7 @@ def write_recording( recording: BaseRecording, folder_path: str | Path, storage_options: dict | None = None, **kwargs ): zarr_root = zarr.open(str(folder_path), mode="w", storage_options=storage_options) + zarr_root.attrs["zarr_class_info"] = retrieve_importing_provenance(ZarrRecordingExtractor) add_recording_to_zarr_group(recording, zarr_root, **kwargs) @@ -363,6 +364,7 @@ def write_sorting(sorting: BaseSorting, folder_path: str | Path, storage_options Write a sorting extractor to zarr format. """ zarr_root = zarr.open(str(folder_path), mode="w", storage_options=storage_options) + zarr_root.attrs["zarr_class_info"] = retrieve_importing_provenance(ZarrSortingExtractor) add_sorting_to_zarr_group(sorting, zarr_root, **kwargs) @@ -388,15 +390,23 @@ def read_zarr( extractor : ZarrExtractor The loaded extractor """ - # TODO @alessio : we should have something more explicit in our zarr format to tell which object it is. - # for the futur SortingAnalyzer we will have this 2 fields!!! root = super_zarr_open(folder_path, mode="r", storage_options=storage_options) - if "channel_ids" in root.keys(): - return read_zarr_recording(folder_path, storage_options=storage_options) - elif "unit_ids" in root.keys(): - return read_zarr_sorting(folder_path, storage_options=storage_options) + zarr_class_info = root.attrs.get("zarr_class_info", None) + if zarr_class_info is not None: + class_name = zarr_class_info["class"] + extractor_class = _get_class_from_string(class_name) + return extractor_class(folder_path, storage_options=storage_options) else: - raise ValueError("Cannot find 'channel_ids' or 'unit_ids' in zarr root. Not a valid SpikeInterface zarr format") + # For v<0.105.0 and old zarr files, revert to old way of loading based on the presence of "channel_ids" + # or "unit_ids" in the root + if "channel_ids" in root.keys(): + return read_zarr_recording(folder_path, storage_options=storage_options) + elif "unit_ids" in root.keys(): + return read_zarr_sorting(folder_path, storage_options=storage_options) + else: + raise ValueError( + "Cannot find 'channel_ids' or 'unit_ids' in zarr root. Not a valid SpikeInterface zarr format" + ) ### UTILITY FUNCTIONS ### From 7be375b0fb80bf8948a40ae739016e677e6aa50a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Mar 2026 12:21:56 +0100 Subject: [PATCH 06/10] Python>=3.11 and disable deepinterpolation action --- .github/workflows/all-tests.yml | 2 +- .github/workflows/deepinterpolation.yml | 6 ++---- .github/workflows/test_containers_docker.yml | 2 +- .github/workflows/test_containers_singularity.yml | 2 +- pyproject.toml | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 0d242b759a..41c3f81054 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -24,7 +24,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.13"] # Lower and higher versions we support + python-version: ["3.11", "3.13"] # Lower and higher versions we support os: [macos-latest, windows-latest, ubuntu-latest] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/deepinterpolation.yml b/.github/workflows/deepinterpolation.yml index be003da742..2e7b8d03eb 100644 --- a/.github/workflows/deepinterpolation.yml +++ b/.github/workflows/deepinterpolation.yml @@ -1,10 +1,8 @@ name: Testing deepinterpolation +# Manual only — deepinterpolation requires Python 3.10, incompatible with 3.11+ required by Zarr 3.0.0+ on: - pull_request: - types: [synchronize, opened, reopened] - branches: - - main + workflow_dispatch: concurrency: # Cancel previous workflows on the same pull request group: ${{ github.workflow }}-${{ github.ref }} diff --git a/.github/workflows/test_containers_docker.yml b/.github/workflows/test_containers_docker.yml index 211db5f775..73a194efb3 100644 --- a/.github/workflows/test_containers_docker.yml +++ b/.github/workflows/test_containers_docker.yml @@ -15,7 +15,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - name: Python version run: python --version diff --git a/.github/workflows/test_containers_singularity.yml b/.github/workflows/test_containers_singularity.yml index 00941215b1..0554a0060c 100644 --- a/.github/workflows/test_containers_singularity.yml +++ b/.github/workflows/test_containers_singularity.yml @@ -16,7 +16,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - uses: eWaterCycle/setup-singularity@v7 with: singularity-version: 3.8.7 diff --git a/pyproject.toml b/pyproject.toml index ebe6f3ba1a..2edee0194c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", From 38987afec555dcfc020c69ab92be3440ffe5933b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Mar 2026 14:45:09 +0100 Subject: [PATCH 07/10] Add sharding option --- src/spikeinterface/core/base.py | 42 +++---- src/spikeinterface/core/baserecording.py | 47 +++++++- src/spikeinterface/core/testing_tools.py | 2 +- .../core/tests/test_sortinganalyzer.py | 2 +- .../core/tests/test_zarrextractors.py | 43 ++++++- src/spikeinterface/core/zarr_tools.py | 26 ++++ src/spikeinterface/core/zarrextractors.py | 114 ++++++++++-------- .../preprocessing/tests/test_scaling.py | 2 +- 8 files changed, 197 insertions(+), 81 deletions(-) create mode 100644 src/spikeinterface/core/zarr_tools.py diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 9dc270d38d..c16cfe80e3 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -873,6 +873,8 @@ def save(self, **kwargs) -> BaseExtractor: * dump_ext: "json" or "pkl", default "json" (if format is "folder") * verbose: if True output is verbose * **save_kwargs: additional kwargs format-dependent and job kwargs for recording + (check `save_to_memory()`, `save_to_folder()`, `save_to_zarr()` for more details on format-dependent + kwargs) {} Returns @@ -892,13 +894,27 @@ def save(self, **kwargs) -> BaseExtractor: save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc) def save_to_memory(self, sharedmem=True, **save_kwargs) -> BaseExtractor: + """ + Save the object to memory. + + Parameters + ---------- + sharedmem : bool, default: True + If True, the object is saved to shared memory, allowing it to be accessed by multiple processes without + copying. If False, the object is saved to regular memory, which may involve copying when accessed by + multiple processes. + + Returns + ------- + BaseExtractor + A saved copy of the extractor in memory. + """ save_kwargs.pop("format", None) cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs) self.copy_metadata(cached) return cached - # TODO rename to saveto_binary_folder def save_to_folder( self, name: str | None = None, @@ -944,8 +960,7 @@ def save_to_folder( If True, an existing folder at the specified path will be deleted before saving. verbose : bool, default: True If True, print information about the cache folder being used. - **save_kwargs - Additional keyword arguments to be passed to the underlying save method. + {} Returns ------- @@ -1010,7 +1025,6 @@ def save_to_zarr( folder=None, overwrite=False, storage_options=None, - channel_chunk_size=None, verbose=True, **save_kwargs, ): @@ -1030,26 +1044,9 @@ def save_to_zarr( storage_options: dict or None, default: None Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. For cloud storage locations, this should not be None (in case of default values, use an empty dict) - channel_chunk_size: int or None, default: None - Channels per chunk (only for BaseRecording) - compressor: numcodecs.Codec or None, default: None - Global compressor. If None, Blosc-zstd, level 5, with bit shuffle is used - filters: list[numcodecs.Codec] or None, default: None - Global filters for zarr (global) - compressor_by_dataset: dict or None, default: None - Optional compressor per dataset: - - traces - - times - If None, the global compressor is used - filters_by_dataset: dict or None, default: None - Optional filters per dataset: - - traces - - times - If None, the global filters are used verbose: bool, default: True If True, the output is verbose - auto_cast_uint: bool, default: True - If True, unsigned integers are cast to signed integers to avoid issues with zarr (only for BaseRecording) + {} Returns ------- @@ -1085,7 +1082,6 @@ def save_to_zarr( assert not zarr_path.exists(), f"Path {zarr_path} already exists, choose another name" save_kwargs["zarr_path"] = zarr_path save_kwargs["storage_options"] = storage_options - save_kwargs["channel_chunk_size"] = channel_chunk_size cached = self._save(format="zarr", verbose=verbose, **save_kwargs) cached = read_zarr(zarr_path) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 75bd47597b..a2ccb62937 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -5,10 +5,10 @@ import numpy as np from probeinterface import read_probeinterface, write_probeinterface -from .base import BaseSegment +from .base import BaseSegment, BaseExtractor from .baserecordingsnippets import BaseRecordingSnippets from .core_tools import convert_bytes_to_str, convert_seconds_to_str -from .job_tools import split_job_kwargs +from .job_tools import split_job_kwargs, _shared_job_kwargs_doc from .recording_tools import write_binary_recording @@ -39,6 +39,41 @@ class BaseRecording(BaseRecordingSnippets): "noise_level_rms_scaled", ] + _save_to_folder_docs_params = """dtype: np.dtype | None, default: None + The dtype to use for saving the binary file. If None, the dtype of the recording is used. +""" + _shared_job_kwargs_doc + + _save_to_zarr_docs_params = """ +channel_chunk_size: int | None, default: None + Chunk size for the channel dimension. If None, no chunking is done on the channel dimension. +chunks: tuple | None, default: None + Chunks for the traces dataset. If None, no chunking is done. Note that sharding requires chunking to be specified + and that chunk dimensions need to be larger than shard dimensions (if shards is not None). + If `chunks` is not None, it needs to be a tuple of length 2 with the chunk size for the time and channel + dimensions respectively and `channel_chunk_size` should not be specified. +shard_factor: int | None, default: None + If specified, the shard size will be set to chunk_size * shard_factor in the first dimension (time), + and to be the at most the total number of channels in the second dimension. Note that `shard_factor` cannot + be specified together with `shards`. +shards: tuple | None, default: None + Number of shard size. If None, no sharding is done. Note that shards dimensions need to be larger than + chunk dimensions (if chunks is not None) and that sharding is only done on the first dimension. +compressors: list[numcodecs.Codec] | None, default: None + Global compressor. If None, Blosc-zstd, level 5, with bit shuffle is used +filters: list[numcodecs.Codec] | None, default: None + Global filters for zarr (global) +compressors_by_dataset: dict | None, default: None + Optional compressor per dataset: + - traces + - times + If None, the global compressor is used +filters_by_dataset: dict | None, default: None + Optional filters per dataset: + - traces + - times + If None, the global filters are used +""" + _shared_job_kwargs_doc + def __init__(self, sampling_frequency: float, channel_ids: list, dtype): BaseRecordingSnippets.__init__( self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype @@ -592,8 +627,8 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): if format == "binary": folder = kwargs["folder"] - file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() + file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] t_starts = self._get_t_starts() write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) @@ -904,6 +939,12 @@ def astype(self, dtype, round: bool | None = None): return astype(self, dtype=dtype, round=round) +BaseRecording.save_to_folder.__doc__ = BaseExtractor.save_to_folder.__doc__.format( + BaseRecording._save_to_folder_docs_params +) +BaseRecording.save_to_zarr.__doc__ = BaseExtractor.save_to_zarr.__doc__.format(BaseRecording._save_to_zarr_docs_params) + + class BaseRecordingSegment(BaseSegment): """ Abstract class representing a multichannel timeseries, or block of raw ephys traces diff --git a/src/spikeinterface/core/testing_tools.py b/src/spikeinterface/core/testing_tools.py index 899aa3852f..0169d5e50e 100644 --- a/src/spikeinterface/core/testing_tools.py +++ b/src/spikeinterface/core/testing_tools.py @@ -1,7 +1,7 @@ import warnings warnings.warn( - "The 'testing_tools' submodule is deprecated. " "Use spikeinterface.core.generate instead", + "The 'testing_tools' submodule is deprecated. Use spikeinterface.core.testing instead", DeprecationWarning, stacklevel=2, ) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index c280fda672..7f4820d2e8 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -17,7 +17,7 @@ AnalyzerExtension, _sort_extensions_by_dependency, ) -from spikeinterface.core.zarrextractors import check_compressors_match +from spikeinterface.core.zarr_utils import check_compressors_match from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension # to test basespikevectorextension with node pipeline diff --git a/src/spikeinterface/core/tests/test_zarrextractors.py b/src/spikeinterface/core/tests/test_zarrextractors.py index a52d456594..0eedb645c8 100644 --- a/src/spikeinterface/core/tests/test_zarrextractors.py +++ b/src/spikeinterface/core/tests/test_zarrextractors.py @@ -10,10 +10,11 @@ generate_sorting, load, ) +from spikeinterface.core.testing import check_recordings_equal +from spikeinterface.core.zarr_tools import check_compressors_match from spikeinterface.core.zarrextractors import ( add_sorting_to_zarr_group, get_default_zarr_compressor, - check_compressors_match, ) @@ -80,6 +81,46 @@ def test_ZarrSortingExtractor(tmp_path): sorting = load(sorting.to_dict()) +def test_sharding_options(tmp_path): + recording = generate_recording(durations=[10], num_channels=20) + folder = tmp_path / "zarr_sharding" + + # explicitly specify chunks and shards + ZarrRecordingExtractor.write_recording(recording, folder, chunks=(1000, 5), shards=(5000, 10), n_jobs=2) + recording_zarr = ZarrRecordingExtractor(folder) + assert recording_zarr._root["traces_seg0"].chunks == (1000, 5) + assert recording_zarr._root["traces_seg0"].shards == (5000, 10) + check_recordings_equal(recording, recording_zarr) + + # specify shard_factor and chunk_size + folder = tmp_path / "zarr_sharding_factor" + ZarrRecordingExtractor.write_recording( + recording, folder, chunk_size=1000, channel_chunk_size=2, shard_factor=5, n_jobs=2 + ) + recording_zarr = ZarrRecordingExtractor(folder) + assert recording_zarr._root["traces_seg0"].chunks == (1000, 2) + assert recording_zarr._root["traces_seg0"].shards == (5000, 10) + check_recordings_equal(recording, recording_zarr) + + # raise error if both shards and shard_factor are provided + with pytest.raises(ValueError): + ZarrRecordingExtractor.write_recording( + recording, folder, chunk_size=1000, channel_chunk_size=2, shard_factor=5, shards=(5000, 10), n_jobs=2 + ) + + # raise error if shards is smaller than chunks + with pytest.raises(AssertionError): + ZarrRecordingExtractor.write_recording( + recording, folder, chunk_size=1000, channel_chunk_size=2, shards=(500, 10), n_jobs=2 + ) + + # raise error if shards is not a multiple of chunks + with pytest.raises(AssertionError): + ZarrRecordingExtractor.write_recording( + recording, folder, chunk_size=1000, channel_chunk_size=2, shards=(5500, 10), n_jobs=2 + ) + + if __name__ == "__main__": tmp_path = Path("tmp") test_zarr_compression_options(tmp_path) diff --git a/src/spikeinterface/core/zarr_tools.py b/src/spikeinterface/core/zarr_tools.py new file mode 100644 index 0000000000..d97630535e --- /dev/null +++ b/src/spikeinterface/core/zarr_tools.py @@ -0,0 +1,26 @@ +def check_compressors_match(comp1, comp2, skip_typesize=True): + """ + Check if two compressor objects match. + + Parameters + ---------- + comp1 : zarr.Codec | Tuple[zarr.Codec] + The first compressor object to compare. + comp2 : zarr.Codec | Tuple[zarr.Codec] + The second compressor object to compare. + skip_typesize : bool, optional + Whether to skip the typesize check, default: True + """ + if not isinstance(comp1, (list, tuple)): + assert not isinstance(comp2, list) + comp1 = [comp1] + comp2 = [comp2] + for i in range(len(comp1)): + comp1_dict = comp1[i].to_dict() + comp2_dict = comp2[i].to_dict() + if skip_typesize: + if "typesize" in comp1_dict["configuration"]: + comp1_dict["configuration"].pop("typesize", None) + if "typesize" in comp2_dict["configuration"]: + comp2_dict["configuration"].pop("typesize", None) + assert comp1_dict == comp2_dict, f"Compressor {i} does not match: {comp1_dict} != {comp2_dict}" diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 4c3058cde0..a7c6b8fc7d 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -5,12 +5,11 @@ from probeinterface import ProbeGroup -from .base import minimum_spike_dtype +from .base import minimum_spike_dtype, _get_class_from_string from .baserecording import BaseRecording, BaseRecordingSegment from .basesorting import BaseSorting, SpikeVectorSortingSegment -from .core_tools import define_function_from_class, check_json, retrieve_importing_provenance -from .job_tools import split_job_kwargs -from .core_tools import is_path_remote +from .core_tools import define_function_from_class, check_json, is_path_remote, retrieve_importing_provenance +from .job_tools import split_job_kwargs, fix_job_kwargs, ensure_chunk_size, ChunkRecordingExecutor zarr.config.set({"default_zarr_version": 3}) @@ -96,34 +95,6 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d return root -def check_compressors_match(comp1, comp2, skip_typesize=True): - """ - Check if two compressor objects match. - - Parameters - ---------- - comp1 : zarr.Codec | Tuple[zarr.Codec] - The first compressor object to compare. - comp2 : zarr.Codec | Tuple[zarr.Codec] - The second compressor object to compare. - skip_typesize : bool, optional - Whether to skip the typesize check, default: True - """ - if not isinstance(comp1, (list, tuple)): - assert not isinstance(comp2, list) - comp1 = [comp1] - comp2 = [comp2] - for i in range(len(comp1)): - comp1_dict = comp1[i].to_dict() - comp2_dict = comp2[i].to_dict() - if skip_typesize: - if "typesize" in comp1_dict["configuration"]: - comp1_dict["configuration"].pop("typesize", None) - if "typesize" in comp2_dict["configuration"]: - comp2_dict["configuration"].pop("typesize", None) - assert comp1_dict == comp2_dict - - class ZarrRecordingExtractor(BaseRecording): """ RecordingExtractor for a zarr format @@ -513,7 +484,7 @@ def build_codec_pipeline(filters=None, compressors=None): codec_kwargs = {} codec_kwargs["filters"] = filters - codec_kwargs["serializer"] = serializers[0] + codec_kwargs["serializer"] = serializers[0] if len(serializers) == 1 else "auto" codec_kwargs["compressors"] = byte_compressors return codec_kwargs @@ -576,14 +547,22 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.Group, **kw if compressor is None: compressor = get_default_zarr_compressor() - # save sub fields + # Save sub fields of spikes as separate arrays to allow for more efficient compression and to + # avoid issues with structured arrays with unicode fields in zarr v3. + # The "segment_index" field is saved as "segment_slices" which contains the start and end indices of spikes for + # each segment, to avoid having a large array of segment indices when there are many spikes. spikes_group = zarr_group.create_group(name="spikes") spikes = sorting.to_spike_vector() for field in spikes.dtype.fields: if field != "segment_index": dtype = spikes[field].dtype spikes_data = spikes[field] - codec_kwargs = build_codec_pipeline(filters=[Delta(dtype=spikes[field].dtype.str)], compressors=compressor) + if field == "sample_index": + # Delta filter is very effective for spike times (sample_index) + filters = [Delta(dtype=spikes[field].dtype.str)] + else: + filters = None + codec_kwargs = build_codec_pipeline(filters=filters, compressors=compressor) spikes_group.create_array(name=field, data=spikes_data, **codec_kwargs) else: segment_slices = [] @@ -599,6 +578,7 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.Group, **kw # Recording def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.Group, verbose=False, dtype=None, **kwargs): zarr_kwargs, job_kwargs = split_job_kwargs(kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) if recording.check_if_json_serializable(): zarr_group.attrs["provenance"] = check_json(recording.to_dict(recursive=True)) @@ -614,17 +594,53 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.Group arr[:] = channel_ids dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())] + num_channels = recording.get_num_channels() dtype = recording.get_dtype() if dtype is None else dtype - channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None) + + # Compressors and filters global_compressor = kwargs.get("compressors") or kwargs.get("compressor") if global_compressor is None: global_compressor = get_default_zarr_compressor() compressor_by_dataset = zarr_kwargs.pop("compressor_by_dataset", {}) global_filters = zarr_kwargs.pop("filters", None) filters_by_dataset = zarr_kwargs.pop("filters_by_dataset", {}) - compressor_traces = compressor_by_dataset.get("traces", global_compressor) filters_traces = filters_by_dataset.get("traces", global_filters) + + # Chunking and sharding + chunks = zarr_kwargs.get("chunks", None) + channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None) + shards = zarr_kwargs.get("shards", None) + shard_factor = zarr_kwargs.get("shard_factor", None) + if shards is not None and shard_factor is not None: + raise ValueError("Cannot specify both 'shards' and 'shard_factor' in zarr_kwargs") + if chunks is not None and channel_chunk_size is not None: + raise ValueError("Cannot specify both 'chunks' and 'channel_chunk_size' in zarr_kwargs") + + # If not specified by chunk, we set the chunk size in the first dimension (time) to be the chunk size that we use + # for the job executor, and the chunk size in the second dimension (channels) to be either the provided + # channel_chunk_size or the total number of channels (no chunking in channels). + if chunks is not None: + job_kwargs["chunk_size"] = chunks[0] + else: + chunk_size = ensure_chunk_size(recording, **job_kwargs) + chunks = (chunk_size, channel_chunk_size if channel_chunk_size is not None else num_channels) + + if shards is not None: + assert len(shards) == len(chunks), "Shards and chunks must have the same number of dimensions" + for dim in range(len(chunks)): + assert ( + shards[dim] >= chunks[dim] and shards[dim] % chunks[dim] == 0 + ), "Shard size must be a multiple of chunk size" + # When sharding is used, chunk_size in job_kwargs is used to determine the number of samples per chunk to + # write in each job. Each process will write all chunks in a shard. + job_kwargs["chunk_size"] = shards[0] + elif shard_factor is not None: + # If shard_factor is provided, we set the shard size to be chunk_size * shard_factor in the first dimension (time), + # and to be the at most the total number of channels in the second dimension. + shards = (chunks[0] * shard_factor, min(chunks[1] * shard_factor, num_channels)) + job_kwargs["chunk_size"] = shards[0] + add_traces_to_zarr( recording=recording, zarr_group=zarr_group, @@ -632,7 +648,8 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.Group compressors=compressor_traces, filters=filters_traces, dtype=dtype, - channel_chunk_size=channel_chunk_size, + chunks=chunks, + shards=shards, verbose=verbose, **job_kwargs, ) @@ -667,7 +684,8 @@ def add_traces_to_zarr( recording, zarr_group, dataset_paths, - channel_chunk_size=None, + chunks=None, + shards=None, dtype=None, compressors=None, filters=None, @@ -685,8 +703,10 @@ def add_traces_to_zarr( The zarr group to add traces to dataset_paths : list List of paths to traces datasets in the zarr group - channel_chunk_size : int or None, default: None (chunking in time only) + chunks : tuple or None, default: None (chunking in time only) Channels per chunk + shards : tuple or None, default: None + If not None, a tuple of (time, num_chunks_per_shard) to dtype : dtype, default: None Type of the saved data compressors : zarr compressor or None, default: None @@ -697,12 +717,6 @@ def add_traces_to_zarr( If True, output is verbose (when chunks are used) {} """ - from .job_tools import ( - ensure_chunk_size, - fix_job_kwargs, - ChunkRecordingExecutor, - ) - assert dataset_paths is not None, "Provide 'file_path'" if not isinstance(dataset_paths, list): @@ -712,9 +726,6 @@ def add_traces_to_zarr( if dtype is None: dtype = recording.get_dtype() - job_kwargs = fix_job_kwargs(job_kwargs) - chunk_size = ensure_chunk_size(recording, **job_kwargs) - codec_kwargs = build_codec_pipeline(filters=filters, compressors=compressors) # create zarr datasets files @@ -725,8 +736,9 @@ def add_traces_to_zarr( dset_name = dataset_paths[segment_index] shape = (num_frames, num_channels) # In zarr v3, chunks must be a tuple of integers (no None allowed) - chunks = (chunk_size, channel_chunk_size if channel_chunk_size is not None else num_channels) - dset = zarr_group.create_array(name=dset_name, shape=shape, chunks=chunks, dtype=dtype, **codec_kwargs) + dset = zarr_group.create_array( + name=dset_name, shape=shape, chunks=chunks, shards=shards, dtype=dtype, **codec_kwargs + ) zarr_datasets.append(dset) # synchronizer=zarr.ThreadSynchronizer()) diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index a19d116b16..cc06d88960 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -1,6 +1,6 @@ import pytest import numpy as np -from spikeinterface.core.testing_tools import generate_recording +from spikeinterface.core.testing import generate_recording from spikeinterface.preprocessing.preprocessing_classes import scale_to_uV, CenterRecording, scale_to_physical_units From f1737863e694b327d1b2925d8d5f06ff2411bb67 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 25 Mar 2026 14:53:42 +0100 Subject: [PATCH 08/10] wrong imports --- src/spikeinterface/core/tests/test_sortinganalyzer.py | 2 +- src/spikeinterface/preprocessing/tests/test_scaling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 7f4820d2e8..2912d4f5a1 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -17,7 +17,7 @@ AnalyzerExtension, _sort_extensions_by_dependency, ) -from spikeinterface.core.zarr_utils import check_compressors_match +from spikeinterface.core.zarr_tools import check_compressors_match from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension # to test basespikevectorextension with node pipeline diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index cc06d88960..27f1de8542 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -1,6 +1,6 @@ import pytest import numpy as np -from spikeinterface.core.testing import generate_recording +from spikeinterface.core.generate import generate_recording from spikeinterface.preprocessing.preprocessing_classes import scale_to_uV, CenterRecording, scale_to_physical_units From 42bb52157c5f675036b8634a30f97fb0c1399205 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Mar 2026 09:25:53 +0100 Subject: [PATCH 09/10] fix: channel_ids dtype --- src/spikeinterface/core/zarrextractors.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index a7c6b8fc7d..b50ac66021 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -590,8 +590,7 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.Group zarr_group.attrs["num_segments"] = int(recording.get_num_segments()) # Use variable-length UTF-8 (stable zarr v3 spec) instead of fixed-length unicode. channel_ids = recording.get_channel_ids() - arr = zarr_group.create_array(name="channel_ids", shape=channel_ids.shape, dtype=str, compressors=None) - arr[:] = channel_ids + arr = zarr_group.create_array(name="channel_ids", data=channel_ids, compressors=None) dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())] num_channels = recording.get_num_channels() From 0de855bc4a22f683531acecb963f96cf5da4448b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 26 Mar 2026 17:40:16 +0100 Subject: [PATCH 10/10] Fix NWB-zarr tests --- src/spikeinterface/extractors/nwbextractors.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index b89999d088..c5c7639209 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -307,8 +307,8 @@ def _get_backend_from_local_file(file_path: str | Path) -> str: try: import zarr - with zarr.open(file_path, "r") as f: - backend = "zarr" + _ = zarr.open(file_path, mode="r") + backend = "zarr" except: raise RuntimeError(f"{file_path} is not a valid Zarr folder!") else: @@ -333,7 +333,8 @@ def _find_neurodata_type_from_backend(group, path="", result=None, neurodata_typ if result is None: result = [] - for neurodata_name, value in group.items(): + for neurodata_name in group.keys(): + value = group[neurodata_name] # Check if it's a group and if it has the neurodata_type if isinstance(value, group_class): current_path = f"{path}/{neurodata_name}" if path else neurodata_name @@ -1409,7 +1410,8 @@ def _find_timeseries_from_backend(group, path="", result=None, backend="hdf5"): if result is None: result = [] - for name, value in group.items(): + for name in group.keys(): + value = group[name] if isinstance(value, group_class): current_path = f"{path}/{name}" if path else name if value.attrs.get("neurodata_type") == "TimeSeries":