diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..1ae1162a31 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -2,7 +2,8 @@ from abc import abstractmethod from collections.abc import Mapping -from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable from typing_extensions import ReadOnly, TypedDict @@ -18,8 +19,9 @@ from zarr.abc.store import ByteGetter, ByteSetter, Store from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import ChunkGrid + from zarr.core.codec_pipeline import ChunkTransform, ReadChunkRequest, WriteChunkRequest from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType - from zarr.core.indexing import SelectorTuple + from zarr.core.indexing import ChunkProjection, SelectorTuple from zarr.core.metadata import ArrayMetadata __all__ = [ @@ -32,6 +34,9 @@ "CodecInput", "CodecOutput", "CodecPipeline", + "PreparedWrite", + "SupportsChunkCodec", + "SupportsSyncCodec", ] CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer) @@ -59,6 +64,36 @@ def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]: """The widest type of JSON-like input that could specify a codec.""" +@runtime_checkable +class SupportsSyncCodec(Protocol): + """Protocol for codecs that support synchronous encode/decode. + + Codecs implementing this protocol provide ``_decode_sync`` and ``_encode_sync`` + methods that perform encoding/decoding without requiring an async event loop. + """ + + def _decode_sync( + self, chunk_data: NDBuffer | Buffer, chunk_spec: ArraySpec + ) -> NDBuffer | Buffer: ... + + def _encode_sync( + self, chunk_data: NDBuffer | Buffer, chunk_spec: ArraySpec + ) -> NDBuffer | Buffer | None: ... + + +class SupportsChunkCodec(Protocol): + """Protocol for objects that can decode/encode whole chunks synchronously. + + [`ChunkTransform`][zarr.core.codec_pipeline.ChunkTransform] satisfies this protocol. + """ + + array_spec: ArraySpec + + def decode_chunk(self, chunk_bytes: Buffer) -> NDBuffer: ... + + def encode_chunk(self, chunk_array: NDBuffer) -> Buffer | None: ... + + class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]): """Generic base class for codecs. @@ -186,9 +221,314 @@ class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]): """Base class for array-to-array codecs.""" +@dataclass +class PreparedWrite: + """Result of the prepare phase of a write operation. + + Carries deserialized chunk data and selection metadata between + [`prepare_write`][zarr.abc.codec.ArrayBytesCodec.prepare_write] (or + [`prepare_write_sync`][zarr.abc.codec.ArrayBytesCodec.prepare_write_sync]) + and [`finalize_write`][zarr.abc.codec.ArrayBytesCodec.finalize_write] (or + [`finalize_write_sync`][zarr.abc.codec.ArrayBytesCodec.finalize_write_sync]). + + Attributes + ---------- + chunk_dict : dict[tuple[int, ...], Buffer | None] + Per-inner-chunk buffers keyed by chunk coordinates. + indexer : list[ChunkProjection] + Mapping from inner-chunk coordinates to value/output selections. + """ + + chunk_dict: dict[tuple[int, ...], Buffer | None] + indexer: list[ChunkProjection] + + class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): """Base class for array-to-bytes codecs.""" + def deserialize( + self, raw: Buffer | None, chunk_spec: ArraySpec + ) -> dict[tuple[int, ...], Buffer | None]: + """Unpack stored bytes into per-inner-chunk buffers. + + The default implementation returns a single-entry dict keyed at + ``(0,)``. ``ShardingCodec`` overrides this to decode the shard index + and split the blob into per-chunk buffers. + + Parameters + ---------- + raw : Buffer or None + The raw bytes read from the store, or ``None`` if the key was + absent. + chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + + Returns + ------- + dict[tuple[int, ...], Buffer | None] + Mapping from inner-chunk coordinates to their encoded bytes. + """ + return {(0,): raw} + + def serialize( + self, + chunk_dict: dict[tuple[int, ...], Buffer | None], + chunk_spec: ArraySpec, + ) -> Buffer | None: + """Pack per-inner-chunk buffers into a storage blob. + + The default implementation returns the single entry at ``(0,)``. + ``ShardingCodec`` overrides this to concatenate chunks and build a + shard index. + + Parameters + ---------- + chunk_dict : dict[tuple[int, ...], Buffer | None] + Mapping from inner-chunk coordinates to their encoded bytes. + chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + + Returns + ------- + Buffer or None + The serialized blob, or ``None`` when all chunks are empty + (the caller should delete the key). + """ + return chunk_dict.get((0,)) + + # ------------------------------------------------------------------ + # prepare / finalize — sync + # ------------------------------------------------------------------ + + def prepare_read_sync( + self, + byte_getter: Any, + chunk_selection: SelectorTuple, + codec_chain: SupportsChunkCodec, + ) -> NDBuffer | None: + """Read a chunk from the store synchronously, decode it, and + return the selected region. + + Parameters + ---------- + byte_getter : Any + An object supporting ``get_sync`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + chunk_selection : SelectorTuple + Selection within the decoded chunk array. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] used to + decode the chunk. Must carry an ``array_spec`` attribute. + + Returns + ------- + NDBuffer or None + The decoded chunk data at *chunk_selection*, or ``None`` if the + chunk does not exist in the store. + """ + raw = byte_getter.get_sync(prototype=codec_chain.array_spec.prototype) + if raw is None: + return None + chunk_array = codec_chain.decode_chunk(raw) + return chunk_array[chunk_selection] + + def prepare_write_sync( + self, + byte_setter: Any, + codec_chain: SupportsChunkCodec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + replace: bool, + ) -> PreparedWrite: + """Prepare a synchronous write by optionally reading existing data. + + When *replace* is ``False``, the existing chunk bytes are fetched + from the store so they can be merged with the new data. When + *replace* is ``True``, the fetch is skipped. + + Parameters + ---------- + byte_setter : Any + An object supporting ``get_sync`` and ``set_sync`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] + carrying the ``array_spec`` for the chunk. + chunk_selection : SelectorTuple + Selection within the chunk being written. + out_selection : SelectorTuple + Corresponding selection within the source value array. + replace : bool + If ``True``, the write replaces all data in the chunk and no + read-modify-write is needed. If ``False``, existing data is + fetched first. + + Returns + ------- + PreparedWrite + A [`PreparedWrite`][zarr.abc.codec.PreparedWrite] carrying the + deserialized chunk data and selection metadata. + """ + chunk_spec = codec_chain.array_spec + existing: Buffer | None = None + if not replace: + existing = byte_setter.get_sync(prototype=chunk_spec.prototype) + chunk_dict = self.deserialize(existing, chunk_spec) + return PreparedWrite( + chunk_dict=chunk_dict, + indexer=[ + ( # type: ignore[list-item] + (0,), + chunk_selection, + out_selection, + replace, + ) + ], + ) + + def finalize_write_sync( + self, + prepared: PreparedWrite, + codec_chain: SupportsChunkCodec, + byte_setter: Any, + ) -> None: + """Serialize the prepared chunk data and write it to the store. + + If serialization produces ``None`` (all chunks empty), the key is + deleted instead. + + Parameters + ---------- + prepared : PreparedWrite + The [`PreparedWrite`][zarr.abc.codec.PreparedWrite] returned by + [`prepare_write_sync`][zarr.abc.codec.ArrayBytesCodec.prepare_write_sync]. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] + carrying the ``array_spec`` for the chunk. + byte_setter : Any + An object supporting ``set_sync`` and ``delete_sync`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + """ + blob = self.serialize(prepared.chunk_dict, codec_chain.array_spec) + if blob is None: + byte_setter.delete_sync() + else: + byte_setter.set_sync(blob) + + # ------------------------------------------------------------------ + # prepare / finalize — async + # ------------------------------------------------------------------ + + async def prepare_read( + self, + byte_getter: Any, + chunk_selection: SelectorTuple, + codec_chain: SupportsChunkCodec, + ) -> NDBuffer | None: + """Async variant of + [`prepare_read_sync`][zarr.abc.codec.ArrayBytesCodec.prepare_read_sync]. + + Parameters + ---------- + byte_getter : Any + An object supporting ``get`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + chunk_selection : SelectorTuple + Selection within the decoded chunk array. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] used to + decode the chunk. Must carry an ``array_spec`` attribute. + + Returns + ------- + NDBuffer or None + The decoded chunk data at *chunk_selection*, or ``None`` if the + chunk does not exist in the store. + """ + raw = await byte_getter.get(prototype=codec_chain.array_spec.prototype) + if raw is None: + return None + chunk_array = codec_chain.decode_chunk(raw) + return chunk_array[chunk_selection] + + async def prepare_write( + self, + byte_setter: Any, + codec_chain: SupportsChunkCodec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + replace: bool, + ) -> PreparedWrite: + """Async variant of + [`prepare_write_sync`][zarr.abc.codec.ArrayBytesCodec.prepare_write_sync]. + + Parameters + ---------- + byte_setter : Any + An object supporting ``get`` and ``set`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] + carrying the ``array_spec`` for the chunk. + chunk_selection : SelectorTuple + Selection within the chunk being written. + out_selection : SelectorTuple + Corresponding selection within the source value array. + replace : bool + If ``True``, the write replaces all data in the chunk and no + read-modify-write is needed. If ``False``, existing data is + fetched first. + + Returns + ------- + PreparedWrite + A [`PreparedWrite`][zarr.abc.codec.PreparedWrite] carrying the + deserialized chunk data and selection metadata. + """ + chunk_spec = codec_chain.array_spec + existing: Buffer | None = None + if not replace: + existing = await byte_setter.get(prototype=chunk_spec.prototype) + chunk_dict = self.deserialize(existing, chunk_spec) + return PreparedWrite( + chunk_dict=chunk_dict, + indexer=[ + ( # type: ignore[list-item] + (0,), + chunk_selection, + out_selection, + replace, + ) + ], + ) + + async def finalize_write( + self, + prepared: PreparedWrite, + codec_chain: SupportsChunkCodec, + byte_setter: Any, + ) -> None: + """Async variant of + [`finalize_write_sync`][zarr.abc.codec.ArrayBytesCodec.finalize_write_sync]. + + Parameters + ---------- + prepared : PreparedWrite + The [`PreparedWrite`][zarr.abc.codec.PreparedWrite] returned by + [`prepare_write`][zarr.abc.codec.ArrayBytesCodec.prepare_write]. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] + carrying the ``array_spec`` for the chunk. + byte_setter : Any + An object supporting ``set`` and ``delete`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + """ + blob = self.serialize(prepared.chunk_dict, codec_chain.array_spec) + if blob is None: + await byte_setter.delete() + else: + await byte_setter.set(blob) + class BytesBytesCodec(BaseCodec[Buffer, Buffer]): """Base class for bytes-to-bytes codecs.""" @@ -371,6 +711,20 @@ def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: """ ... + @abstractmethod + def get_chunk_transform(self, array_spec: ArraySpec) -> ChunkTransform: + """Creates a ChunkTransform for the given array spec. + + Parameters + ---------- + array_spec : ArraySpec + + Returns + ------- + ChunkTransform + """ + ... + @abstractmethod async def decode( self, @@ -412,7 +766,7 @@ async def encode( @abstractmethod async def read( self, - batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ReadChunkRequest], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -421,12 +775,10 @@ async def read( Parameters ---------- - batch_info : Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]] - Ordered set of information about the chunks. - The first slice selection determines which parts of the chunk will be fetched. - The second slice selection determines where in the output array the chunk data will be written. - The ByteGetter is used to fetch the necessary bytes. - The chunk spec contains information about the construction of an array from the bytes. + batch_info : Iterable[ReadChunkRequest] + Ordered set of read requests. Each carries a ``byte_getter``, + a ``ChunkTransform`` (codec chain + spec), and chunk/output + selections. If the Store returns ``None`` for a chunk, then the chunk was not written and the implementation must set the values of that chunk (or @@ -439,7 +791,7 @@ async def read( @abstractmethod async def write( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[WriteChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -449,16 +801,37 @@ async def write( Parameters ---------- - batch_info : Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]] - Ordered set of information about the chunks. - The first slice selection determines which parts of the chunk will be encoded. - The second slice selection determines where in the value array the chunk data is located. - The ByteSetter is used to fetch and write the necessary bytes. - The chunk spec contains information about the chunk. + batch_info : Iterable[WriteChunkRequest] + Ordered set of write requests. Each carries a ``byte_setter``, + a ``ChunkTransform`` (codec chain + spec), chunk/output + selections, and whether the chunk is complete. value : NDBuffer """ ... + @property + def supports_sync_io(self) -> bool: + """Whether this pipeline can run read/write entirely on the calling thread.""" + return False + + def read_sync( + self, + batch_info: Iterable[ReadChunkRequest], + out: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + """Synchronous read: fetch bytes from store, decode, scatter into *out*.""" + raise NotImplementedError + + def write_sync( + self, + batch_info: Iterable[WriteChunkRequest], + value: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + """Synchronous write: gather from *value*, encode, persist to store.""" + raise NotImplementedError + async def _batching_helper( func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]], diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 87df89a683..1d321f5fc3 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -16,7 +16,17 @@ from zarr.core.buffer import Buffer, BufferPrototype -__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"] +__all__ = [ + "ByteGetter", + "ByteSetter", + "Store", + "SupportsDeleteSync", + "SupportsGetSync", + "SupportsSetRangeSync", + "SupportsSetSync", + "SupportsSyncStore", + "set_or_delete", +] @dataclass @@ -700,6 +710,38 @@ async def delete(self) -> None: ... async def set_if_not_exists(self, default: Buffer) -> None: ... +@runtime_checkable +class SupportsGetSync(Protocol): + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: ... + + +@runtime_checkable +class SupportsSetSync(Protocol): + def set_sync(self, key: str, value: Buffer) -> None: ... + + +@runtime_checkable +class SupportsSetRangeSync(Protocol): + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: ... + + +@runtime_checkable +class SupportsDeleteSync(Protocol): + def delete_sync(self, key: str) -> None: ... + + +@runtime_checkable +class SupportsSyncStore( + SupportsGetSync, SupportsSetSync, SupportsSetRangeSync, SupportsDeleteSync, Protocol +): ... + + async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None: """Set or delete a value in a byte setter diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index 5b91cfa005..d05731d640 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -299,13 +299,27 @@ def _blosc_codec(self) -> Blosc: config_dict["typesize"] = self.typesize return Blosc.from_config(config_dict) + def _decode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer: + return as_numpy_array_wrapper(self._blosc_codec.decode, chunk_bytes, chunk_spec.prototype) + async def _decode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer: - return await asyncio.to_thread( - as_numpy_array_wrapper, self._blosc_codec.decode, chunk_bytes, chunk_spec.prototype + return await asyncio.to_thread(self._decode_sync, chunk_bytes, chunk_spec) + + def _encode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return chunk_spec.prototype.buffer.from_bytes( + self._blosc_codec.encode(chunk_bytes.as_numpy_array()) ) async def _encode_single( @@ -313,14 +327,7 @@ async def _encode_single( chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - # Since blosc only support host memory, we convert the input and output of the encoding - # between numpy array and buffer - return await asyncio.to_thread( - lambda chunk: chunk_spec.prototype.buffer.from_bytes( - self._blosc_codec.encode(chunk.as_numpy_array()) - ), - chunk_bytes, - ) + return await asyncio.to_thread(self._encode_sync, chunk_bytes, chunk_spec) def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index 1fbdeef497..86bb354fb5 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -65,7 +65,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: ) return self - async def _decode_single( + def _decode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -88,7 +88,14 @@ async def _decode_single( ) return chunk_array - async def _encode_single( + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_bytes, chunk_spec) + + def _encode_sync( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, @@ -109,5 +116,12 @@ async def _encode_single( nd_array = nd_array.ravel().view(dtype="B") return chunk_spec.prototype.buffer.from_array_like(nd_array) + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return self._encode_sync(chunk_array, chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/src/zarr/codecs/crc32c_.py b/src/zarr/codecs/crc32c_.py index 9536d0d558..ebe2ac8f7a 100644 --- a/src/zarr/codecs/crc32c_.py +++ b/src/zarr/codecs/crc32c_.py @@ -31,7 +31,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: def to_dict(self) -> dict[str, JSON]: return {"name": "crc32c"} - async def _decode_single( + def _decode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -51,7 +51,14 @@ async def _decode_single( ) return chunk_spec.prototype.buffer.from_array_like(inner_bytes) - async def _encode_single( + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer: + return self._decode_sync(chunk_bytes, chunk_spec) + + def _encode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -64,5 +71,12 @@ async def _encode_single( # Append the checksum (as bytes) to the data return chunk_spec.prototype.buffer.from_array_like(np.append(data, checksum.view("B"))) + async def _encode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return self._encode_sync(chunk_bytes, chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length + 4 diff --git a/src/zarr/codecs/gzip.py b/src/zarr/codecs/gzip.py index 610ca9dadd..b8591748f7 100644 --- a/src/zarr/codecs/gzip.py +++ b/src/zarr/codecs/gzip.py @@ -2,6 +2,7 @@ import asyncio from dataclasses import dataclass +from functools import cached_property from typing import TYPE_CHECKING from numcodecs.gzip import GZip @@ -48,23 +49,37 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: def to_dict(self) -> dict[str, JSON]: return {"name": "gzip", "configuration": {"level": self.level}} + @cached_property + def _gzip_codec(self) -> GZip: + return GZip(self.level) + + def _decode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer: + return as_numpy_array_wrapper(self._gzip_codec.decode, chunk_bytes, chunk_spec.prototype) + async def _decode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer: - return await asyncio.to_thread( - as_numpy_array_wrapper, GZip(self.level).decode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._decode_sync, chunk_bytes, chunk_spec) + + def _encode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return as_numpy_array_wrapper(self._gzip_codec.encode, chunk_bytes, chunk_spec.prototype) async def _encode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - return await asyncio.to_thread( - as_numpy_array_wrapper, GZip(self.level).encode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._encode_sync, chunk_bytes, chunk_spec) def compute_encoded_size( self, diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 85162c2f74..e8537a187f 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping, MutableMapping +from collections.abc import Iterable, Mapping from dataclasses import dataclass, replace from enum import Enum from functools import lru_cache @@ -15,11 +15,9 @@ ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin, Codec, - CodecPipeline, ) from zarr.abc.store import ( ByteGetter, - ByteRequest, ByteSetter, RangeByteRequest, SuffixByteRequest, @@ -35,6 +33,7 @@ numpy_buffer_prototype, ) from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid +from zarr.core.codec_pipeline import ChunkTransform, fill_value_or_default from zarr.core.common import ( ShapeLike, parse_enum, @@ -54,7 +53,6 @@ ) from zarr.core.metadata.v3 import parse_codecs from zarr.registry import get_ndbuffer_class, get_pipeline_class -from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: from collections.abc import Iterator @@ -65,7 +63,6 @@ MAX_UINT_64 = 2**64 - 1 ShardMapping = Mapping[tuple[int, ...], Buffer | None] -ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer | None] class ShardingCodecIndexLocation(Enum): @@ -81,41 +78,6 @@ def parse_index_location(data: object) -> ShardingCodecIndexLocation: return parse_enum(data, ShardingCodecIndexLocation) -@dataclass(frozen=True) -class _ShardingByteGetter(ByteGetter): - shard_dict: ShardMapping - chunk_coords: tuple[int, ...] - - async def get( - self, prototype: BufferPrototype, byte_range: ByteRequest | None = None - ) -> Buffer | None: - assert prototype == default_buffer_prototype(), ( - f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}" - ) - value = self.shard_dict.get(self.chunk_coords) - if value is None: - return None - if byte_range is None: - return value - start, stop = _normalize_byte_range_index(value, byte_range) - return value[start:stop] - - -@dataclass(frozen=True) -class _ShardingByteSetter(_ShardingByteGetter, ByteSetter): - shard_dict: ShardMutableMapping - - async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None: - assert byte_range is None, "byte_range is not supported within shards" - self.shard_dict[self.chunk_coords] = value - - async def delete(self) -> None: - del self.shard_dict[self.chunk_coords] - - async def set_if_not_exists(self, default: Buffer) -> None: - self.shard_dict.setdefault(self.chunk_coords, default) - - class _ShardIndex(NamedTuple): # dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2) offsets_and_lengths: npt.NDArray[np.uint64] @@ -354,9 +316,8 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: _, configuration_parsed = parse_named_configuration(data, "sharding_indexed") return cls(**configuration_parsed) # type: ignore[arg-type] - @property - def codec_pipeline(self) -> CodecPipeline: - return get_pipeline_class().from_codecs(self.codecs) + def _get_chunk_transform(self, chunk_spec: ArraySpec) -> ChunkTransform: + return ChunkTransform(codecs=self.codecs, array_spec=chunk_spec) def to_dict(self) -> dict[str, JSON]: return { @@ -430,20 +391,15 @@ async def _decode_single( out.fill(shard_spec.fill_value) return out - # decoding chunks and writing them into the output buffer - await self.codec_pipeline.read( - [ - ( - _ShardingByteGetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - out, - ) + transform = self._get_chunk_transform(chunk_spec) + fill_value = fill_value_or_default(chunk_spec) + for chunk_coords, chunk_selection, out_selection, _is_complete in indexer: + chunk_bytes = shard_dict.get(chunk_coords) + if chunk_bytes is not None: + chunk_array = await transform.decode_chunk_async(chunk_bytes) + out[out_selection] = chunk_array[chunk_selection] + else: + out[out_selection] = fill_value return out @@ -502,20 +458,16 @@ async def _decode_partial_single( if chunk_bytes: shard_dict[chunk_coords] = chunk_bytes - # decoding chunks and writing them into the output buffer - await self.codec_pipeline.read( - [ - ( - _ShardingByteGetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - out, - ) + # decode chunks and write them into the output buffer + transform = self._get_chunk_transform(chunk_spec) + fill_value = fill_value_or_default(chunk_spec) + for chunk_coords, chunk_selection, out_selection, _is_complete in indexed_chunks: + chunk_bytes = shard_dict.get(chunk_coords) + if chunk_bytes is not None: + chunk_array = await transform.decode_chunk_async(chunk_bytes) + out[out_selection] = chunk_array[chunk_selection] + else: + out[out_selection] = fill_value if hasattr(indexer, "sel_shape"): return out.reshape(indexer.sel_shape) @@ -532,29 +484,23 @@ async def _encode_single( chunks_per_shard = self._get_chunks_per_shard(shard_spec) chunk_spec = self._get_chunk_spec(shard_spec) - indexer = list( - BasicIndexer( - tuple(slice(0, s) for s in shard_shape), - shape=shard_shape, - chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), - ) + indexer = BasicIndexer( + tuple(slice(0, s) for s in shard_shape), + shape=shard_shape, + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), ) - shard_builder = dict.fromkeys(morton_order_iter(chunks_per_shard)) + transform = self._get_chunk_transform(chunk_spec) + fill_value = fill_value_or_default(chunk_spec) + shard_builder: dict[tuple[int, ...], Buffer | None] = {} - await self.codec_pipeline.write( - [ - ( - _ShardingByteSetter(shard_builder, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - shard_array, - ) + for chunk_coords, _, out_selection, _is_complete in indexer: + chunk_array = shard_array[out_selection] + if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal(fill_value): + continue + encoded = await transform.encode_chunk_async(chunk_array) + if encoded is not None: + shard_builder[chunk_coords] = encoded return await self._encode_shard_dict( shard_builder, @@ -581,7 +527,9 @@ async def _encode_partial_single( ) shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) # Use vectorized lookup for better performance - shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard))) + shard_dict: dict[tuple[int, ...], Buffer | None] = shard_reader.to_dict_vectorized( + np.asarray(_morton_order(chunks_per_shard)) + ) indexer = list( get_indexer( @@ -589,19 +537,37 @@ async def _encode_partial_single( ) ) - await self.codec_pipeline.write( - [ - ( - _ShardingByteSetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - shard_array, - ) + transform = self._get_chunk_transform(chunk_spec) + fill_value = fill_value_or_default(chunk_spec) + + is_scalar = len(shard_array.shape) == 0 + for chunk_coords, chunk_selection, out_selection, is_complete in indexer: + value = shard_array if is_scalar else shard_array[out_selection] + if is_complete and not is_scalar and value.shape == chunk_spec.shape: + # Complete overwrite with matching shape — use value directly + chunk_data = value + else: + # Read-modify-write: decode existing or create new, merge data + if is_complete: + existing_bytes = None + else: + existing_bytes = shard_dict.get(chunk_coords) + if existing_bytes is not None: + chunk_data = (await transform.decode_chunk_async(existing_bytes)).copy() + else: + chunk_data = chunk_spec.prototype.nd_buffer.create( + shape=chunk_spec.shape, + dtype=chunk_spec.dtype.to_native_dtype(), + order=chunk_spec.order, + fill_value=fill_value, + ) + chunk_data[chunk_selection] = value + + if not chunk_spec.config.write_empty_chunks and chunk_data.all_equal(fill_value): + shard_dict[chunk_coords] = None + else: + shard_dict[chunk_coords] = await transform.encode_chunk_async(chunk_data) + buf = await self._encode_shard_dict( shard_dict, chunks_per_shard=chunks_per_shard, diff --git a/src/zarr/codecs/transpose.py b/src/zarr/codecs/transpose.py index a8570b6e8f..609448a59c 100644 --- a/src/zarr/codecs/transpose.py +++ b/src/zarr/codecs/transpose.py @@ -95,20 +95,34 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: prototype=chunk_spec.prototype, ) - async def _decode_single( + def _decode_sync( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, ) -> NDBuffer: - inverse_order = np.argsort(self.order) + inverse_order = tuple(int(i) for i in np.argsort(self.order)) return chunk_array.transpose(inverse_order) - async def _encode_single( + async def _decode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_array, chunk_spec) + + def _encode_sync( self, chunk_array: NDBuffer, _chunk_spec: ArraySpec, ) -> NDBuffer | None: return chunk_array.transpose(self.order) + async def _encode_single( + self, + chunk_array: NDBuffer, + _chunk_spec: ArraySpec, + ) -> NDBuffer | None: + return self._encode_sync(chunk_array, _chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py index fb1fb76126..a10cb7c335 100644 --- a/src/zarr/codecs/vlen_utf8.py +++ b/src/zarr/codecs/vlen_utf8.py @@ -40,8 +40,7 @@ def to_dict(self) -> dict[str, JSON]: def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self - # TODO: expand the tests for this function - async def _decode_single( + def _decode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -55,7 +54,14 @@ async def _decode_single( as_string_dtype = decoded.astype(chunk_spec.dtype.to_native_dtype(), copy=False) return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype) - async def _encode_single( + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_bytes, chunk_spec) + + def _encode_sync( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, @@ -65,6 +71,13 @@ async def _encode_single( _vlen_utf8_codec.encode(chunk_array.as_numpy_array()) ) + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return self._encode_sync(chunk_array, chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: # what is input_byte_length for an object dtype? raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") @@ -86,7 +99,7 @@ def to_dict(self) -> dict[str, JSON]: def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self - async def _decode_single( + def _decode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -99,7 +112,14 @@ async def _decode_single( decoded = _reshape_view(decoded, chunk_spec.shape) return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded) - async def _encode_single( + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_bytes, chunk_spec) + + def _encode_sync( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, @@ -109,6 +129,13 @@ async def _encode_single( _vlen_bytes_codec.encode(chunk_array.as_numpy_array()) ) + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return self._encode_sync(chunk_array, chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: # what is input_byte_length for an object dtype? raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") diff --git a/src/zarr/codecs/zstd.py b/src/zarr/codecs/zstd.py index 27cc9a7777..f93c25a3c7 100644 --- a/src/zarr/codecs/zstd.py +++ b/src/zarr/codecs/zstd.py @@ -38,7 +38,7 @@ def parse_checksum(data: JSON) -> bool: class ZstdCodec(BytesBytesCodec): """zstd codec""" - is_fixed_size = True + is_fixed_size = False level: int = 0 checksum: bool = False @@ -71,23 +71,33 @@ def _zstd_codec(self) -> Zstd: config_dict = {"level": self.level, "checksum": self.checksum} return Zstd.from_config(config_dict) + def _decode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer: + return as_numpy_array_wrapper(self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype) + async def _decode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer: - return await asyncio.to_thread( - as_numpy_array_wrapper, self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._decode_sync, chunk_bytes, chunk_spec) + + def _encode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return as_numpy_array_wrapper(self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype) async def _encode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - return await asyncio.to_thread( - as_numpy_array_wrapper, self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._encode_sync, chunk_bytes, chunk_spec) def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 564d0e915a..479416277d 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -48,6 +48,7 @@ V2ChunkKeyEncoding, parse_chunk_key_encoding, ) +from zarr.core.codec_pipeline import ReadChunkRequest, WriteChunkRequest from zarr.core.common import ( JSON, ZARR_JSON, @@ -5602,14 +5603,15 @@ async def _get_selection( # reading chunks and decoding them await codec_pipeline.read( [ - ( - store_path / metadata.encode_chunk_key(chunk_coords), - metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype), - chunk_selection, - out_selection, - is_complete_chunk, + ReadChunkRequest( + byte_getter=store_path / metadata.encode_chunk_key(chunk_coords), + transform=codec_pipeline.get_chunk_transform( + metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype) + ), + chunk_selection=chunk_selection, + out_selection=out_selection, ) - for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer + for chunk_coords, chunk_selection, out_selection, _is_complete in indexer ], out_buffer, drop_axes=indexer.drop_axes, @@ -5912,12 +5914,14 @@ async def _set_selection( # merging with existing data and encoding chunks await codec_pipeline.write( [ - ( - store_path / metadata.encode_chunk_key(chunk_coords), - metadata.get_chunk_spec(chunk_coords, _config, prototype), - chunk_selection, - out_selection, - is_complete_chunk, + WriteChunkRequest( + byte_setter=store_path / metadata.encode_chunk_key(chunk_coords), + transform=codec_pipeline.get_chunk_transform( + metadata.get_chunk_spec(chunk_coords, _config, prototype) + ), + chunk_selection=chunk_selection, + out_selection=out_selection, + is_complete_chunk=is_complete_chunk, ) for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer ], diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index fd557ac43e..b2b39e2963 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -1,8 +1,11 @@ from __future__ import annotations -from dataclasses import dataclass +import os +import threading +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field from itertools import islice, pairwise -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast from warnings import warn from zarr.abc.codec import ( @@ -13,15 +16,16 @@ BytesBytesCodec, Codec, CodecPipeline, + SupportsSyncCodec, ) -from zarr.core.common import concurrent_map +from zarr.core.common import concurrent_map, product from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar from zarr.errors import ZarrUserWarning from zarr.registry import register_pipeline if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Callable, Iterable, Iterator from typing import Self from zarr.abc.store import ByteGetter, ByteSetter @@ -68,6 +72,249 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any: return fill_value +# --------------------------------------------------------------------------- +# Thread-pool infrastructure for synchronous codec paths +# --------------------------------------------------------------------------- + +_MIN_CHUNK_NBYTES_FOR_POOL = 100_000 # 100 KB + + +def _get_codec_worker_config() -> tuple[bool, int, int]: + """Read the ``threading.codec_workers`` config. + + Returns + ------- + tuple[bool, int, int] + ``(enabled, min_workers, max_workers)`` + """ + codec_workers = config.get("threading.codec_workers") + enabled: bool = codec_workers.get("enabled", True) + min_workers: int = codec_workers.get("min", 0) + max_workers: int = max(codec_workers.get("max") or os.cpu_count() or 4, min_workers) + return enabled, min_workers, max_workers + + +def _choose_workers(n_chunks: int, chunk_nbytes: int, codecs: Iterable[Codec]) -> int: + """Decide how many thread-pool workers to use (0 = don't use pool).""" + if getattr(_thread_local, "in_pool_worker", False): + return 0 + + enabled, min_workers, max_workers = _get_codec_worker_config() + if not enabled: + return 0 + + if n_chunks < 2: + return min_workers + + if not any(isinstance(c, BytesBytesCodec) for c in codecs) and min_workers == 0: + return 0 + if chunk_nbytes < _MIN_CHUNK_NBYTES_FOR_POOL and min_workers == 0: + return 0 + + return max(min_workers, min(n_chunks, max_workers)) + + +def _get_pool() -> ThreadPoolExecutor: + """Get the module-level thread pool, creating it lazily.""" + global _pool + if _pool is None: + _, _, max_workers = _get_codec_worker_config() + _pool = ThreadPoolExecutor(max_workers=max_workers) + return _pool + + +_pool: ThreadPoolExecutor | None = None +_thread_local = threading.local() + + +def _mark_pool_worker(fn: Callable[..., T]) -> Callable[..., T]: + """Wrap *fn* so that ``_thread_local.in_pool_worker`` is ``True`` while it runs.""" + + def wrapper(*args: Any, **kwargs: Any) -> T: + _thread_local.in_pool_worker = True + try: + return fn(*args, **kwargs) + finally: + _thread_local.in_pool_worker = False + + return wrapper + + +_DELETED = object() + + +@dataclass(slots=True, kw_only=True) +class ChunkTransform: + """A stored chunk, modeled as a layered array. + + Each layer corresponds to one ArrayArrayCodec and the ArraySpec + at its input boundary. ``layers[0]`` is the outermost (user-visible) + transform; after the last layer comes the ArrayBytesCodec. + + The chunk's ``shape`` and ``dtype`` reflect the representation + **after** all ArrayArrayCodec layers have been applied — i.e. the + spec that feeds the ArrayBytesCodec. + """ + + codecs: tuple[Codec, ...] + array_spec: ArraySpec + + # Each element is (ArrayArrayCodec, input_spec_for_that_codec). + layers: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = field( + init=False, repr=False, compare=False + ) + _ab_codec: ArrayBytesCodec = field(init=False, repr=False, compare=False) + _ab_spec: ArraySpec = field(init=False, repr=False, compare=False) + _bb_codecs: tuple[BytesBytesCodec, ...] = field(init=False, repr=False, compare=False) + _all_sync: bool = field(init=False, repr=False, compare=False) + + def __post_init__(self) -> None: + aa, ab, bb = codecs_from_list(list(self.codecs)) + + layers: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = () + spec = self.array_spec + for aa_codec in aa: + layers = (*layers, (aa_codec, spec)) + spec = aa_codec.resolve_metadata(spec) + + self.layers = layers + self._ab_codec = ab + self._ab_spec = spec + self._bb_codecs = bb + self._all_sync = all(isinstance(c, SupportsSyncCodec) for c in self.codecs) + + @property + def shape(self) -> tuple[int, ...]: + """Shape after all ArrayArrayCodec layers (input to the ArrayBytesCodec).""" + return self._ab_spec.shape + + @property + def dtype(self) -> ZDType[TBaseDType, TBaseScalar]: + """Dtype after all ArrayArrayCodec layers (input to the ArrayBytesCodec).""" + return self._ab_spec.dtype + + @property + def all_sync(self) -> bool: + return self._all_sync + + def decode_chunk( + self, + chunk_bytes: Buffer, + ) -> NDBuffer: + """Decode a single chunk through the full codec chain, synchronously. + + Pure compute -- no IO. Only callable when all codecs support sync. + """ + bb_out: Any = chunk_bytes + for bb_codec in reversed(self._bb_codecs): + bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self._ab_spec) + + ab_out: Any = cast("SupportsSyncCodec", self._ab_codec)._decode_sync(bb_out, self._ab_spec) + + for aa_codec, spec in reversed(self.layers): + ab_out = cast("SupportsSyncCodec", aa_codec)._decode_sync(ab_out, spec) + + return ab_out # type: ignore[no-any-return] + + def encode_chunk( + self, + chunk_array: NDBuffer, + ) -> Buffer | None: + """Encode a single chunk through the full codec chain, synchronously. + + Pure compute -- no IO. Only callable when all codecs support sync. + """ + aa_out: Any = chunk_array + + for aa_codec, spec in self.layers: + if aa_out is None: + return None + aa_out = cast("SupportsSyncCodec", aa_codec)._encode_sync(aa_out, spec) + + if aa_out is None: + return None + bb_out: Any = cast("SupportsSyncCodec", self._ab_codec)._encode_sync(aa_out, self._ab_spec) + + for bb_codec in self._bb_codecs: + if bb_out is None: + return None + bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self._ab_spec) + + return bb_out # type: ignore[no-any-return] + + async def decode_chunk_async( + self, + chunk_bytes: Buffer, + ) -> NDBuffer: + """Decode a single chunk through the full codec chain, asynchronously. + + Needed when the codec chain contains async-only codecs (e.g. nested sharding). + """ + bb_out: Any = chunk_bytes + for bb_codec in reversed(self._bb_codecs): + bb_out = await bb_codec._decode_single(bb_out, self._ab_spec) + + ab_out: Any = await self._ab_codec._decode_single(bb_out, self._ab_spec) + + for aa_codec, spec in reversed(self.layers): + ab_out = await aa_codec._decode_single(ab_out, spec) + + return ab_out # type: ignore[no-any-return] + + async def encode_chunk_async( + self, + chunk_array: NDBuffer, + ) -> Buffer | None: + """Encode a single chunk through the full codec chain, asynchronously. + + Needed when the codec chain contains async-only codecs (e.g. nested sharding). + """ + aa_out: Any = chunk_array + + for aa_codec, spec in self.layers: + if aa_out is None: + return None + aa_out = await aa_codec._encode_single(aa_out, spec) + + if aa_out is None: + return None + bb_out: Any = await self._ab_codec._encode_single(aa_out, self._ab_spec) + + for bb_codec in self._bb_codecs: + if bb_out is None: + return None + bb_out = await bb_codec._encode_single(bb_out, self._ab_spec) + + return bb_out # type: ignore[no-any-return] + + def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: + for codec in self.codecs: + byte_length = codec.compute_encoded_size(byte_length, array_spec) + array_spec = codec.resolve_metadata(array_spec) + return byte_length + + +@dataclass(slots=True) +class ReadChunkRequest: + """A request to read and decode a single chunk.""" + + byte_getter: ByteGetter + transform: ChunkTransform + chunk_selection: SelectorTuple + out_selection: SelectorTuple + + +@dataclass(slots=True) +class WriteChunkRequest: + """A request to encode and write a single chunk.""" + + byte_setter: ByteSetter + transform: ChunkTransform + chunk_selection: SelectorTuple + out_selection: SelectorTuple + is_complete_chunk: bool + + @dataclass(frozen=True) class BatchedCodecPipeline(CodecPipeline): """Default codec pipeline. @@ -133,6 +380,9 @@ def __iter__(self) -> Iterator[Codec]: yield self.array_bytes_codec yield from self.bytes_bytes_codecs + def get_chunk_transform(self, array_spec: ArraySpec) -> ChunkTransform: + return ChunkTransform(codecs=tuple(self), array_spec=array_spec) + def validate( self, *, @@ -248,48 +498,50 @@ async def encode_partial_batch( async def read_batch( self, - batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ReadChunkRequest], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: + batch_info = list(batch_info) if self.supports_partial_decode: chunk_array_batch = await self.decode_partial_batch( [ - (byte_getter, chunk_selection, chunk_spec) - for byte_getter, chunk_spec, chunk_selection, *_ in batch_info + (req.byte_getter, req.chunk_selection, req.transform.array_spec) + for req in batch_info ] ) - for chunk_array, (_, chunk_spec, _, out_selection, _) in zip( - chunk_array_batch, batch_info, strict=False - ): + for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False): if chunk_array is not None: - out[out_selection] = chunk_array + out[req.out_selection] = chunk_array else: - out[out_selection] = fill_value_or_default(chunk_spec) + out[req.out_selection] = fill_value_or_default(req.transform.array_spec) else: chunk_bytes_batch = await concurrent_map( - [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], + [(req.byte_getter, req.transform.array_spec.prototype) for req in batch_info], lambda byte_getter, prototype: byte_getter.get(prototype), config.get("async.concurrency"), ) - chunk_array_batch = await self.decode_batch( - [ - (chunk_bytes, chunk_spec) - for chunk_bytes, (_, chunk_spec, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) - ], + + async def _decode_one( + chunk_bytes: Buffer | None, req: ReadChunkRequest + ) -> NDBuffer | None: + if chunk_bytes is None: + return None + return await req.transform.decode_chunk_async(chunk_bytes) + + chunk_array_batch = await concurrent_map( + list(zip(chunk_bytes_batch, batch_info, strict=False)), + _decode_one, + config.get("async.concurrency"), ) - for chunk_array, (_, chunk_spec, chunk_selection, out_selection, _) in zip( - chunk_array_batch, batch_info, strict=False - ): + for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False): if chunk_array is not None: - tmp = chunk_array[chunk_selection] + tmp = chunk_array[req.chunk_selection] if drop_axes != (): tmp = tmp.squeeze(axis=drop_axes) - out[out_selection] = tmp + out[req.out_selection] = tmp else: - out[out_selection] = fill_value_or_default(chunk_spec) + out[req.out_selection] = fill_value_or_default(req.transform.array_spec) def _merge_chunk_array( self, @@ -298,13 +550,11 @@ def _merge_chunk_array( out_selection: SelectorTuple, chunk_spec: ArraySpec, chunk_selection: SelectorTuple, - is_complete_chunk: bool, drop_axes: tuple[int, ...], ) -> NDBuffer: if ( - is_complete_chunk + existing_chunk_array is None and value.shape == chunk_spec.shape - # Guard that this is not a partial chunk at the end with is_complete_chunk=True and value[out_selection].shape == chunk_spec.shape ): return value @@ -337,24 +587,35 @@ def _merge_chunk_array( async def write_batch( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[WriteChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: + batch_info = list(batch_info) if self.supports_partial_encode: # Pass scalar values as is if len(value.shape) == 0: await self.encode_partial_batch( [ - (byte_setter, value, chunk_selection, chunk_spec) - for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info + ( + req.byte_setter, + value, + req.chunk_selection, + req.transform.array_spec, + ) + for req in batch_info ], ) else: await self.encode_partial_batch( [ - (byte_setter, value[out_selection], chunk_selection, chunk_spec) - for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info + ( + req.byte_setter, + value[req.out_selection], + req.chunk_selection, + req.transform.array_spec, + ) + for req in batch_info ], ) @@ -371,45 +632,42 @@ async def _read_key( chunk_bytes_batch = await concurrent_map( [ ( - None if is_complete_chunk else byte_setter, - chunk_spec.prototype, + None if req.is_complete_chunk else req.byte_setter, + req.transform.array_spec.prototype, ) - for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info + for req in batch_info ], _read_key, config.get("async.concurrency"), ) - chunk_array_decoded = await self.decode_batch( - [ - (chunk_bytes, chunk_spec) - for chunk_bytes, (_, chunk_spec, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) - ], + + async def _decode_one( + chunk_bytes: Buffer | None, req: WriteChunkRequest + ) -> NDBuffer | None: + if chunk_bytes is None: + return None + return await req.transform.decode_chunk_async(chunk_bytes) + + chunk_array_decoded = await concurrent_map( + list(zip(chunk_bytes_batch, batch_info, strict=False)), + _decode_one, + config.get("async.concurrency"), ) chunk_array_merged = [ self._merge_chunk_array( chunk_array, value, - out_selection, - chunk_spec, - chunk_selection, - is_complete_chunk, + req.out_selection, + req.transform.array_spec, + req.chunk_selection, drop_axes, ) - for chunk_array, ( - _, - chunk_spec, - chunk_selection, - out_selection, - is_complete_chunk, - ) in zip(chunk_array_decoded, batch_info, strict=False) + for chunk_array, req in zip(chunk_array_decoded, batch_info, strict=False) ] chunk_array_batch: list[NDBuffer | None] = [] - for chunk_array, (_, chunk_spec, *_) in zip( - chunk_array_merged, batch_info, strict=False - ): + for chunk_array, req in zip(chunk_array_merged, batch_info, strict=False): + chunk_spec = req.transform.array_spec if chunk_array is None: chunk_array_batch.append(None) # type: ignore[unreachable] else: @@ -420,13 +678,17 @@ async def _read_key( else: chunk_array_batch.append(chunk_array) - chunk_bytes_batch = await self.encode_batch( - [ - (chunk_array, chunk_spec) - for chunk_array, (_, chunk_spec, *_) in zip( - chunk_array_batch, batch_info, strict=False - ) - ], + async def _encode_one( + chunk_array: NDBuffer | None, req: WriteChunkRequest + ) -> Buffer | None: + if chunk_array is None: + return None + return await req.transform.encode_chunk_async(chunk_array) + + chunk_bytes_batch = await concurrent_map( + list(zip(chunk_array_batch, batch_info, strict=False)), + _encode_one, + config.get("async.concurrency"), ) async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> None: @@ -437,10 +699,8 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non await concurrent_map( [ - (byte_setter, chunk_bytes) - for chunk_bytes, (byte_setter, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) + (req.byte_setter, chunk_bytes) + for chunk_bytes, req in zip(chunk_bytes_batch, batch_info, strict=False) ], _write_key, config.get("async.concurrency"), @@ -466,7 +726,7 @@ async def encode( async def read( self, - batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ReadChunkRequest], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -481,7 +741,7 @@ async def read( async def write( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[WriteChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -494,6 +754,208 @@ async def write( config.get("async.concurrency"), ) + # ------------------------------------------------------------------ + # Synchronous read / write + # ------------------------------------------------------------------ + + @property + def supports_sync_io(self) -> bool: + return all(isinstance(c, SupportsSyncCodec) for c in self) + + def _scatter( + self, + chunk_arrays: list[NDBuffer | None], + batch_info_list: list[ReadChunkRequest], + out: NDBuffer, + drop_axes: tuple[int, ...], + ) -> None: + """Assign decoded chunk arrays into the output buffer.""" + for chunk_array, req in zip(chunk_arrays, batch_info_list, strict=False): + if chunk_array is not None: + tmp = chunk_array[req.chunk_selection] + if drop_axes != (): + tmp = tmp.squeeze(axis=drop_axes) + out[req.out_selection] = tmp + else: + out[req.out_selection] = fill_value_or_default(req.transform.array_spec) + + def read_sync( + self, + batch_info: Iterable[ReadChunkRequest], + out: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + batch_info_list = list(batch_info) + if not batch_info_list: + return + + first_spec = batch_info_list[0].transform.array_spec + chunk_nbytes = product(first_spec.shape) * first_spec.dtype.to_native_dtype().itemsize + n_workers = _choose_workers(len(batch_info_list), chunk_nbytes, self) + + if n_workers > 0: + # Threaded: fetch all bytes, decode in parallel, scatter. + chunk_bytes_list: list[Buffer | None] = [ + req.byte_getter.get_sync(prototype=req.transform.array_spec.prototype) # type: ignore[attr-defined] + for req in batch_info_list + ] + pool = _get_pool() + chunk_arrays: list[NDBuffer | None] = list( + pool.map( + _mark_pool_worker( + lambda cb, req: req.transform.decode_chunk(cb) if cb is not None else None + ), + chunk_bytes_list, + batch_info_list, + ) + ) + self._scatter(chunk_arrays, batch_info_list, out, drop_axes) + else: + # Non-threaded: prepare_read_sync handles IO + decode per chunk. + ab_codec = self.array_bytes_codec + for req in batch_info_list: + result = ab_codec.prepare_read_sync( + req.byte_getter, + req.chunk_selection, + req.transform, + ) + if result is not None: + if drop_axes != (): + result = result.squeeze(axis=drop_axes) + out[req.out_selection] = result + else: + out[req.out_selection] = fill_value_or_default(req.transform.array_spec) + + @staticmethod + def _write_chunk_compute( + existing_bytes: Buffer | None, + req: WriteChunkRequest, + value: NDBuffer, + drop_axes: tuple[int, ...], + ) -> Buffer | None | object: + """Per-chunk compute for the threaded write path. + + Returns encoded bytes, ``None`` for a no-op, or ``_DELETED`` + to signal that the key should be removed. + """ + chunk_spec = req.transform.array_spec + existing_chunk_array: NDBuffer | None = None + if existing_bytes is not None: + existing_chunk_array = req.transform.decode_chunk(existing_bytes) + + # Merge + if ( + existing_chunk_array is None + and value.shape == chunk_spec.shape + and value[req.out_selection].shape == chunk_spec.shape + ): + merged = value + else: + if existing_chunk_array is None: + chunk_array = chunk_spec.prototype.nd_buffer.create( + shape=chunk_spec.shape, + dtype=chunk_spec.dtype.to_native_dtype(), + order=chunk_spec.order, + fill_value=fill_value_or_default(chunk_spec), + ) + else: + chunk_array = existing_chunk_array.copy() + + if req.chunk_selection == () or is_scalar( + value.as_ndarray_like(), chunk_spec.dtype.to_native_dtype() + ): + chunk_value = value + else: + chunk_value = value[req.out_selection] + if drop_axes != (): + item = tuple( + None if idx in drop_axes else slice(None) for idx in range(chunk_spec.ndim) + ) + chunk_value = chunk_value[item] + chunk_array[req.chunk_selection] = chunk_value + merged = chunk_array + + # Check write_empty_chunks + if not chunk_spec.config.write_empty_chunks and merged.all_equal( + fill_value_or_default(chunk_spec) + ): + return _DELETED + + return req.transform.encode_chunk(merged) + + def write_sync( + self, + batch_info: Iterable[WriteChunkRequest], + value: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + batch_info_list = list(batch_info) + if not batch_info_list: + return + + first_spec = batch_info_list[0].transform.array_spec + chunk_nbytes = product(first_spec.shape) * first_spec.dtype.to_native_dtype().itemsize + n_workers = _choose_workers(len(batch_info_list), chunk_nbytes, self) + + if n_workers > 0: + # Threaded: fetch existing, compute in parallel, write results. + existing_bytes_list: list[Buffer | None] = [ + req.byte_setter.get_sync(prototype=req.transform.array_spec.prototype) # type: ignore[attr-defined] + if not req.is_complete_chunk + else None + for req in batch_info_list + ] + pool = _get_pool() + n = len(batch_info_list) + encoded_list: list[Buffer | None | object] = list( + pool.map( + _mark_pool_worker(self._write_chunk_compute), + existing_bytes_list, + batch_info_list, + [value] * n, + [drop_axes] * n, + ) + ) + for encoded, req in zip(encoded_list, batch_info_list, strict=False): + if encoded is _DELETED: + req.byte_setter.delete_sync() # type: ignore[attr-defined] + elif encoded is not None: + req.byte_setter.set_sync(encoded) # type: ignore[attr-defined] + else: + # Non-threaded: prepare/compute/finalize per chunk. + ab_codec = self.array_bytes_codec + for req in batch_info_list: + prepared = ab_codec.prepare_write_sync( + req.byte_setter, + req.transform, + req.chunk_selection, + req.out_selection, + req.is_complete_chunk, + ) + for coords, chunk_sel, out_sel, _is_complete in prepared.indexer: + existing_inner = prepared.chunk_dict.get(coords) + if existing_inner is not None: + existing_array = req.transform.decode_chunk(existing_inner) + else: + existing_array = None + merged = self._merge_chunk_array( + existing_array, + value, + out_sel, + req.transform.array_spec, + chunk_sel, + drop_axes, + ) + inner_spec = req.transform.array_spec + if not inner_spec.config.write_empty_chunks and merged.all_equal( + fill_value_or_default(inner_spec) + ): + prepared.chunk_dict[coords] = None + else: + prepared.chunk_dict[coords] = req.transform.encode_chunk(merged) + + ab_codec.finalize_write_sync(prepared, req.transform, req.byte_setter) + def codecs_from_list( codecs: Iterable[Codec], diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index f8f8ea4f5f..e7d9e7ec20 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -99,7 +99,10 @@ def enable_gpu(self) -> ConfigSet: "target_shard_size_bytes": None, }, "async": {"concurrency": 10, "timeout": None}, - "threading": {"max_workers": None}, + "threading": { + "max_workers": None, + "codec_workers": {"enabled": True, "min": 0, "max": None}, + }, "json_indent": 2, "codec_pipeline": { "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 454f7e2290..27e05740d9 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -373,7 +373,6 @@ class ChunkDimProjection(NamedTuple): dim_chunk_ix: int dim_chunk_sel: Selector dim_out_sel: Selector | None - is_complete_chunk: bool @dataclass(frozen=True) @@ -393,8 +392,7 @@ def __iter__(self) -> Iterator[ChunkDimProjection]: dim_offset = dim_chunk_ix * self.dim_chunk_len dim_chunk_sel = self.dim_sel - dim_offset dim_out_sel = None - is_complete_chunk = self.dim_chunk_len == 1 - yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) @dataclass(frozen=True) @@ -468,10 +466,7 @@ def __iter__(self) -> Iterator[ChunkDimProjection]: dim_out_sel = slice(dim_out_offset, dim_out_offset + dim_chunk_nitems) - is_complete_chunk = ( - dim_chunk_sel_start == 0 and (self.stop >= dim_limit) and self.step in [1, None] - ) - yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) def check_selection_length(selection: SelectionNormalized, shape: tuple[int, ...]) -> None: @@ -531,6 +526,35 @@ def ensure_tuple(v: Any) -> SelectionNormalized: return cast("SelectionNormalized", v) +def _is_complete_chunk( + dim_indexers: Sequence[ + IntDimIndexer | SliceDimIndexer | BoolArrayDimIndexer | IntArrayDimIndexer + ], + dim_projections: tuple[ChunkDimProjection, ...], +) -> bool: + """Return True if the combined chunk selection covers the full actual extent. + + The actual extent of each chunk dimension accounts for edge chunks + (where the array length is not a multiple of the chunk length). + """ + for dim_indexer, p in zip(dim_indexers, dim_projections, strict=True): + chunk_ix = p.dim_chunk_ix + actual_extent = ( + min(dim_indexer.dim_len, (chunk_ix + 1) * dim_indexer.dim_chunk_len) + - chunk_ix * dim_indexer.dim_chunk_len + ) + s = p.dim_chunk_sel + if isinstance(s, slice): + if not (s.start in (0, None) and s.stop >= actual_extent and s.step in (1, None)): + return False + elif isinstance(dim_indexer, IntDimIndexer): + if actual_extent != 1: + return False + else: + return False + return True + + class ChunkProjection(NamedTuple): """A mapping of items from chunk to output array. Can be used to extract items from the chunk array for loading into an output array. Can also be used to extract items from a @@ -544,8 +568,8 @@ class ChunkProjection(NamedTuple): Selection of items from chunk array. out_selection Selection of items in target (output) array. - is_complete_chunk: - True if a complete chunk is indexed + is_complete_chunk + True if the selection covers the full actual extent of the chunk. """ chunk_coords: tuple[int, ...] @@ -627,8 +651,8 @@ def __iter__(self) -> Iterator[ChunkProjection]: out_selection = tuple( p.dim_out_sel for p in dim_projections if p.dim_out_sel is not None ) - is_complete_chunk = all(p.is_complete_chunk for p in dim_projections) - yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete_chunk) + is_complete = _is_complete_chunk(self.dim_indexers, dim_projections) + yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete) @dataclass(frozen=True) @@ -696,9 +720,8 @@ def __iter__(self) -> Iterator[ChunkDimProjection]: start = self.chunk_nitems_cumsum[dim_chunk_ix - 1] stop = self.chunk_nitems_cumsum[dim_chunk_ix] dim_out_sel = slice(start, stop) - is_complete_chunk = False # TODO - yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) class Order(Enum): @@ -838,8 +861,7 @@ def __iter__(self) -> Iterator[ChunkDimProjection]: # find region in chunk dim_offset = dim_chunk_ix * self.dim_chunk_len dim_chunk_sel = self.dim_sel[start:stop] - dim_offset - is_complete_chunk = False # TODO - yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) def slice_to_range(s: slice, length: int) -> range: @@ -976,8 +998,8 @@ def __iter__(self) -> Iterator[ChunkProjection]: if not is_basic_selection(out_selection): out_selection = ix_(out_selection, self.shape) - is_complete_chunk = all(p.is_complete_chunk for p in dim_projections) - yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete_chunk) + is_complete = _is_complete_chunk(self.dim_indexers, dim_projections) + yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete) @dataclass(frozen=True) @@ -1106,8 +1128,8 @@ def __iter__(self) -> Iterator[ChunkProjection]: out_selection = tuple( p.dim_out_sel for p in dim_projections if p.dim_out_sel is not None ) - is_complete_chunk = all(p.is_complete_chunk for p in dim_projections) - yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete_chunk) + is_complete = _is_complete_chunk(self.dim_indexers, dim_projections) + yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete) @dataclass(frozen=True) @@ -1274,8 +1296,9 @@ def __iter__(self) -> Iterator[ChunkProjection]: for (dim_sel, dim_chunk_offset) in zip(self.selection, chunk_offsets, strict=True) ) - is_complete_chunk = False # TODO - yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete_chunk) + yield ChunkProjection( + chunk_coords, chunk_selection, out_selection, is_complete_chunk=False + ) @dataclass(frozen=True) @@ -1512,54 +1535,48 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]: out.flags.writeable = False return out - # Optimization: Remove singleton dimensions to enable magic number usage - # for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand. - singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1) - if singleton_dims: - squeezed_shape = tuple(s for s in chunk_shape if s != 1) - if squeezed_shape: - # Compute Morton order on squeezed shape, then expand singleton dims (always 0) - squeezed_order = np.asarray(_morton_order(squeezed_shape)) - out = np.zeros((n_total, n_dims), dtype=np.intp) - squeezed_col = 0 - for full_col in range(n_dims): - if chunk_shape[full_col] != 1: - out[:, full_col] = squeezed_order[:, squeezed_col] - squeezed_col += 1 - else: - # All dimensions are singletons, just return the single point - out = np.zeros((1, n_dims), dtype=np.intp) - out.flags.writeable = False - return out - - # Find the largest power-of-2 hypercube that fits within chunk_shape. - # Within this hypercube, Morton codes are guaranteed to be in bounds. - min_dim = min(chunk_shape) - if min_dim >= 1: - power = min_dim.bit_length() - 1 # floor(log2(min_dim)) - hypercube_size = 1 << power # 2^power - n_hypercube = hypercube_size**n_dims + # Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span + # all valid coordinates in chunk_shape. (c-1).bit_length() gives the number + # of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits + # is the size of this hypercube. + total_bits = sum((c - 1).bit_length() for c in chunk_shape) + n_z = 1 << total_bits if total_bits > 0 else 1 + + # Decode all Morton codes in the ceiling hypercube, then filter to valid coords. + # This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33): + # n_z=262144, n_total=35937), consider the argsort strategy below. + order: npt.NDArray[np.intp] + if n_z <= 4 * n_total: + # Ceiling strategy: decode all n_z codes vectorized, filter in-bounds. + # Works well when the overgeneration ratio n_z/n_total is small (≤4). + z_values = np.arange(n_z, dtype=np.intp) + all_coords = decode_morton_vectorized(z_values, chunk_shape) + shape_arr = np.array(chunk_shape, dtype=np.intp) + valid_mask = np.all(all_coords < shape_arr, axis=1) + order = all_coords[valid_mask] else: - n_hypercube = 0 + # Argsort strategy: enumerate all n_total valid coordinates directly, + # encode each to a Morton code, then sort by code. Avoids the 8x or + # larger overgeneration penalty for near-miss shapes like (33,33,33). + # Cost: O(n_total * bits) encode + O(n_total log n_total) sort, + # vs O(n_z * bits) = O(8 * n_total * bits) for ceiling. + grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in chunk_shape], indexing="ij") + all_coords = np.stack([g.ravel() for g in grids], axis=1) + + # Encode all coordinates to Morton codes (vectorized). + bits_per_dim = tuple((c - 1).bit_length() for c in chunk_shape) + max_coord_bits = max(bits_per_dim) + z_codes = np.zeros(n_total, dtype=np.intp) + output_bit = 0 + for coord_bit in range(max_coord_bits): + for dim in range(n_dims): + if coord_bit < bits_per_dim[dim]: + z_codes |= ((all_coords[:, dim] >> coord_bit) & 1) << output_bit + output_bit += 1 + + sort_idx: npt.NDArray[np.intp] = np.argsort(z_codes, kind="stable") + order = np.asarray(all_coords[sort_idx], dtype=np.intp) - # Within the hypercube, no bounds checking needed - use vectorized decoding - if n_hypercube > 0: - z_values = np.arange(n_hypercube, dtype=np.intp) - order: npt.NDArray[np.intp] = decode_morton_vectorized(z_values, chunk_shape) - else: - order = np.empty((0, n_dims), dtype=np.intp) - - # For remaining elements outside the hypercube, bounds checking is needed - remaining: list[tuple[int, ...]] = [] - i = n_hypercube - while len(order) + len(remaining) < n_total: - m = decode_morton(i, chunk_shape) - if all(x < y for x, y in zip(m, chunk_shape, strict=False)): - remaining.append(m) - i += 1 - - if remaining: - order = np.vstack([order, np.array(remaining, dtype=np.intp)]) order.flags.writeable = False return order diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 4bea04f024..c14aa1a37d 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -228,6 +228,33 @@ async def is_empty(self) -> bool: """ return await self.store.is_empty(self.path) + # ------------------------------------------------------------------- + # Synchronous IO delegation + # ------------------------------------------------------------------- + + def get_sync( + self, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + """Synchronous read — delegates to ``self.store.get_sync(self.path, ...)``.""" + if prototype is None: + prototype = default_buffer_prototype() + return self.store.get_sync(self.path, prototype=prototype, byte_range=byte_range) # type: ignore[attr-defined, no-any-return] + + def set_sync(self, value: Buffer) -> None: + """Synchronous write — delegates to ``self.store.set_sync(self.path, value)``.""" + self.store.set_sync(self.path, value) # type: ignore[attr-defined] + + def set_range_sync(self, value: Buffer, start: int) -> None: + """Synchronous byte-range write.""" + self.store.set_range_sync(self.path, value, start) # type: ignore[attr-defined] + + def delete_sync(self) -> None: + """Synchronous delete — delegates to ``self.store.delete_sync(self.path)``.""" + self.store.delete_sync(self.path) # type: ignore[attr-defined] + def __truediv__(self, other: str) -> StorePath: """Combine this store path with another path""" return self.__class__(self.store, _dereference_path(self.path, other)) diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 80233a112d..dc05ff67a7 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -85,6 +85,19 @@ def _put(path: Path, value: Buffer, exclusive: bool = False) -> int: return f.write(view) +def _put_range(path: Path, value: Buffer, start: int) -> None: + view = value.as_buffer_like() + file_size = path.stat().st_size + if start + len(view) > file_size: + raise ValueError( + f"set_range would write beyond the end of the stored value: " + f"start={start}, len(value)={len(view)}, stored size={file_size}" + ) + with path.open("r+b") as f: + f.seek(start) + f.write(view) + + class LocalStore(Store): """ Store for the local file system. @@ -187,6 +200,62 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + # ------------------------------------------------------------------- + # Synchronous store methods + # ------------------------------------------------------------------- + + def _ensure_open_sync(self) -> None: + if not self._is_open: + if not self.read_only: + self.root.mkdir(parents=True, exist_ok=True) + if not self.root.exists(): + raise FileNotFoundError(f"{self.root} does not exist") + self._is_open = True + + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + if prototype is None: + prototype = default_buffer_prototype() + self._ensure_open_sync() + assert isinstance(key, str) + path = self.root / key + try: + return _get(path, prototype, byte_range) + except (FileNotFoundError, IsADirectoryError, NotADirectoryError): + return None + + def set_sync(self, key: str, value: Buffer) -> None: + self._ensure_open_sync() + self._check_writable() + assert isinstance(key, str) + if not isinstance(value, Buffer): + raise TypeError( + f"LocalStore.set(): `value` must be a Buffer instance. " + f"Got an instance of {type(value)} instead." + ) + path = self.root / key + _put(path, value) + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + self._ensure_open_sync() + self._check_writable() + path = self.root / key + _put_range(path, value, start) + + def delete_sync(self, key: str) -> None: + self._ensure_open_sync() + self._check_writable() + path = self.root / key + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink(missing_ok=True) + async def get( self, key: str, diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index e6f9b7a512..168f3e8890 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -77,6 +77,49 @@ def __eq__(self, other: object) -> bool: and self.read_only == other.read_only ) + # ------------------------------------------------------------------- + # Synchronous store methods + # ------------------------------------------------------------------- + + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + if prototype is None: + prototype = default_buffer_prototype() + if not self._is_open: + self._is_open = True + assert isinstance(key, str) + try: + value = self._store_dict[key] + start, stop = _normalize_byte_range_index(value, byte_range) + return prototype.buffer.from_buffer(value[start:stop]) + except KeyError: + return None + + def set_sync(self, key: str, value: Buffer) -> None: + self._check_writable() + if not self._is_open: + self._is_open = True + assert isinstance(key, str) + if not isinstance(value, Buffer): + raise TypeError( + f"MemoryStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." + ) + self._store_dict[key] = value + + def delete_sync(self, key: str) -> None: + self._check_writable() + if not self._is_open: + self._is_open = True + try: + del self._store_dict[key] + except KeyError: + logger.debug("Key %s does not exist.", key) + async def get( self, key: str, @@ -122,7 +165,6 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None raise TypeError( f"MemoryStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." ) - if byte_range is not None: buf = self._store_dict[key] buf[byte_range[0] : byte_range[1]] = value @@ -130,6 +172,26 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None else: self._store_dict[key] = value + def _set_range_impl(self, key: str, value: Buffer, start: int) -> None: + buf = self._store_dict[key] + target = buf.as_numpy_array() + if start + len(value) > len(target): + raise ValueError( + f"set_range would write beyond the end of the stored value: " + f"start={start}, len(value)={len(value)}, stored size={len(target)}" + ) + if not target.flags.writeable: + target = target.copy() + self._store_dict[key] = buf.__class__(target) + target[start : start + len(value)] = value.as_numpy_array() + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + """Synchronous byte-range write.""" + self._check_writable() + if not self._is_open: + self._is_open = True + self._set_range_impl(key, value, start) + async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited self._check_writable() diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 1b8e85ed98..bb60cc371f 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -6,12 +6,13 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Generic, Self, TypeVar +import numpy as np + from zarr.storage import WrapperStore if TYPE_CHECKING: from typing import Any - from zarr.abc.store import ByteRequest from zarr.core.buffer.core import BufferPrototype import pytest @@ -22,6 +23,10 @@ RangeByteRequest, Store, SuffixByteRequest, + SupportsDeleteSync, + SupportsGetSync, + SupportsSetRangeSync, + SupportsSetSync, ) from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.sync import _collect_aiterator, sync @@ -39,6 +44,34 @@ class StoreTests(Generic[S, B]): store_cls: type[S] buffer_cls: type[B] + @staticmethod + def _require_get_sync(store: S) -> SupportsGetSync: + """Skip unless *store* implements :class:`SupportsGetSync`.""" + if not isinstance(store, SupportsGetSync): + pytest.skip("store does not implement SupportsGetSync") + return store # type: ignore[unreachable] + + @staticmethod + def _require_set_sync(store: S) -> SupportsSetSync: + """Skip unless *store* implements :class:`SupportsSetSync`.""" + if not isinstance(store, SupportsSetSync): + pytest.skip("store does not implement SupportsSetSync") + return store # type: ignore[unreachable] + + @staticmethod + def _require_set_range_sync(store: S) -> SupportsSetRangeSync: + """Skip unless *store* implements :class:`SupportsSetRangeSync`.""" + if not isinstance(store, SupportsSetRangeSync): + pytest.skip("store does not implement SupportsSetRangeSync") + return store # type: ignore[unreachable] + + @staticmethod + def _require_delete_sync(store: S) -> SupportsDeleteSync: + """Skip unless *store* implements :class:`SupportsDeleteSync`.""" + if not isinstance(store, SupportsDeleteSync): + pytest.skip("store does not implement SupportsDeleteSync") + return store # type: ignore[unreachable] + @abstractmethod async def set(self, store: S, key: str, value: Buffer) -> None: """ @@ -579,6 +612,107 @@ def test_get_json_sync(self, store: S) -> None: sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) assert store._get_json_sync(key, prototype=default_buffer_prototype()) == data + # ------------------------------------------------------------------- + # Synchronous store methods (SupportsSyncStore protocol) + # ------------------------------------------------------------------- + + def test_get_sync(self, store: S) -> None: + getter = self._require_get_sync(store) + data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") + key = "sync_get" + sync(self.set(store, key, data_buf)) + result = getter.get_sync(key) + assert result is not None + assert_bytes_equal(result, data_buf) + + def test_get_sync_missing(self, store: S) -> None: + getter = self._require_get_sync(store) + result = getter.get_sync("nonexistent") + assert result is None + + def test_set_sync(self, store: S) -> None: + setter = self._require_set_sync(store) + data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") + key = "sync_set" + setter.set_sync(key, data_buf) + result = sync(self.get(store, key)) + assert_bytes_equal(result, data_buf) + + def test_delete_sync(self, store: S) -> None: + setter = self._require_set_sync(store) + deleter = self._require_delete_sync(store) + getter = self._require_get_sync(store) + if not store.supports_deletes: + pytest.skip("store does not support deletes") + data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") + key = "sync_delete" + setter.set_sync(key, data_buf) + deleter.delete_sync(key) + result = getter.get_sync(key) + assert result is None + + def test_delete_sync_missing(self, store: S) -> None: + deleter = self._require_delete_sync(store) + if not store.supports_deletes: + pytest.skip("store does not support deletes") + # should not raise + deleter.delete_sync("nonexistent_sync") + + # ------------------------------------------------------------------- + # set_range (sync only — set_range is exclusively a sync-path API) + # ------------------------------------------------------------------- + + def test_set_range_sync(self, store: S) -> None: + setter = self._require_set_sync(store) + ranger = self._require_set_range_sync(store) + getter = self._require_get_sync(store) + data_buf = self.buffer_cls.from_bytes(b"hello world") + key = "range_sync_key" + setter.set_sync(key, data_buf) + patch = default_buffer_prototype().buffer.from_bytes(b"WORLD") + ranger.set_range_sync(key, patch, 6) + result = getter.get_sync(key) + assert result is not None + assert result.to_bytes() == b"hello WORLD" + + def test_set_range_sync_preserves_other_bytes(self, store: S) -> None: + setter = self._require_set_sync(store) + ranger = self._require_set_range_sync(store) + getter = self._require_get_sync(store) + data = np.arange(100, dtype="uint8") + data_buf = default_buffer_prototype().buffer.from_array_like(data) + key = "range_preserve" + setter.set_sync(key, data_buf) + patch = np.full(10, 255, dtype="uint8") + patch_buf = default_buffer_prototype().buffer.from_array_like(patch) + ranger.set_range_sync(key, patch_buf, 50) + result = getter.get_sync(key) + assert result is not None + result_arr = np.frombuffer(result.to_bytes(), dtype="uint8") + expected = data.copy() + expected[50:60] = 255 + np.testing.assert_array_equal(result_arr, expected) + + def test_set_range_sync_beyond_end_raises(self, store: S) -> None: + setter = self._require_set_sync(store) + ranger = self._require_set_range_sync(store) + data_buf = self.buffer_cls.from_bytes(b"hello") + key = "range_oob" + setter.set_sync(key, data_buf) + patch = default_buffer_prototype().buffer.from_bytes(b"world!") + with pytest.raises(ValueError, match="set_range would write beyond"): + ranger.set_range_sync(key, patch, 0) + + def test_set_range_sync_start_beyond_end_raises(self, store: S) -> None: + setter = self._require_set_sync(store) + ranger = self._require_set_range_sync(store) + data_buf = self.buffer_cls.from_bytes(b"hello") + key = "range_oob2" + setter.set_sync(key, data_buf) + patch = default_buffer_prototype().buffer.from_bytes(b"x") + with pytest.raises(ValueError, match="set_range would write beyond"): + ranger.set_range_sync(key, patch, 10) + class LatencyStore(WrapperStore[Store]): """ diff --git a/tests/test_codecs/test_blosc.py b/tests/test_codecs/test_blosc.py index 6f4821f8b1..0201beb8de 100644 --- a/tests/test_codecs/test_blosc.py +++ b/tests/test_codecs/test_blosc.py @@ -6,11 +6,12 @@ from packaging.version import Version import zarr +from zarr.abc.codec import SupportsSyncCodec from zarr.codecs import BloscCodec from zarr.codecs.blosc import BloscShuffle, Shuffle -from zarr.core.array_spec import ArraySpec +from zarr.core.array_spec import ArrayConfig, ArraySpec from zarr.core.buffer import default_buffer_prototype -from zarr.core.dtype import UInt16 +from zarr.core.dtype import UInt16, get_data_type_from_native_dtype from zarr.storage import MemoryStore, StorePath @@ -110,3 +111,27 @@ async def test_typesize() -> None: else: expected_size = 10216 assert size == expected_size, msg + + +def test_blosc_codec_supports_sync() -> None: + assert isinstance(BloscCodec(), SupportsSyncCodec) + + +def test_blosc_codec_sync_roundtrip() -> None: + codec = BloscCodec(typesize=8) + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_codecs/test_crc32c.py b/tests/test_codecs/test_crc32c.py new file mode 100644 index 0000000000..3ab1070f60 --- /dev/null +++ b/tests/test_codecs/test_crc32c.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import numpy as np + +from zarr.abc.codec import SupportsSyncCodec +from zarr.codecs.crc32c_ import Crc32cCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import default_buffer_prototype +from zarr.core.dtype import get_data_type_from_native_dtype + + +def test_crc32c_codec_supports_sync() -> None: + assert isinstance(Crc32cCodec(), SupportsSyncCodec) + + +def test_crc32c_codec_sync_roundtrip() -> None: + codec = Crc32cCodec() + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_codecs/test_endian.py b/tests/test_codecs/test_endian.py index ab64afb1b8..c505cee828 100644 --- a/tests/test_codecs/test_endian.py +++ b/tests/test_codecs/test_endian.py @@ -4,8 +4,12 @@ import pytest import zarr +from zarr.abc.codec import SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import BytesCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import NDBuffer, default_buffer_prototype +from zarr.core.dtype import get_data_type_from_native_dtype from zarr.storage import StorePath from .test_codecs import _AsyncArrayProxy @@ -33,6 +37,31 @@ async def test_endian(store: Store, endian: Literal["big", "little"]) -> None: assert np.array_equal(data, readback_data) +def test_bytes_codec_supports_sync() -> None: + assert isinstance(BytesCodec(), SupportsSyncCodec) + + +def test_bytes_codec_sync_roundtrip() -> None: + codec = BytesCodec() + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + nd_buf: NDBuffer = default_buffer_prototype().nd_buffer.from_numpy_array(arr) + + codec = codec.evolve_from_array_spec(spec) + + encoded = codec._encode_sync(nd_buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + @pytest.mark.filterwarnings("ignore:The endianness of the requested serializer") @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) @pytest.mark.parametrize("dtype_input_endian", [">u2", " None: a[:, :] = data assert np.array_equal(data, a[:, :]) + + +def test_gzip_codec_supports_sync() -> None: + assert isinstance(GzipCodec(), SupportsSyncCodec) + + +def test_gzip_codec_sync_roundtrip() -> None: + codec = GzipCodec(level=1) + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_codecs/test_transpose.py b/tests/test_codecs/test_transpose.py index 06ec668ad3..949bb72a62 100644 --- a/tests/test_codecs/test_transpose.py +++ b/tests/test_codecs/test_transpose.py @@ -3,9 +3,13 @@ import zarr from zarr import AsyncArray, config +from zarr.abc.codec import SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import TransposeCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import NDBuffer, default_buffer_prototype from zarr.core.common import MemoryOrder +from zarr.core.dtype import get_data_type_from_native_dtype from zarr.storage import StorePath from .test_codecs import _AsyncArrayProxy @@ -93,3 +97,27 @@ def test_transpose_invalid( chunk_key_encoding={"name": "v2", "separator": "."}, filters=[TransposeCodec(order=order)], # type: ignore[arg-type] ) + + +def test_transpose_codec_supports_sync() -> None: + assert isinstance(TransposeCodec(order=(0, 1)), SupportsSyncCodec) + + +def test_transpose_codec_sync_roundtrip() -> None: + codec = TransposeCodec(order=(1, 0)) + arr = np.arange(12, dtype="float64").reshape(3, 4) + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + nd_buf: NDBuffer = default_buffer_prototype().nd_buffer.from_numpy_array(arr) + + encoded = codec._encode_sync(nd_buf, spec) + assert encoded is not None + resolved_spec = codec.resolve_metadata(spec) + decoded = codec._decode_sync(encoded, resolved_spec) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) diff --git a/tests/test_codecs/test_vlen.py b/tests/test_codecs/test_vlen.py index cf0905daca..f3445824b3 100644 --- a/tests/test_codecs/test_vlen.py +++ b/tests/test_codecs/test_vlen.py @@ -5,9 +5,10 @@ import zarr from zarr import Array -from zarr.abc.codec import Codec +from zarr.abc.codec import Codec, SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import ZstdCodec +from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec from zarr.core.dtype import get_data_type_from_native_dtype from zarr.core.dtype.npy.string import _NUMPY_SUPPORTS_VLEN_STRING from zarr.core.metadata.v3 import ArrayV3Metadata @@ -62,3 +63,11 @@ def test_vlen_string( assert np.array_equal(data, b[:, :]) assert b.metadata.data_type == get_data_type_from_native_dtype(data.dtype) assert a.dtype == data.dtype + + +def test_vlen_utf8_codec_supports_sync() -> None: + assert isinstance(VLenUTF8Codec(), SupportsSyncCodec) + + +def test_vlen_bytes_codec_supports_sync() -> None: + assert isinstance(VLenBytesCodec(), SupportsSyncCodec) diff --git a/tests/test_codecs/test_zstd.py b/tests/test_codecs/test_zstd.py index 6068f53443..68297e4d94 100644 --- a/tests/test_codecs/test_zstd.py +++ b/tests/test_codecs/test_zstd.py @@ -2,8 +2,12 @@ import pytest import zarr +from zarr.abc.codec import SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import ZstdCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import default_buffer_prototype +from zarr.core.dtype import get_data_type_from_native_dtype from zarr.storage import StorePath @@ -23,3 +27,31 @@ def test_zstd(store: Store, checksum: bool) -> None: a[:, :] = data assert np.array_equal(data, a[:, :]) + + +def test_zstd_codec_supports_sync() -> None: + assert isinstance(ZstdCodec(), SupportsSyncCodec) + + +def test_zstd_is_not_fixed_size() -> None: + assert ZstdCodec.is_fixed_size is False + + +def test_zstd_codec_sync_roundtrip() -> None: + codec = ZstdCodec(level=1) + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_config.py b/tests/test_config.py index c3102e8efe..7a1c9f3f4e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -10,7 +10,7 @@ import zarr from zarr import zeros from zarr.abc.codec import CodecPipeline -from zarr.abc.store import ByteSetter, Store +from zarr.abc.store import Store from zarr.codecs import ( BloscCodec, BytesCodec, @@ -20,9 +20,8 @@ from zarr.core.array_spec import ArraySpec from zarr.core.buffer import NDBuffer from zarr.core.buffer.core import Buffer -from zarr.core.codec_pipeline import BatchedCodecPipeline +from zarr.core.codec_pipeline import BatchedCodecPipeline, WriteChunkRequest from zarr.core.config import BadConfigError, config -from zarr.core.indexing import SelectorTuple from zarr.errors import ZarrUserWarning from zarr.registry import ( fully_qualified_name, @@ -56,7 +55,10 @@ def test_config_defaults_set() -> None: "target_shard_size_bytes": None, }, "async": {"concurrency": 10, "timeout": None}, - "threading": {"max_workers": None}, + "threading": { + "max_workers": None, + "codec_workers": {"enabled": True, "min": 0, "max": None}, + }, "json_indent": 2, "codec_pipeline": { "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", @@ -140,7 +142,7 @@ def test_config_codec_pipeline_class(store: Store) -> None: class MockCodecPipeline(BatchedCodecPipeline): async def write( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[WriteChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: diff --git a/tests/test_indexing.py b/tests/test_indexing.py index c0bf7dd270..9c734fb0c3 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator + from zarr.abc.store import ByteRequest from zarr.core.buffer import BufferPrototype from zarr.core.buffer.core import Buffer @@ -83,6 +84,22 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None self.counter["__setitem__", key_suffix] += 1 return await super().set(key, value, byte_range) + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + key_suffix = "/".join(key.split("/")[1:]) + self.counter["__getitem__", key_suffix] += 1 + return super().get_sync(key, prototype=prototype, byte_range=byte_range) + + def set_sync(self, key: str, value: Buffer) -> None: + key_suffix = "/".join(key.split("/")[1:]) + self.counter["__setitem__", key_suffix] += 1 + return super().set_sync(key, value) + def test_normalize_integer_selection() -> None: assert 1 == normalize_integer_selection(1, 100) diff --git a/tests/test_sync_codec_pipeline.py b/tests/test_sync_codec_pipeline.py new file mode 100644 index 0000000000..d241adee49 --- /dev/null +++ b/tests/test_sync_codec_pipeline.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +from zarr.codecs.bytes import BytesCodec +from zarr.codecs.gzip import GzipCodec +from zarr.codecs.transpose import TransposeCodec +from zarr.codecs.zstd import ZstdCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import NDBuffer, default_buffer_prototype +from zarr.core.codec_pipeline import ( + BatchedCodecPipeline, + ChunkTransform, + ReadChunkRequest, + WriteChunkRequest, + _choose_workers, +) +from zarr.core.dtype import get_data_type_from_native_dtype +from zarr.storage._common import StorePath +from zarr.storage._memory import MemoryStore + + +def _make_array_spec(shape: tuple[int, ...], dtype: np.dtype[np.generic]) -> ArraySpec: + zdtype = get_data_type_from_native_dtype(dtype) + return ArraySpec( + shape=shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + + +def _make_nd_buffer(arr: np.ndarray[Any, np.dtype[Any]]) -> NDBuffer: + return default_buffer_prototype().nd_buffer.from_numpy_array(arr) + + +class TestChunkTransform: + def test_all_sync(self) -> None: + spec = _make_array_spec((100,), np.dtype("float64")) + chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec) + assert chain.all_sync is True + + def test_all_sync_with_compression(self) -> None: + spec = _make_array_spec((100,), np.dtype("float64")) + chain = ChunkTransform(codecs=(BytesCodec(), GzipCodec()), array_spec=spec) + assert chain.all_sync is True + + def test_all_sync_full_chain(self) -> None: + spec = _make_array_spec((3, 4), np.dtype("float64")) + chain = ChunkTransform( + codecs=(TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()), array_spec=spec + ) + assert chain.all_sync is True + + def test_encode_decode_roundtrip_bytes_only(self) -> None: + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec) + nd_buf = _make_nd_buffer(arr) + + encoded = chain.encode_chunk(nd_buf) + assert encoded is not None + decoded = chain.decode_chunk(encoded) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + def test_layers_no_aa_codecs(self) -> None: + spec = _make_array_spec((100,), np.dtype("float64")) + chunk = ChunkTransform(codecs=(BytesCodec(), GzipCodec()), array_spec=spec) + assert chunk.layers == () + + def test_layers_with_transpose(self) -> None: + spec = _make_array_spec((3, 4), np.dtype("float64")) + transpose = TransposeCodec(order=(1, 0)) + chunk = ChunkTransform(codecs=(transpose, BytesCodec(), ZstdCodec()), array_spec=spec) + assert len(chunk.layers) == 1 + assert chunk.layers[0][0] is transpose + assert chunk.layers[0][1] is spec + + def test_shape_dtype_no_aa_codecs(self) -> None: + spec = _make_array_spec((100,), np.dtype("float64")) + chunk = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec) + assert chunk.shape == (100,) + assert chunk.dtype == spec.dtype + + def test_shape_dtype_with_transpose(self) -> None: + spec = _make_array_spec((3, 4), np.dtype("float64")) + chunk = ChunkTransform(codecs=(TransposeCodec(order=(1, 0)), BytesCodec()), array_spec=spec) + # After transpose (1,0), shape (3,4) becomes (4,3) + assert chunk.shape == (4, 3) + assert chunk.dtype == spec.dtype + + def test_encode_decode_roundtrip_with_compression(self) -> None: + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + chain = ChunkTransform(codecs=(BytesCodec(), GzipCodec(level=1)), array_spec=spec) + nd_buf = _make_nd_buffer(arr) + + encoded = chain.encode_chunk(nd_buf) + assert encoded is not None + decoded = chain.decode_chunk(encoded) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + def test_encode_decode_roundtrip_with_transpose(self) -> None: + arr = np.arange(12, dtype="float64").reshape(3, 4) + spec = _make_array_spec(arr.shape, arr.dtype) + chain = ChunkTransform( + codecs=(TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)), + array_spec=spec, + ) + nd_buf = _make_nd_buffer(arr) + + encoded = chain.encode_chunk(nd_buf) + assert encoded is not None + decoded = chain.decode_chunk(encoded) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + +# --------------------------------------------------------------------------- +# Helpers for sync pipeline tests +# --------------------------------------------------------------------------- + + +def _make_pipeline( + codecs: tuple[Any, ...], +) -> BatchedCodecPipeline: + return BatchedCodecPipeline.from_codecs(codecs) + + +def _make_store_path(key: str = "chunk/0") -> StorePath: + store = MemoryStore() + return StorePath(store, key) + + +# --------------------------------------------------------------------------- +# Sync pipeline tests +# --------------------------------------------------------------------------- + + +class TestSyncPipeline: + def test_supports_sync_io(self) -> None: + pipeline = _make_pipeline((BytesCodec(),)) + assert pipeline.supports_sync_io is True + + def test_supports_sync_io_with_compression(self) -> None: + pipeline = _make_pipeline((BytesCodec(), ZstdCodec())) + assert pipeline.supports_sync_io is True + + def test_write_sync_read_sync_roundtrip(self) -> None: + """Write data via write_sync, read it back via read_sync.""" + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + pipeline = _make_pipeline((BytesCodec(),)) + store_path = _make_store_path() + transform = ChunkTransform(codecs=tuple(pipeline), array_spec=spec) + + value = _make_nd_buffer(arr) + + # Write + pipeline.write_sync( + [ + WriteChunkRequest( + byte_setter=store_path, + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(None),), + is_complete_chunk=True, + ) + ], + value, + ) + + # Read + out = default_buffer_prototype().nd_buffer.create( + shape=arr.shape, + dtype=arr.dtype, + order="C", + fill_value=0, + ) + pipeline.read_sync( + [ + ReadChunkRequest( + byte_getter=store_path, + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(None),), + ) + ], + out, + ) + + np.testing.assert_array_equal(arr, out.as_numpy_array()) + + def test_write_sync_read_sync_with_compression(self) -> None: + """Round-trip with a compression codec.""" + arr = np.arange(200, dtype="float32").reshape(10, 20) + spec = _make_array_spec(arr.shape, arr.dtype) + pipeline = _make_pipeline((BytesCodec(), GzipCodec(level=1))) + store_path = _make_store_path() + transform = ChunkTransform(codecs=tuple(pipeline), array_spec=spec) + + value = _make_nd_buffer(arr) + + pipeline.write_sync( + [ + WriteChunkRequest( + byte_setter=store_path, + transform=transform, + chunk_selection=(slice(None), slice(None)), + out_selection=(slice(None), slice(None)), + is_complete_chunk=True, + ) + ], + value, + ) + + out = default_buffer_prototype().nd_buffer.create( + shape=arr.shape, + dtype=arr.dtype, + order="C", + fill_value=0, + ) + pipeline.read_sync( + [ + ReadChunkRequest( + byte_getter=store_path, + transform=transform, + chunk_selection=(slice(None), slice(None)), + out_selection=(slice(None), slice(None)), + ) + ], + out, + ) + + np.testing.assert_array_equal(arr, out.as_numpy_array()) + + def test_write_sync_partial_chunk(self) -> None: + """Write a partial chunk (is_complete_chunk=False), read back full chunk.""" + shape = (10,) + spec = _make_array_spec(shape, np.dtype("float64")) + pipeline = _make_pipeline((BytesCodec(),)) + store_path = _make_store_path() + transform = ChunkTransform(codecs=tuple(pipeline), array_spec=spec) + + # Write only first 5 elements + value = _make_nd_buffer(np.arange(5, dtype="float64")) + + pipeline.write_sync( + [ + WriteChunkRequest( + byte_setter=store_path, + transform=transform, + chunk_selection=(slice(0, 5),), + out_selection=(slice(None),), + is_complete_chunk=False, + ) + ], + value, + ) + + # Read back full chunk + out = default_buffer_prototype().nd_buffer.create( + shape=shape, + dtype=np.dtype("float64"), + order="C", + fill_value=-1, + ) + pipeline.read_sync( + [ + ReadChunkRequest( + byte_getter=store_path, + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(None),), + ) + ], + out, + ) + + result = out.as_numpy_array() + np.testing.assert_array_equal(result[:5], np.arange(5, dtype="float64")) + # Remaining elements should be fill value (0) + np.testing.assert_array_equal(result[5:], 0) + + def test_read_sync_missing_chunk(self) -> None: + """Reading a non-existent chunk should fill with fill value.""" + spec = _make_array_spec((10,), np.dtype("float64")) + pipeline = _make_pipeline((BytesCodec(),)) + store_path = _make_store_path("nonexistent/chunk") + transform = ChunkTransform(codecs=tuple(pipeline), array_spec=spec) + + out = default_buffer_prototype().nd_buffer.create( + shape=(10,), + dtype=np.dtype("float64"), + order="C", + fill_value=-1, + ) + pipeline.read_sync( + [ + ReadChunkRequest( + byte_getter=store_path, + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(None),), + ) + ], + out, + ) + + # Should be filled with the spec's fill value (0) + np.testing.assert_array_equal(out.as_numpy_array(), 0) + + def test_write_sync_multiple_chunks(self) -> None: + """Write and read multiple chunks in one batch.""" + spec = _make_array_spec((10,), np.dtype("float64")) + pipeline = _make_pipeline((BytesCodec(),)) + store = MemoryStore() + transform = ChunkTransform(codecs=tuple(pipeline), array_spec=spec) + + value = _make_nd_buffer(np.arange(20, dtype="float64")) + + pipeline.write_sync( + [ + WriteChunkRequest( + byte_setter=StorePath(store, "c/0"), + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(0, 10),), + is_complete_chunk=True, + ), + WriteChunkRequest( + byte_setter=StorePath(store, "c/1"), + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(10, 20),), + is_complete_chunk=True, + ), + ], + value, + ) + + out = default_buffer_prototype().nd_buffer.create( + shape=(20,), + dtype=np.dtype("float64"), + order="C", + fill_value=0, + ) + pipeline.read_sync( + [ + ReadChunkRequest( + byte_getter=StorePath(store, "c/0"), + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(0, 10),), + ), + ReadChunkRequest( + byte_getter=StorePath(store, "c/1"), + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(10, 20),), + ), + ], + out, + ) + + np.testing.assert_array_equal(out.as_numpy_array(), np.arange(20, dtype="float64")) + + +class TestChooseWorkers: + def test_returns_zero_for_single_chunk(self) -> None: + codecs = (BytesCodec(), ZstdCodec()) + assert _choose_workers(1, 1_000_000, codecs) == 0 + + def test_returns_nonzero_for_large_compressed_batch(self) -> None: + codecs = (BytesCodec(), ZstdCodec()) + n_workers = _choose_workers(10, 1_000_000, codecs) + assert n_workers > 0 + + def test_returns_zero_without_bb_codecs(self) -> None: + codecs: tuple[Any, ...] = (BytesCodec(),) + assert _choose_workers(10, 1_000_000, codecs) == 0 + + def test_returns_zero_for_small_chunks(self) -> None: + codecs = (BytesCodec(), ZstdCodec()) + assert _choose_workers(10, 100, codecs) == 0 + + @pytest.mark.parametrize("enabled", [True, False]) + def test_config_enabled(self, enabled: bool) -> None: + from zarr.core.config import config + + with config.set({"threading.codec_workers.enabled": enabled}): + codecs = (BytesCodec(), ZstdCodec()) + result = _choose_workers(10, 1_000_000, codecs) + if enabled: + assert result > 0 + else: + assert result == 0