From 2b64daa6aeadb73925aaea9d5c80ab28f23b9141 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 25 Feb 2026 18:33:03 -0500 Subject: [PATCH 01/14] add sync methods to codecs --- src/zarr/abc/codec.py | 20 ++++++++++++++++++- src/zarr/codecs/blosc.py | 27 ++++++++++++++++---------- src/zarr/codecs/bytes.py | 18 ++++++++++++++++-- src/zarr/codecs/crc32c_.py | 18 ++++++++++++++++-- src/zarr/codecs/gzip.py | 27 ++++++++++++++++++++------ src/zarr/codecs/transpose.py | 20 ++++++++++++++++--- src/zarr/codecs/vlen_utf8.py | 37 +++++++++++++++++++++++++++++++----- src/zarr/codecs/zstd.py | 24 ++++++++++++++++------- 8 files changed, 155 insertions(+), 36 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d41c457b4e..3ec5ec522b 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -2,7 +2,7 @@ from abc import abstractmethod from collections.abc import Mapping -from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar +from typing import TYPE_CHECKING, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable from typing_extensions import ReadOnly, TypedDict @@ -32,6 +32,7 @@ "CodecInput", "CodecOutput", "CodecPipeline", + "SupportsSyncCodec", ] CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer) @@ -59,6 +60,23 @@ def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]: """The widest type of JSON-like input that could specify a codec.""" +@runtime_checkable +class SupportsSyncCodec(Protocol): + """Protocol for codecs that support synchronous encode/decode. + + Codecs implementing this protocol provide ``_decode_sync`` and ``_encode_sync`` + methods that perform encoding/decoding without requiring an async event loop. + """ + + def _decode_sync( + self, chunk_data: NDBuffer | Buffer, chunk_spec: ArraySpec + ) -> NDBuffer | Buffer: ... + + def _encode_sync( + self, chunk_data: NDBuffer | Buffer, chunk_spec: ArraySpec + ) -> NDBuffer | Buffer | None: ... + + class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]): """Generic base class for codecs. diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index 5b91cfa005..d05731d640 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -299,13 +299,27 @@ def _blosc_codec(self) -> Blosc: config_dict["typesize"] = self.typesize return Blosc.from_config(config_dict) + def _decode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer: + return as_numpy_array_wrapper(self._blosc_codec.decode, chunk_bytes, chunk_spec.prototype) + async def _decode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer: - return await asyncio.to_thread( - as_numpy_array_wrapper, self._blosc_codec.decode, chunk_bytes, chunk_spec.prototype + return await asyncio.to_thread(self._decode_sync, chunk_bytes, chunk_spec) + + def _encode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return chunk_spec.prototype.buffer.from_bytes( + self._blosc_codec.encode(chunk_bytes.as_numpy_array()) ) async def _encode_single( @@ -313,14 +327,7 @@ async def _encode_single( chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - # Since blosc only support host memory, we convert the input and output of the encoding - # between numpy array and buffer - return await asyncio.to_thread( - lambda chunk: chunk_spec.prototype.buffer.from_bytes( - self._blosc_codec.encode(chunk.as_numpy_array()) - ), - chunk_bytes, - ) + return await asyncio.to_thread(self._encode_sync, chunk_bytes, chunk_spec) def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index 1fbdeef497..86bb354fb5 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -65,7 +65,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: ) return self - async def _decode_single( + def _decode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -88,7 +88,14 @@ async def _decode_single( ) return chunk_array - async def _encode_single( + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_bytes, chunk_spec) + + def _encode_sync( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, @@ -109,5 +116,12 @@ async def _encode_single( nd_array = nd_array.ravel().view(dtype="B") return chunk_spec.prototype.buffer.from_array_like(nd_array) + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return self._encode_sync(chunk_array, chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/src/zarr/codecs/crc32c_.py b/src/zarr/codecs/crc32c_.py index 9536d0d558..ebe2ac8f7a 100644 --- a/src/zarr/codecs/crc32c_.py +++ b/src/zarr/codecs/crc32c_.py @@ -31,7 +31,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: def to_dict(self) -> dict[str, JSON]: return {"name": "crc32c"} - async def _decode_single( + def _decode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -51,7 +51,14 @@ async def _decode_single( ) return chunk_spec.prototype.buffer.from_array_like(inner_bytes) - async def _encode_single( + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer: + return self._decode_sync(chunk_bytes, chunk_spec) + + def _encode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -64,5 +71,12 @@ async def _encode_single( # Append the checksum (as bytes) to the data return chunk_spec.prototype.buffer.from_array_like(np.append(data, checksum.view("B"))) + async def _encode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return self._encode_sync(chunk_bytes, chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length + 4 diff --git a/src/zarr/codecs/gzip.py b/src/zarr/codecs/gzip.py index 610ca9dadd..b8591748f7 100644 --- a/src/zarr/codecs/gzip.py +++ b/src/zarr/codecs/gzip.py @@ -2,6 +2,7 @@ import asyncio from dataclasses import dataclass +from functools import cached_property from typing import TYPE_CHECKING from numcodecs.gzip import GZip @@ -48,23 +49,37 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: def to_dict(self) -> dict[str, JSON]: return {"name": "gzip", "configuration": {"level": self.level}} + @cached_property + def _gzip_codec(self) -> GZip: + return GZip(self.level) + + def _decode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer: + return as_numpy_array_wrapper(self._gzip_codec.decode, chunk_bytes, chunk_spec.prototype) + async def _decode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer: - return await asyncio.to_thread( - as_numpy_array_wrapper, GZip(self.level).decode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._decode_sync, chunk_bytes, chunk_spec) + + def _encode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return as_numpy_array_wrapper(self._gzip_codec.encode, chunk_bytes, chunk_spec.prototype) async def _encode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - return await asyncio.to_thread( - as_numpy_array_wrapper, GZip(self.level).encode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._encode_sync, chunk_bytes, chunk_spec) def compute_encoded_size( self, diff --git a/src/zarr/codecs/transpose.py b/src/zarr/codecs/transpose.py index a8570b6e8f..609448a59c 100644 --- a/src/zarr/codecs/transpose.py +++ b/src/zarr/codecs/transpose.py @@ -95,20 +95,34 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: prototype=chunk_spec.prototype, ) - async def _decode_single( + def _decode_sync( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, ) -> NDBuffer: - inverse_order = np.argsort(self.order) + inverse_order = tuple(int(i) for i in np.argsort(self.order)) return chunk_array.transpose(inverse_order) - async def _encode_single( + async def _decode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_array, chunk_spec) + + def _encode_sync( self, chunk_array: NDBuffer, _chunk_spec: ArraySpec, ) -> NDBuffer | None: return chunk_array.transpose(self.order) + async def _encode_single( + self, + chunk_array: NDBuffer, + _chunk_spec: ArraySpec, + ) -> NDBuffer | None: + return self._encode_sync(chunk_array, _chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py index fb1fb76126..a10cb7c335 100644 --- a/src/zarr/codecs/vlen_utf8.py +++ b/src/zarr/codecs/vlen_utf8.py @@ -40,8 +40,7 @@ def to_dict(self) -> dict[str, JSON]: def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self - # TODO: expand the tests for this function - async def _decode_single( + def _decode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -55,7 +54,14 @@ async def _decode_single( as_string_dtype = decoded.astype(chunk_spec.dtype.to_native_dtype(), copy=False) return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype) - async def _encode_single( + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_bytes, chunk_spec) + + def _encode_sync( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, @@ -65,6 +71,13 @@ async def _encode_single( _vlen_utf8_codec.encode(chunk_array.as_numpy_array()) ) + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return self._encode_sync(chunk_array, chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: # what is input_byte_length for an object dtype? raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") @@ -86,7 +99,7 @@ def to_dict(self) -> dict[str, JSON]: def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self - async def _decode_single( + def _decode_sync( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, @@ -99,7 +112,14 @@ async def _decode_single( decoded = _reshape_view(decoded, chunk_spec.shape) return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded) - async def _encode_single( + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + return self._decode_sync(chunk_bytes, chunk_spec) + + def _encode_sync( self, chunk_array: NDBuffer, chunk_spec: ArraySpec, @@ -109,6 +129,13 @@ async def _encode_single( _vlen_bytes_codec.encode(chunk_array.as_numpy_array()) ) + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return self._encode_sync(chunk_array, chunk_spec) + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: # what is input_byte_length for an object dtype? raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") diff --git a/src/zarr/codecs/zstd.py b/src/zarr/codecs/zstd.py index 27cc9a7777..f93c25a3c7 100644 --- a/src/zarr/codecs/zstd.py +++ b/src/zarr/codecs/zstd.py @@ -38,7 +38,7 @@ def parse_checksum(data: JSON) -> bool: class ZstdCodec(BytesBytesCodec): """zstd codec""" - is_fixed_size = True + is_fixed_size = False level: int = 0 checksum: bool = False @@ -71,23 +71,33 @@ def _zstd_codec(self) -> Zstd: config_dict = {"level": self.level, "checksum": self.checksum} return Zstd.from_config(config_dict) + def _decode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer: + return as_numpy_array_wrapper(self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype) + async def _decode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer: - return await asyncio.to_thread( - as_numpy_array_wrapper, self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._decode_sync, chunk_bytes, chunk_spec) + + def _encode_sync( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + return as_numpy_array_wrapper(self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype) async def _encode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - return await asyncio.to_thread( - as_numpy_array_wrapper, self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype - ) + return await asyncio.to_thread(self._encode_sync, chunk_bytes, chunk_spec) def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError From cd4efb0a1ed611d083e650b87e3de44cc39f5a46 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 25 Feb 2026 19:09:51 -0500 Subject: [PATCH 02/14] add CodecChain dataclass and sync codec tests Introduces CodecChain, a frozen dataclass that chains array-array, array-bytes, and bytes-bytes codecs with synchronous encode/decode methods. Pure compute only -- no IO, no threading, no batching. Also adds sync roundtrip tests for individual codecs (blosc, gzip, zstd, crc32c, bytes, transpose, vlen) and CodecChain integration tests. Co-Authored-By: Claude Opus 4.6 --- src/zarr/core/codec_pipeline.py | 133 +++++++++++++++++++++++- tests/test_codecs/test_blosc.py | 29 +++++- tests/test_codecs/test_crc32c.py | 33 ++++++ tests/test_codecs/test_endian.py | 29 ++++++ tests/test_codecs/test_gzip.py | 28 +++++ tests/test_codecs/test_transpose.py | 28 +++++ tests/test_codecs/test_vlen.py | 11 +- tests/test_codecs/test_zstd.py | 32 ++++++ tests/test_sync_codec_pipeline.py | 156 ++++++++++++++++++++++++++++ 9 files changed, 474 insertions(+), 5 deletions(-) create mode 100644 tests/test_codecs/test_crc32c.py create mode 100644 tests/test_sync_codec_pipeline.py diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index fd557ac43e..7425ec30f6 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -1,8 +1,8 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from itertools import islice, pairwise -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast from warnings import warn from zarr.abc.codec import ( @@ -13,6 +13,7 @@ BytesBytesCodec, Codec, CodecPipeline, + SupportsSyncCodec, ) from zarr.core.common import concurrent_map from zarr.core.config import config @@ -68,6 +69,134 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any: return fill_value +@dataclass(frozen=True) +class CodecChain: + """Lightweight codec chain: array-array -> array-bytes -> bytes-bytes. + + Pure compute only -- no IO methods, no threading, no batching. + """ + + array_array_codecs: tuple[ArrayArrayCodec, ...] + array_bytes_codec: ArrayBytesCodec + bytes_bytes_codecs: tuple[BytesBytesCodec, ...] + + _all_sync: bool = field(default=False, init=False, repr=False, compare=False) + + def __post_init__(self) -> None: + object.__setattr__( + self, + "_all_sync", + all(isinstance(c, SupportsSyncCodec) for c in self), + ) + + def __iter__(self) -> Iterator[Codec]: + yield from self.array_array_codecs + yield self.array_bytes_codec + yield from self.bytes_bytes_codecs + + @classmethod + def from_codecs(cls, codecs: Iterable[Codec]) -> CodecChain: + aa, ab, bb = codecs_from_list(list(codecs)) + return cls(array_array_codecs=aa, array_bytes_codec=ab, bytes_bytes_codecs=bb) + + def resolve_metadata_chain( + self, chunk_spec: ArraySpec + ) -> tuple[ + list[tuple[ArrayArrayCodec, ArraySpec]], + tuple[ArrayBytesCodec, ArraySpec], + list[tuple[BytesBytesCodec, ArraySpec]], + ]: + """Resolve metadata through the codec chain for a single chunk_spec.""" + aa_codecs_with_spec: list[tuple[ArrayArrayCodec, ArraySpec]] = [] + spec = chunk_spec + for aa_codec in self.array_array_codecs: + aa_codecs_with_spec.append((aa_codec, spec)) + spec = aa_codec.resolve_metadata(spec) + + ab_codec_with_spec = (self.array_bytes_codec, spec) + spec = self.array_bytes_codec.resolve_metadata(spec) + + bb_codecs_with_spec: list[tuple[BytesBytesCodec, ArraySpec]] = [] + for bb_codec in self.bytes_bytes_codecs: + bb_codecs_with_spec.append((bb_codec, spec)) + spec = bb_codec.resolve_metadata(spec) + + return (aa_codecs_with_spec, ab_codec_with_spec, bb_codecs_with_spec) + + def decode_chunk( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + aa_chain: Iterable[tuple[ArrayArrayCodec, ArraySpec]] | None = None, + ab_pair: tuple[ArrayBytesCodec, ArraySpec] | None = None, + bb_chain: Iterable[tuple[BytesBytesCodec, ArraySpec]] | None = None, + ) -> NDBuffer: + """Decode a single chunk through the full codec chain, synchronously. + + Pure compute -- no IO. Only callable when all codecs support sync. + + The optional ``aa_chain``, ``ab_pair``, ``bb_chain`` parameters allow + pre-resolved metadata to be reused across many chunks with the same spec. + If not provided, ``resolve_metadata_chain`` is called internally. + """ + if aa_chain is None or ab_pair is None or bb_chain is None: + aa_chain, ab_pair, bb_chain = self.resolve_metadata_chain(chunk_spec) + + bb_out: Any = chunk_bytes + for bb_codec, spec in reversed(list(bb_chain)): + bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, spec) + + ab_codec, ab_spec = ab_pair + ab_out: Any = cast("SupportsSyncCodec", ab_codec)._decode_sync(bb_out, ab_spec) + + for aa_codec, spec in reversed(list(aa_chain)): + ab_out = cast("SupportsSyncCodec", aa_codec)._decode_sync(ab_out, spec) + + return ab_out # type: ignore[no-any-return] + + def encode_chunk( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + """Encode a single chunk through the full codec chain, synchronously. + + Pure compute -- no IO. Only callable when all codecs support sync. + """ + spec = chunk_spec + aa_out: Any = chunk_array + + for aa_codec in self.array_array_codecs: + if aa_out is None: + return None + aa_out = cast("SupportsSyncCodec", aa_codec)._encode_sync(aa_out, spec) + spec = aa_codec.resolve_metadata(spec) + + if aa_out is None: + return None + bb_out: Any = cast("SupportsSyncCodec", self.array_bytes_codec)._encode_sync(aa_out, spec) + spec = self.array_bytes_codec.resolve_metadata(spec) + + for bb_codec in self.bytes_bytes_codecs: + if bb_out is None: + return None + bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, spec) + spec = bb_codec.resolve_metadata(spec) + + return bb_out # type: ignore[no-any-return] + + def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: + for codec in self: + byte_length = codec.compute_encoded_size(byte_length, array_spec) + array_spec = codec.resolve_metadata(array_spec) + return byte_length + + def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: + for codec in self: + chunk_spec = codec.resolve_metadata(chunk_spec) + return chunk_spec + + @dataclass(frozen=True) class BatchedCodecPipeline(CodecPipeline): """Default codec pipeline. diff --git a/tests/test_codecs/test_blosc.py b/tests/test_codecs/test_blosc.py index 6f4821f8b1..0201beb8de 100644 --- a/tests/test_codecs/test_blosc.py +++ b/tests/test_codecs/test_blosc.py @@ -6,11 +6,12 @@ from packaging.version import Version import zarr +from zarr.abc.codec import SupportsSyncCodec from zarr.codecs import BloscCodec from zarr.codecs.blosc import BloscShuffle, Shuffle -from zarr.core.array_spec import ArraySpec +from zarr.core.array_spec import ArrayConfig, ArraySpec from zarr.core.buffer import default_buffer_prototype -from zarr.core.dtype import UInt16 +from zarr.core.dtype import UInt16, get_data_type_from_native_dtype from zarr.storage import MemoryStore, StorePath @@ -110,3 +111,27 @@ async def test_typesize() -> None: else: expected_size = 10216 assert size == expected_size, msg + + +def test_blosc_codec_supports_sync() -> None: + assert isinstance(BloscCodec(), SupportsSyncCodec) + + +def test_blosc_codec_sync_roundtrip() -> None: + codec = BloscCodec(typesize=8) + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_codecs/test_crc32c.py b/tests/test_codecs/test_crc32c.py new file mode 100644 index 0000000000..3ab1070f60 --- /dev/null +++ b/tests/test_codecs/test_crc32c.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import numpy as np + +from zarr.abc.codec import SupportsSyncCodec +from zarr.codecs.crc32c_ import Crc32cCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import default_buffer_prototype +from zarr.core.dtype import get_data_type_from_native_dtype + + +def test_crc32c_codec_supports_sync() -> None: + assert isinstance(Crc32cCodec(), SupportsSyncCodec) + + +def test_crc32c_codec_sync_roundtrip() -> None: + codec = Crc32cCodec() + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_codecs/test_endian.py b/tests/test_codecs/test_endian.py index ab64afb1b8..c505cee828 100644 --- a/tests/test_codecs/test_endian.py +++ b/tests/test_codecs/test_endian.py @@ -4,8 +4,12 @@ import pytest import zarr +from zarr.abc.codec import SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import BytesCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import NDBuffer, default_buffer_prototype +from zarr.core.dtype import get_data_type_from_native_dtype from zarr.storage import StorePath from .test_codecs import _AsyncArrayProxy @@ -33,6 +37,31 @@ async def test_endian(store: Store, endian: Literal["big", "little"]) -> None: assert np.array_equal(data, readback_data) +def test_bytes_codec_supports_sync() -> None: + assert isinstance(BytesCodec(), SupportsSyncCodec) + + +def test_bytes_codec_sync_roundtrip() -> None: + codec = BytesCodec() + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + nd_buf: NDBuffer = default_buffer_prototype().nd_buffer.from_numpy_array(arr) + + codec = codec.evolve_from_array_spec(spec) + + encoded = codec._encode_sync(nd_buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + @pytest.mark.filterwarnings("ignore:The endianness of the requested serializer") @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) @pytest.mark.parametrize("dtype_input_endian", [">u2", " None: a[:, :] = data assert np.array_equal(data, a[:, :]) + + +def test_gzip_codec_supports_sync() -> None: + assert isinstance(GzipCodec(), SupportsSyncCodec) + + +def test_gzip_codec_sync_roundtrip() -> None: + codec = GzipCodec(level=1) + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_codecs/test_transpose.py b/tests/test_codecs/test_transpose.py index 06ec668ad3..949bb72a62 100644 --- a/tests/test_codecs/test_transpose.py +++ b/tests/test_codecs/test_transpose.py @@ -3,9 +3,13 @@ import zarr from zarr import AsyncArray, config +from zarr.abc.codec import SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import TransposeCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import NDBuffer, default_buffer_prototype from zarr.core.common import MemoryOrder +from zarr.core.dtype import get_data_type_from_native_dtype from zarr.storage import StorePath from .test_codecs import _AsyncArrayProxy @@ -93,3 +97,27 @@ def test_transpose_invalid( chunk_key_encoding={"name": "v2", "separator": "."}, filters=[TransposeCodec(order=order)], # type: ignore[arg-type] ) + + +def test_transpose_codec_supports_sync() -> None: + assert isinstance(TransposeCodec(order=(0, 1)), SupportsSyncCodec) + + +def test_transpose_codec_sync_roundtrip() -> None: + codec = TransposeCodec(order=(1, 0)) + arr = np.arange(12, dtype="float64").reshape(3, 4) + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + nd_buf: NDBuffer = default_buffer_prototype().nd_buffer.from_numpy_array(arr) + + encoded = codec._encode_sync(nd_buf, spec) + assert encoded is not None + resolved_spec = codec.resolve_metadata(spec) + decoded = codec._decode_sync(encoded, resolved_spec) + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) diff --git a/tests/test_codecs/test_vlen.py b/tests/test_codecs/test_vlen.py index cf0905daca..f3445824b3 100644 --- a/tests/test_codecs/test_vlen.py +++ b/tests/test_codecs/test_vlen.py @@ -5,9 +5,10 @@ import zarr from zarr import Array -from zarr.abc.codec import Codec +from zarr.abc.codec import Codec, SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import ZstdCodec +from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec from zarr.core.dtype import get_data_type_from_native_dtype from zarr.core.dtype.npy.string import _NUMPY_SUPPORTS_VLEN_STRING from zarr.core.metadata.v3 import ArrayV3Metadata @@ -62,3 +63,11 @@ def test_vlen_string( assert np.array_equal(data, b[:, :]) assert b.metadata.data_type == get_data_type_from_native_dtype(data.dtype) assert a.dtype == data.dtype + + +def test_vlen_utf8_codec_supports_sync() -> None: + assert isinstance(VLenUTF8Codec(), SupportsSyncCodec) + + +def test_vlen_bytes_codec_supports_sync() -> None: + assert isinstance(VLenBytesCodec(), SupportsSyncCodec) diff --git a/tests/test_codecs/test_zstd.py b/tests/test_codecs/test_zstd.py index 6068f53443..68297e4d94 100644 --- a/tests/test_codecs/test_zstd.py +++ b/tests/test_codecs/test_zstd.py @@ -2,8 +2,12 @@ import pytest import zarr +from zarr.abc.codec import SupportsSyncCodec from zarr.abc.store import Store from zarr.codecs import ZstdCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import default_buffer_prototype +from zarr.core.dtype import get_data_type_from_native_dtype from zarr.storage import StorePath @@ -23,3 +27,31 @@ def test_zstd(store: Store, checksum: bool) -> None: a[:, :] = data assert np.array_equal(data, a[:, :]) + + +def test_zstd_codec_supports_sync() -> None: + assert isinstance(ZstdCodec(), SupportsSyncCodec) + + +def test_zstd_is_not_fixed_size() -> None: + assert ZstdCodec.is_fixed_size is False + + +def test_zstd_codec_sync_roundtrip() -> None: + codec = ZstdCodec(level=1) + arr = np.arange(100, dtype="float64") + zdtype = get_data_type_from_native_dtype(arr.dtype) + spec = ArraySpec( + shape=arr.shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + buf = default_buffer_prototype().buffer.from_array_like(arr.view("B")) + + encoded = codec._encode_sync(buf, spec) + assert encoded is not None + decoded = codec._decode_sync(encoded, spec) + result = np.frombuffer(decoded.as_numpy_array(), dtype="float64") + np.testing.assert_array_equal(arr, result) diff --git a/tests/test_sync_codec_pipeline.py b/tests/test_sync_codec_pipeline.py new file mode 100644 index 0000000000..23fa28cb04 --- /dev/null +++ b/tests/test_sync_codec_pipeline.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import pytest + +from zarr.codecs.bytes import BytesCodec +from zarr.codecs.gzip import GzipCodec +from zarr.codecs.transpose import TransposeCodec +from zarr.codecs.zstd import ZstdCodec +from zarr.core.array_spec import ArrayConfig, ArraySpec +from zarr.core.buffer import NDBuffer, default_buffer_prototype +from zarr.core.dtype import get_data_type_from_native_dtype + +if TYPE_CHECKING: + from zarr.abc.codec import Codec + + +def _make_array_spec(shape: tuple[int, ...], dtype: np.dtype[np.generic]) -> ArraySpec: + zdtype = get_data_type_from_native_dtype(dtype) + return ArraySpec( + shape=shape, + dtype=zdtype, + fill_value=zdtype.cast_scalar(0), + config=ArrayConfig(order="C", write_empty_chunks=True), + prototype=default_buffer_prototype(), + ) + + +def _make_nd_buffer(arr: np.ndarray[Any, np.dtype[Any]]) -> NDBuffer: + return default_buffer_prototype().nd_buffer.from_numpy_array(arr) + + +class TestCodecChain: + def test_from_codecs_bytes_only(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + chain = CodecChain.from_codecs([BytesCodec()]) + assert chain.array_array_codecs == () + assert isinstance(chain.array_bytes_codec, BytesCodec) + assert chain.bytes_bytes_codecs == () + assert chain._all_sync is True + + def test_from_codecs_with_compression(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + chain = CodecChain.from_codecs([BytesCodec(), GzipCodec()]) + assert isinstance(chain.array_bytes_codec, BytesCodec) + assert len(chain.bytes_bytes_codecs) == 1 + assert isinstance(chain.bytes_bytes_codecs[0], GzipCodec) + assert chain._all_sync is True + + def test_from_codecs_with_transpose(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec()]) + assert len(chain.array_array_codecs) == 1 + assert isinstance(chain.array_array_codecs[0], TransposeCodec) + assert isinstance(chain.array_bytes_codec, BytesCodec) + assert chain._all_sync is True + + def test_from_codecs_full_chain(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()]) + assert len(chain.array_array_codecs) == 1 + assert isinstance(chain.array_bytes_codec, BytesCodec) + assert len(chain.bytes_bytes_codecs) == 1 + assert chain._all_sync is True + + def test_iter(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + codecs: list[Codec] = [TransposeCodec(order=(1, 0)), BytesCodec(), GzipCodec()] + chain = CodecChain.from_codecs(codecs) + assert list(chain) == codecs + + def test_frozen(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + chain = CodecChain.from_codecs([BytesCodec()]) + with pytest.raises(AttributeError): + chain.array_bytes_codec = BytesCodec() # type: ignore[misc] + + def test_encode_decode_roundtrip_bytes_only(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + chain = CodecChain.from_codecs([BytesCodec()]) + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain]) + nd_buf = _make_nd_buffer(arr) + + encoded = chain_evolved.encode_chunk(nd_buf, spec) + assert encoded is not None + decoded = chain_evolved.decode_chunk(encoded, spec) + assert decoded is not None + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + def test_encode_decode_roundtrip_with_compression(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + chain = CodecChain.from_codecs([BytesCodec(), GzipCodec(level=1)]) + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain]) + nd_buf = _make_nd_buffer(arr) + + encoded = chain_evolved.encode_chunk(nd_buf, spec) + assert encoded is not None + decoded = chain_evolved.decode_chunk(encoded, spec) + assert decoded is not None + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + def test_encode_decode_roundtrip_with_transpose(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + chain = CodecChain.from_codecs( + [TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)] + ) + arr = np.arange(12, dtype="float64").reshape(3, 4) + spec = _make_array_spec(arr.shape, arr.dtype) + chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain]) + nd_buf = _make_nd_buffer(arr) + + encoded = chain_evolved.encode_chunk(nd_buf, spec) + assert encoded is not None + decoded = chain_evolved.decode_chunk(encoded, spec) + assert decoded is not None + np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + def test_resolve_metadata_chain(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec(), GzipCodec()]) + arr = np.zeros((3, 4), dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain]) + + aa_chain, ab_pair, bb_chain = chain_evolved.resolve_metadata_chain(spec) + assert len(aa_chain) == 1 + assert aa_chain[0][1].shape == (3, 4) # spec before transpose + _ab_codec, ab_spec = ab_pair + assert ab_spec.shape == (4, 3) # spec after transpose + assert len(bb_chain) == 1 + + def test_resolve_metadata(self) -> None: + from zarr.core.codec_pipeline import CodecChain + + chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec()]) + spec = _make_array_spec((3, 4), np.dtype("float64")) + chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain]) + resolved = chain_evolved.resolve_metadata(spec) + # After transpose (1,0) + bytes, shape should reflect the transpose + assert resolved.shape == (4, 3) From 41b7a6ad78361de0018cc8031e0280810924d680 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 25 Feb 2026 20:34:34 -0500 Subject: [PATCH 03/14] refactor codecchain --- src/zarr/core/codec_pipeline.py | 114 ++++++++++----------------- tests/test_sync_codec_pipeline.py | 123 ++++++------------------------ 2 files changed, 65 insertions(+), 172 deletions(-) diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 7425ec30f6..9c0dd292ed 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -69,87 +69,67 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any: return fill_value -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class CodecChain: - """Lightweight codec chain: array-array -> array-bytes -> bytes-bytes. + """Codec chain with pre-resolved metadata specs. - Pure compute only -- no IO methods, no threading, no batching. + Constructed from an iterable of codecs and a chunk ArraySpec. + Resolves each codec against the spec so that encode/decode can + run without re-resolving. Pure compute only -- no IO, no threading, + no batching. """ - array_array_codecs: tuple[ArrayArrayCodec, ...] - array_bytes_codec: ArrayBytesCodec - bytes_bytes_codecs: tuple[BytesBytesCodec, ...] + codecs: tuple[Codec, ...] + chunk_spec: ArraySpec - _all_sync: bool = field(default=False, init=False, repr=False, compare=False) + _aa_codecs: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = field( + init=False, repr=False, compare=False + ) + _ab_codec: ArrayBytesCodec = field(init=False, repr=False, compare=False) + _ab_spec: ArraySpec = field(init=False, repr=False, compare=False) + _bb_codecs: tuple[BytesBytesCodec, ...] = field(init=False, repr=False, compare=False) + _all_sync: bool = field(init=False, repr=False, compare=False) def __post_init__(self) -> None: - object.__setattr__( - self, - "_all_sync", - all(isinstance(c, SupportsSyncCodec) for c in self), - ) - - def __iter__(self) -> Iterator[Codec]: - yield from self.array_array_codecs - yield self.array_bytes_codec - yield from self.bytes_bytes_codecs + aa, ab, bb = codecs_from_list(list(self.codecs)) - @classmethod - def from_codecs(cls, codecs: Iterable[Codec]) -> CodecChain: - aa, ab, bb = codecs_from_list(list(codecs)) - return cls(array_array_codecs=aa, array_bytes_codec=ab, bytes_bytes_codecs=bb) - - def resolve_metadata_chain( - self, chunk_spec: ArraySpec - ) -> tuple[ - list[tuple[ArrayArrayCodec, ArraySpec]], - tuple[ArrayBytesCodec, ArraySpec], - list[tuple[BytesBytesCodec, ArraySpec]], - ]: - """Resolve metadata through the codec chain for a single chunk_spec.""" - aa_codecs_with_spec: list[tuple[ArrayArrayCodec, ArraySpec]] = [] - spec = chunk_spec - for aa_codec in self.array_array_codecs: - aa_codecs_with_spec.append((aa_codec, spec)) + aa_pairs: list[tuple[ArrayArrayCodec, ArraySpec]] = [] + spec = self.chunk_spec + for aa_codec in aa: + aa_pairs.append((aa_codec, spec)) spec = aa_codec.resolve_metadata(spec) - ab_codec_with_spec = (self.array_bytes_codec, spec) - spec = self.array_bytes_codec.resolve_metadata(spec) + object.__setattr__(self, "_aa_codecs", tuple(aa_pairs)) + object.__setattr__(self, "_ab_codec", ab) + object.__setattr__(self, "_ab_spec", spec) - bb_codecs_with_spec: list[tuple[BytesBytesCodec, ArraySpec]] = [] - for bb_codec in self.bytes_bytes_codecs: - bb_codecs_with_spec.append((bb_codec, spec)) - spec = bb_codec.resolve_metadata(spec) + object.__setattr__(self, "_bb_codecs", bb) - return (aa_codecs_with_spec, ab_codec_with_spec, bb_codecs_with_spec) + object.__setattr__( + self, + "_all_sync", + all(isinstance(c, SupportsSyncCodec) for c in self.codecs), + ) + + @property + def all_sync(self) -> bool: + return self._all_sync def decode_chunk( self, chunk_bytes: Buffer, - chunk_spec: ArraySpec, - aa_chain: Iterable[tuple[ArrayArrayCodec, ArraySpec]] | None = None, - ab_pair: tuple[ArrayBytesCodec, ArraySpec] | None = None, - bb_chain: Iterable[tuple[BytesBytesCodec, ArraySpec]] | None = None, ) -> NDBuffer: """Decode a single chunk through the full codec chain, synchronously. Pure compute -- no IO. Only callable when all codecs support sync. - - The optional ``aa_chain``, ``ab_pair``, ``bb_chain`` parameters allow - pre-resolved metadata to be reused across many chunks with the same spec. - If not provided, ``resolve_metadata_chain`` is called internally. """ - if aa_chain is None or ab_pair is None or bb_chain is None: - aa_chain, ab_pair, bb_chain = self.resolve_metadata_chain(chunk_spec) - bb_out: Any = chunk_bytes - for bb_codec, spec in reversed(list(bb_chain)): - bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, spec) + for bb_codec in reversed(self._bb_codecs): + bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self.chunk_spec) - ab_codec, ab_spec = ab_pair - ab_out: Any = cast("SupportsSyncCodec", ab_codec)._decode_sync(bb_out, ab_spec) + ab_out: Any = cast("SupportsSyncCodec", self._ab_codec)._decode_sync(bb_out, self._ab_spec) - for aa_codec, spec in reversed(list(aa_chain)): + for aa_codec, spec in reversed(self._aa_codecs): ab_out = cast("SupportsSyncCodec", aa_codec)._decode_sync(ab_out, spec) return ab_out # type: ignore[no-any-return] @@ -157,45 +137,35 @@ def decode_chunk( def encode_chunk( self, chunk_array: NDBuffer, - chunk_spec: ArraySpec, ) -> Buffer | None: """Encode a single chunk through the full codec chain, synchronously. Pure compute -- no IO. Only callable when all codecs support sync. """ - spec = chunk_spec aa_out: Any = chunk_array - for aa_codec in self.array_array_codecs: + for aa_codec, spec in self._aa_codecs: if aa_out is None: return None aa_out = cast("SupportsSyncCodec", aa_codec)._encode_sync(aa_out, spec) - spec = aa_codec.resolve_metadata(spec) if aa_out is None: return None - bb_out: Any = cast("SupportsSyncCodec", self.array_bytes_codec)._encode_sync(aa_out, spec) - spec = self.array_bytes_codec.resolve_metadata(spec) + bb_out: Any = cast("SupportsSyncCodec", self._ab_codec)._encode_sync(aa_out, self._ab_spec) - for bb_codec in self.bytes_bytes_codecs: + for bb_codec in self._bb_codecs: if bb_out is None: return None - bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, spec) - spec = bb_codec.resolve_metadata(spec) + bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self.chunk_spec) return bb_out # type: ignore[no-any-return] def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: - for codec in self: + for codec in self.codecs: byte_length = codec.compute_encoded_size(byte_length, array_spec) array_spec = codec.resolve_metadata(array_spec) return byte_length - def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: - for codec in self: - chunk_spec = codec.resolve_metadata(chunk_spec) - return chunk_spec - @dataclass(frozen=True) class BatchedCodecPipeline(CodecPipeline): diff --git a/tests/test_sync_codec_pipeline.py b/tests/test_sync_codec_pipeline.py index 23fa28cb04..192479dc59 100644 --- a/tests/test_sync_codec_pipeline.py +++ b/tests/test_sync_codec_pipeline.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any import numpy as np -import pytest from zarr.codecs.bytes import BytesCodec from zarr.codecs.gzip import GzipCodec @@ -11,11 +10,9 @@ from zarr.codecs.zstd import ZstdCodec from zarr.core.array_spec import ArrayConfig, ArraySpec from zarr.core.buffer import NDBuffer, default_buffer_prototype +from zarr.core.codec_pipeline import CodecChain from zarr.core.dtype import get_data_type_from_native_dtype -if TYPE_CHECKING: - from zarr.abc.codec import Codec - def _make_array_spec(shape: tuple[int, ...], dtype: np.dtype[np.generic]) -> ArraySpec: zdtype = get_data_type_from_native_dtype(dtype) @@ -33,124 +30,50 @@ def _make_nd_buffer(arr: np.ndarray[Any, np.dtype[Any]]) -> NDBuffer: class TestCodecChain: - def test_from_codecs_bytes_only(self) -> None: - from zarr.core.codec_pipeline import CodecChain - - chain = CodecChain.from_codecs([BytesCodec()]) - assert chain.array_array_codecs == () - assert isinstance(chain.array_bytes_codec, BytesCodec) - assert chain.bytes_bytes_codecs == () - assert chain._all_sync is True - - def test_from_codecs_with_compression(self) -> None: - from zarr.core.codec_pipeline import CodecChain - - chain = CodecChain.from_codecs([BytesCodec(), GzipCodec()]) - assert isinstance(chain.array_bytes_codec, BytesCodec) - assert len(chain.bytes_bytes_codecs) == 1 - assert isinstance(chain.bytes_bytes_codecs[0], GzipCodec) - assert chain._all_sync is True - - def test_from_codecs_with_transpose(self) -> None: - from zarr.core.codec_pipeline import CodecChain - - chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec()]) - assert len(chain.array_array_codecs) == 1 - assert isinstance(chain.array_array_codecs[0], TransposeCodec) - assert isinstance(chain.array_bytes_codec, BytesCodec) - assert chain._all_sync is True + def test_all_sync(self) -> None: + spec = _make_array_spec((100,), np.dtype("float64")) + chain = CodecChain((BytesCodec(),), spec) + assert chain.all_sync is True - def test_from_codecs_full_chain(self) -> None: - from zarr.core.codec_pipeline import CodecChain + def test_all_sync_with_compression(self) -> None: + spec = _make_array_spec((100,), np.dtype("float64")) + chain = CodecChain((BytesCodec(), GzipCodec()), spec) + assert chain.all_sync is True - chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()]) - assert len(chain.array_array_codecs) == 1 - assert isinstance(chain.array_bytes_codec, BytesCodec) - assert len(chain.bytes_bytes_codecs) == 1 - assert chain._all_sync is True - - def test_iter(self) -> None: - from zarr.core.codec_pipeline import CodecChain - - codecs: list[Codec] = [TransposeCodec(order=(1, 0)), BytesCodec(), GzipCodec()] - chain = CodecChain.from_codecs(codecs) - assert list(chain) == codecs - - def test_frozen(self) -> None: - from zarr.core.codec_pipeline import CodecChain - - chain = CodecChain.from_codecs([BytesCodec()]) - with pytest.raises(AttributeError): - chain.array_bytes_codec = BytesCodec() # type: ignore[misc] + def test_all_sync_full_chain(self) -> None: + spec = _make_array_spec((3, 4), np.dtype("float64")) + chain = CodecChain((TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()), spec) + assert chain.all_sync is True def test_encode_decode_roundtrip_bytes_only(self) -> None: - from zarr.core.codec_pipeline import CodecChain - - chain = CodecChain.from_codecs([BytesCodec()]) arr = np.arange(100, dtype="float64") spec = _make_array_spec(arr.shape, arr.dtype) - chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain]) + chain = CodecChain((BytesCodec(),), spec) nd_buf = _make_nd_buffer(arr) - encoded = chain_evolved.encode_chunk(nd_buf, spec) + encoded = chain.encode_chunk(nd_buf) assert encoded is not None - decoded = chain_evolved.decode_chunk(encoded, spec) - assert decoded is not None + decoded = chain.decode_chunk(encoded) np.testing.assert_array_equal(arr, decoded.as_numpy_array()) def test_encode_decode_roundtrip_with_compression(self) -> None: - from zarr.core.codec_pipeline import CodecChain - - chain = CodecChain.from_codecs([BytesCodec(), GzipCodec(level=1)]) arr = np.arange(100, dtype="float64") spec = _make_array_spec(arr.shape, arr.dtype) - chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain]) + chain = CodecChain((BytesCodec(), GzipCodec(level=1)), spec) nd_buf = _make_nd_buffer(arr) - encoded = chain_evolved.encode_chunk(nd_buf, spec) + encoded = chain.encode_chunk(nd_buf) assert encoded is not None - decoded = chain_evolved.decode_chunk(encoded, spec) - assert decoded is not None + decoded = chain.decode_chunk(encoded) np.testing.assert_array_equal(arr, decoded.as_numpy_array()) def test_encode_decode_roundtrip_with_transpose(self) -> None: - from zarr.core.codec_pipeline import CodecChain - - chain = CodecChain.from_codecs( - [TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)] - ) arr = np.arange(12, dtype="float64").reshape(3, 4) spec = _make_array_spec(arr.shape, arr.dtype) - chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain]) + chain = CodecChain((TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)), spec) nd_buf = _make_nd_buffer(arr) - encoded = chain_evolved.encode_chunk(nd_buf, spec) + encoded = chain.encode_chunk(nd_buf) assert encoded is not None - decoded = chain_evolved.decode_chunk(encoded, spec) - assert decoded is not None + decoded = chain.decode_chunk(encoded) np.testing.assert_array_equal(arr, decoded.as_numpy_array()) - - def test_resolve_metadata_chain(self) -> None: - from zarr.core.codec_pipeline import CodecChain - - chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec(), GzipCodec()]) - arr = np.zeros((3, 4), dtype="float64") - spec = _make_array_spec(arr.shape, arr.dtype) - chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain]) - - aa_chain, ab_pair, bb_chain = chain_evolved.resolve_metadata_chain(spec) - assert len(aa_chain) == 1 - assert aa_chain[0][1].shape == (3, 4) # spec before transpose - _ab_codec, ab_spec = ab_pair - assert ab_spec.shape == (4, 3) # spec after transpose - assert len(bb_chain) == 1 - - def test_resolve_metadata(self) -> None: - from zarr.core.codec_pipeline import CodecChain - - chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec()]) - spec = _make_array_spec((3, 4), np.dtype("float64")) - chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain]) - resolved = chain_evolved.resolve_metadata(spec) - # After transpose (1,0) + bytes, shape should reflect the transpose - assert resolved.shape == (4, 3) From 5a2a884e249277086fe6d706b25881b13b331f2d Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 25 Feb 2026 20:50:19 -0500 Subject: [PATCH 04/14] separate codecs and specs --- src/zarr/core/codec_pipeline.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 9c0dd292ed..4412ffa705 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -75,35 +75,37 @@ class CodecChain: Constructed from an iterable of codecs and a chunk ArraySpec. Resolves each codec against the spec so that encode/decode can - run without re-resolving. Pure compute only -- no IO, no threading, - no batching. + run without re-resolving. """ codecs: tuple[Codec, ...] chunk_spec: ArraySpec - _aa_codecs: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = field( - init=False, repr=False, compare=False - ) + _aa_codecs: tuple[ArrayArrayCodec, ...] = field(init=False, repr=False, compare=False) + _aa_specs: tuple[ArraySpec, ...] = field(init=False, repr=False, compare=False) _ab_codec: ArrayBytesCodec = field(init=False, repr=False, compare=False) _ab_spec: ArraySpec = field(init=False, repr=False, compare=False) _bb_codecs: tuple[BytesBytesCodec, ...] = field(init=False, repr=False, compare=False) + _bb_spec: ArraySpec = field(init=False, repr=False, compare=False) _all_sync: bool = field(init=False, repr=False, compare=False) def __post_init__(self) -> None: aa, ab, bb = codecs_from_list(list(self.codecs)) - aa_pairs: list[tuple[ArrayArrayCodec, ArraySpec]] = [] + aa_specs: list[ArraySpec] = [] spec = self.chunk_spec for aa_codec in aa: - aa_pairs.append((aa_codec, spec)) + aa_specs.append(spec) spec = aa_codec.resolve_metadata(spec) - object.__setattr__(self, "_aa_codecs", tuple(aa_pairs)) + object.__setattr__(self, "_aa_codecs", aa) + object.__setattr__(self, "_aa_specs", tuple(aa_specs)) object.__setattr__(self, "_ab_codec", ab) object.__setattr__(self, "_ab_spec", spec) + spec = ab.resolve_metadata(spec) object.__setattr__(self, "_bb_codecs", bb) + object.__setattr__(self, "_bb_spec", spec) object.__setattr__( self, @@ -125,11 +127,11 @@ def decode_chunk( """ bb_out: Any = chunk_bytes for bb_codec in reversed(self._bb_codecs): - bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self.chunk_spec) + bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self._bb_spec) ab_out: Any = cast("SupportsSyncCodec", self._ab_codec)._decode_sync(bb_out, self._ab_spec) - for aa_codec, spec in reversed(self._aa_codecs): + for aa_codec, spec in zip(reversed(self._aa_codecs), reversed(self._aa_specs), strict=True): ab_out = cast("SupportsSyncCodec", aa_codec)._decode_sync(ab_out, spec) return ab_out # type: ignore[no-any-return] @@ -144,7 +146,7 @@ def encode_chunk( """ aa_out: Any = chunk_array - for aa_codec, spec in self._aa_codecs: + for aa_codec, spec in zip(self._aa_codecs, self._aa_specs, strict=True): if aa_out is None: return None aa_out = cast("SupportsSyncCodec", aa_codec)._encode_sync(aa_out, spec) @@ -156,7 +158,7 @@ def encode_chunk( for bb_codec in self._bb_codecs: if bb_out is None: return None - bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self.chunk_spec) + bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self._bb_spec) return bb_out # type: ignore[no-any-return] From 4e262b17d120dd0ac9ac2e83a0e576d3e4876038 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 26 Feb 2026 11:05:55 -0500 Subject: [PATCH 05/14] add synchronous methods to stores --- src/zarr/abc/store.py | 44 +++++++++++- src/zarr/storage/_common.py | 27 +++++++ src/zarr/storage/_local.py | 69 ++++++++++++++++++ src/zarr/storage/_memory.py | 64 ++++++++++++++++- src/zarr/testing/store.py | 136 +++++++++++++++++++++++++++++++++++- tests/test_indexing.py | 17 +++++ 6 files changed, 354 insertions(+), 3 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 87df89a683..1d321f5fc3 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -16,7 +16,17 @@ from zarr.core.buffer import Buffer, BufferPrototype -__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"] +__all__ = [ + "ByteGetter", + "ByteSetter", + "Store", + "SupportsDeleteSync", + "SupportsGetSync", + "SupportsSetRangeSync", + "SupportsSetSync", + "SupportsSyncStore", + "set_or_delete", +] @dataclass @@ -700,6 +710,38 @@ async def delete(self) -> None: ... async def set_if_not_exists(self, default: Buffer) -> None: ... +@runtime_checkable +class SupportsGetSync(Protocol): + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: ... + + +@runtime_checkable +class SupportsSetSync(Protocol): + def set_sync(self, key: str, value: Buffer) -> None: ... + + +@runtime_checkable +class SupportsSetRangeSync(Protocol): + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: ... + + +@runtime_checkable +class SupportsDeleteSync(Protocol): + def delete_sync(self, key: str) -> None: ... + + +@runtime_checkable +class SupportsSyncStore( + SupportsGetSync, SupportsSetSync, SupportsSetRangeSync, SupportsDeleteSync, Protocol +): ... + + async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None: """Set or delete a value in a byte setter diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 4bea04f024..c14aa1a37d 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -228,6 +228,33 @@ async def is_empty(self) -> bool: """ return await self.store.is_empty(self.path) + # ------------------------------------------------------------------- + # Synchronous IO delegation + # ------------------------------------------------------------------- + + def get_sync( + self, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + """Synchronous read — delegates to ``self.store.get_sync(self.path, ...)``.""" + if prototype is None: + prototype = default_buffer_prototype() + return self.store.get_sync(self.path, prototype=prototype, byte_range=byte_range) # type: ignore[attr-defined, no-any-return] + + def set_sync(self, value: Buffer) -> None: + """Synchronous write — delegates to ``self.store.set_sync(self.path, value)``.""" + self.store.set_sync(self.path, value) # type: ignore[attr-defined] + + def set_range_sync(self, value: Buffer, start: int) -> None: + """Synchronous byte-range write.""" + self.store.set_range_sync(self.path, value, start) # type: ignore[attr-defined] + + def delete_sync(self) -> None: + """Synchronous delete — delegates to ``self.store.delete_sync(self.path)``.""" + self.store.delete_sync(self.path) # type: ignore[attr-defined] + def __truediv__(self, other: str) -> StorePath: """Combine this store path with another path""" return self.__class__(self.store, _dereference_path(self.path, other)) diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 80233a112d..dc05ff67a7 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -85,6 +85,19 @@ def _put(path: Path, value: Buffer, exclusive: bool = False) -> int: return f.write(view) +def _put_range(path: Path, value: Buffer, start: int) -> None: + view = value.as_buffer_like() + file_size = path.stat().st_size + if start + len(view) > file_size: + raise ValueError( + f"set_range would write beyond the end of the stored value: " + f"start={start}, len(value)={len(view)}, stored size={file_size}" + ) + with path.open("r+b") as f: + f.seek(start) + f.write(view) + + class LocalStore(Store): """ Store for the local file system. @@ -187,6 +200,62 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root + # ------------------------------------------------------------------- + # Synchronous store methods + # ------------------------------------------------------------------- + + def _ensure_open_sync(self) -> None: + if not self._is_open: + if not self.read_only: + self.root.mkdir(parents=True, exist_ok=True) + if not self.root.exists(): + raise FileNotFoundError(f"{self.root} does not exist") + self._is_open = True + + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + if prototype is None: + prototype = default_buffer_prototype() + self._ensure_open_sync() + assert isinstance(key, str) + path = self.root / key + try: + return _get(path, prototype, byte_range) + except (FileNotFoundError, IsADirectoryError, NotADirectoryError): + return None + + def set_sync(self, key: str, value: Buffer) -> None: + self._ensure_open_sync() + self._check_writable() + assert isinstance(key, str) + if not isinstance(value, Buffer): + raise TypeError( + f"LocalStore.set(): `value` must be a Buffer instance. " + f"Got an instance of {type(value)} instead." + ) + path = self.root / key + _put(path, value) + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + self._ensure_open_sync() + self._check_writable() + path = self.root / key + _put_range(path, value, start) + + def delete_sync(self, key: str) -> None: + self._ensure_open_sync() + self._check_writable() + path = self.root / key + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink(missing_ok=True) + async def get( self, key: str, diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index e6f9b7a512..168f3e8890 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -77,6 +77,49 @@ def __eq__(self, other: object) -> bool: and self.read_only == other.read_only ) + # ------------------------------------------------------------------- + # Synchronous store methods + # ------------------------------------------------------------------- + + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + if prototype is None: + prototype = default_buffer_prototype() + if not self._is_open: + self._is_open = True + assert isinstance(key, str) + try: + value = self._store_dict[key] + start, stop = _normalize_byte_range_index(value, byte_range) + return prototype.buffer.from_buffer(value[start:stop]) + except KeyError: + return None + + def set_sync(self, key: str, value: Buffer) -> None: + self._check_writable() + if not self._is_open: + self._is_open = True + assert isinstance(key, str) + if not isinstance(value, Buffer): + raise TypeError( + f"MemoryStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." + ) + self._store_dict[key] = value + + def delete_sync(self, key: str) -> None: + self._check_writable() + if not self._is_open: + self._is_open = True + try: + del self._store_dict[key] + except KeyError: + logger.debug("Key %s does not exist.", key) + async def get( self, key: str, @@ -122,7 +165,6 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None raise TypeError( f"MemoryStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." ) - if byte_range is not None: buf = self._store_dict[key] buf[byte_range[0] : byte_range[1]] = value @@ -130,6 +172,26 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None else: self._store_dict[key] = value + def _set_range_impl(self, key: str, value: Buffer, start: int) -> None: + buf = self._store_dict[key] + target = buf.as_numpy_array() + if start + len(value) > len(target): + raise ValueError( + f"set_range would write beyond the end of the stored value: " + f"start={start}, len(value)={len(value)}, stored size={len(target)}" + ) + if not target.flags.writeable: + target = target.copy() + self._store_dict[key] = buf.__class__(target) + target[start : start + len(value)] = value.as_numpy_array() + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + """Synchronous byte-range write.""" + self._check_writable() + if not self._is_open: + self._is_open = True + self._set_range_impl(key, value, start) + async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited self._check_writable() diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 1b8e85ed98..bb60cc371f 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -6,12 +6,13 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Generic, Self, TypeVar +import numpy as np + from zarr.storage import WrapperStore if TYPE_CHECKING: from typing import Any - from zarr.abc.store import ByteRequest from zarr.core.buffer.core import BufferPrototype import pytest @@ -22,6 +23,10 @@ RangeByteRequest, Store, SuffixByteRequest, + SupportsDeleteSync, + SupportsGetSync, + SupportsSetRangeSync, + SupportsSetSync, ) from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.sync import _collect_aiterator, sync @@ -39,6 +44,34 @@ class StoreTests(Generic[S, B]): store_cls: type[S] buffer_cls: type[B] + @staticmethod + def _require_get_sync(store: S) -> SupportsGetSync: + """Skip unless *store* implements :class:`SupportsGetSync`.""" + if not isinstance(store, SupportsGetSync): + pytest.skip("store does not implement SupportsGetSync") + return store # type: ignore[unreachable] + + @staticmethod + def _require_set_sync(store: S) -> SupportsSetSync: + """Skip unless *store* implements :class:`SupportsSetSync`.""" + if not isinstance(store, SupportsSetSync): + pytest.skip("store does not implement SupportsSetSync") + return store # type: ignore[unreachable] + + @staticmethod + def _require_set_range_sync(store: S) -> SupportsSetRangeSync: + """Skip unless *store* implements :class:`SupportsSetRangeSync`.""" + if not isinstance(store, SupportsSetRangeSync): + pytest.skip("store does not implement SupportsSetRangeSync") + return store # type: ignore[unreachable] + + @staticmethod + def _require_delete_sync(store: S) -> SupportsDeleteSync: + """Skip unless *store* implements :class:`SupportsDeleteSync`.""" + if not isinstance(store, SupportsDeleteSync): + pytest.skip("store does not implement SupportsDeleteSync") + return store # type: ignore[unreachable] + @abstractmethod async def set(self, store: S, key: str, value: Buffer) -> None: """ @@ -579,6 +612,107 @@ def test_get_json_sync(self, store: S) -> None: sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) assert store._get_json_sync(key, prototype=default_buffer_prototype()) == data + # ------------------------------------------------------------------- + # Synchronous store methods (SupportsSyncStore protocol) + # ------------------------------------------------------------------- + + def test_get_sync(self, store: S) -> None: + getter = self._require_get_sync(store) + data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") + key = "sync_get" + sync(self.set(store, key, data_buf)) + result = getter.get_sync(key) + assert result is not None + assert_bytes_equal(result, data_buf) + + def test_get_sync_missing(self, store: S) -> None: + getter = self._require_get_sync(store) + result = getter.get_sync("nonexistent") + assert result is None + + def test_set_sync(self, store: S) -> None: + setter = self._require_set_sync(store) + data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") + key = "sync_set" + setter.set_sync(key, data_buf) + result = sync(self.get(store, key)) + assert_bytes_equal(result, data_buf) + + def test_delete_sync(self, store: S) -> None: + setter = self._require_set_sync(store) + deleter = self._require_delete_sync(store) + getter = self._require_get_sync(store) + if not store.supports_deletes: + pytest.skip("store does not support deletes") + data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") + key = "sync_delete" + setter.set_sync(key, data_buf) + deleter.delete_sync(key) + result = getter.get_sync(key) + assert result is None + + def test_delete_sync_missing(self, store: S) -> None: + deleter = self._require_delete_sync(store) + if not store.supports_deletes: + pytest.skip("store does not support deletes") + # should not raise + deleter.delete_sync("nonexistent_sync") + + # ------------------------------------------------------------------- + # set_range (sync only — set_range is exclusively a sync-path API) + # ------------------------------------------------------------------- + + def test_set_range_sync(self, store: S) -> None: + setter = self._require_set_sync(store) + ranger = self._require_set_range_sync(store) + getter = self._require_get_sync(store) + data_buf = self.buffer_cls.from_bytes(b"hello world") + key = "range_sync_key" + setter.set_sync(key, data_buf) + patch = default_buffer_prototype().buffer.from_bytes(b"WORLD") + ranger.set_range_sync(key, patch, 6) + result = getter.get_sync(key) + assert result is not None + assert result.to_bytes() == b"hello WORLD" + + def test_set_range_sync_preserves_other_bytes(self, store: S) -> None: + setter = self._require_set_sync(store) + ranger = self._require_set_range_sync(store) + getter = self._require_get_sync(store) + data = np.arange(100, dtype="uint8") + data_buf = default_buffer_prototype().buffer.from_array_like(data) + key = "range_preserve" + setter.set_sync(key, data_buf) + patch = np.full(10, 255, dtype="uint8") + patch_buf = default_buffer_prototype().buffer.from_array_like(patch) + ranger.set_range_sync(key, patch_buf, 50) + result = getter.get_sync(key) + assert result is not None + result_arr = np.frombuffer(result.to_bytes(), dtype="uint8") + expected = data.copy() + expected[50:60] = 255 + np.testing.assert_array_equal(result_arr, expected) + + def test_set_range_sync_beyond_end_raises(self, store: S) -> None: + setter = self._require_set_sync(store) + ranger = self._require_set_range_sync(store) + data_buf = self.buffer_cls.from_bytes(b"hello") + key = "range_oob" + setter.set_sync(key, data_buf) + patch = default_buffer_prototype().buffer.from_bytes(b"world!") + with pytest.raises(ValueError, match="set_range would write beyond"): + ranger.set_range_sync(key, patch, 0) + + def test_set_range_sync_start_beyond_end_raises(self, store: S) -> None: + setter = self._require_set_sync(store) + ranger = self._require_set_range_sync(store) + data_buf = self.buffer_cls.from_bytes(b"hello") + key = "range_oob2" + setter.set_sync(key, data_buf) + patch = default_buffer_prototype().buffer.from_bytes(b"x") + with pytest.raises(ValueError, match="set_range would write beyond"): + ranger.set_range_sync(key, patch, 10) + class LatencyStore(WrapperStore[Store]): """ diff --git a/tests/test_indexing.py b/tests/test_indexing.py index c0bf7dd270..9c734fb0c3 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator + from zarr.abc.store import ByteRequest from zarr.core.buffer import BufferPrototype from zarr.core.buffer.core import Buffer @@ -83,6 +84,22 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None self.counter["__setitem__", key_suffix] += 1 return await super().set(key, value, byte_range) + def get_sync( + self, + key: str, + *, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + key_suffix = "/".join(key.split("/")[1:]) + self.counter["__getitem__", key_suffix] += 1 + return super().get_sync(key, prototype=prototype, byte_range=byte_range) + + def set_sync(self, key: str, value: Buffer) -> None: + key_suffix = "/".join(key.split("/")[1:]) + self.counter["__setitem__", key_suffix] += 1 + return super().set_sync(key, value) + def test_normalize_integer_selection() -> None: assert 1 == normalize_integer_selection(1, 100) From 71a780b157164667c7684a2c0485b77d2afc530f Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 27 Feb 2026 14:40:52 -0500 Subject: [PATCH 06/14] chunktransform --- src/zarr/core/codec_pipeline.py | 66 +++++++++++++++++-------------- tests/test_sync_codec_pipeline.py | 47 ++++++++++++++++++---- 2 files changed, 75 insertions(+), 38 deletions(-) diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 4412ffa705..0c3cccb1d9 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -69,49 +69,55 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any: return fill_value -@dataclass(frozen=True, slots=True) -class CodecChain: - """Codec chain with pre-resolved metadata specs. +@dataclass(slots=True, kw_only=True) +class ChunkTransform: + """A stored chunk, modeled as a layered array. - Constructed from an iterable of codecs and a chunk ArraySpec. - Resolves each codec against the spec so that encode/decode can - run without re-resolving. + Each layer corresponds to one ArrayArrayCodec and the ArraySpec + at its input boundary. ``layers[0]`` is the outermost (user-visible) + transform; after the last layer comes the ArrayBytesCodec. + + The chunk's ``shape`` and ``dtype`` reflect the representation + **after** all ArrayArrayCodec layers have been applied — i.e. the + spec that feeds the ArrayBytesCodec. """ codecs: tuple[Codec, ...] - chunk_spec: ArraySpec + array_spec: ArraySpec - _aa_codecs: tuple[ArrayArrayCodec, ...] = field(init=False, repr=False, compare=False) - _aa_specs: tuple[ArraySpec, ...] = field(init=False, repr=False, compare=False) + # Each element is (ArrayArrayCodec, input_spec_for_that_codec). + layers: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = field( + init=False, repr=False, compare=False + ) _ab_codec: ArrayBytesCodec = field(init=False, repr=False, compare=False) _ab_spec: ArraySpec = field(init=False, repr=False, compare=False) _bb_codecs: tuple[BytesBytesCodec, ...] = field(init=False, repr=False, compare=False) - _bb_spec: ArraySpec = field(init=False, repr=False, compare=False) _all_sync: bool = field(init=False, repr=False, compare=False) def __post_init__(self) -> None: aa, ab, bb = codecs_from_list(list(self.codecs)) - aa_specs: list[ArraySpec] = [] - spec = self.chunk_spec + layers: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = () + spec = self.array_spec for aa_codec in aa: - aa_specs.append(spec) + layers = (*layers, (aa_codec, spec)) spec = aa_codec.resolve_metadata(spec) - object.__setattr__(self, "_aa_codecs", aa) - object.__setattr__(self, "_aa_specs", tuple(aa_specs)) - object.__setattr__(self, "_ab_codec", ab) - object.__setattr__(self, "_ab_spec", spec) + self.layers = layers + self._ab_codec = ab + self._ab_spec = spec + self._bb_codecs = bb + self._all_sync = all(isinstance(c, SupportsSyncCodec) for c in self.codecs) - spec = ab.resolve_metadata(spec) - object.__setattr__(self, "_bb_codecs", bb) - object.__setattr__(self, "_bb_spec", spec) + @property + def shape(self) -> tuple[int, ...]: + """Shape after all ArrayArrayCodec layers (input to the ArrayBytesCodec).""" + return self._ab_spec.shape - object.__setattr__( - self, - "_all_sync", - all(isinstance(c, SupportsSyncCodec) for c in self.codecs), - ) + @property + def dtype(self) -> ZDType[TBaseDType, TBaseScalar]: + """Dtype after all ArrayArrayCodec layers (input to the ArrayBytesCodec).""" + return self._ab_spec.dtype @property def all_sync(self) -> bool: @@ -127,11 +133,11 @@ def decode_chunk( """ bb_out: Any = chunk_bytes for bb_codec in reversed(self._bb_codecs): - bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self._bb_spec) + bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self._ab_spec) ab_out: Any = cast("SupportsSyncCodec", self._ab_codec)._decode_sync(bb_out, self._ab_spec) - for aa_codec, spec in zip(reversed(self._aa_codecs), reversed(self._aa_specs), strict=True): + for aa_codec, spec in reversed(self.layers): ab_out = cast("SupportsSyncCodec", aa_codec)._decode_sync(ab_out, spec) return ab_out # type: ignore[no-any-return] @@ -146,7 +152,7 @@ def encode_chunk( """ aa_out: Any = chunk_array - for aa_codec, spec in zip(self._aa_codecs, self._aa_specs, strict=True): + for aa_codec, spec in self.layers: if aa_out is None: return None aa_out = cast("SupportsSyncCodec", aa_codec)._encode_sync(aa_out, spec) @@ -158,7 +164,7 @@ def encode_chunk( for bb_codec in self._bb_codecs: if bb_out is None: return None - bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self._bb_spec) + bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self._ab_spec) return bb_out # type: ignore[no-any-return] @@ -369,7 +375,7 @@ async def read_batch( out[out_selection] = fill_value_or_default(chunk_spec) else: chunk_bytes_batch = await concurrent_map( - [(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info], + [(byte_getter, chunk_spec.prototype) for byte_getter, chunk_spec, *_ in batch_info], lambda byte_getter, prototype: byte_getter.get(prototype), config.get("async.concurrency"), ) diff --git a/tests/test_sync_codec_pipeline.py b/tests/test_sync_codec_pipeline.py index 192479dc59..e9d05dcec6 100644 --- a/tests/test_sync_codec_pipeline.py +++ b/tests/test_sync_codec_pipeline.py @@ -10,7 +10,7 @@ from zarr.codecs.zstd import ZstdCodec from zarr.core.array_spec import ArrayConfig, ArraySpec from zarr.core.buffer import NDBuffer, default_buffer_prototype -from zarr.core.codec_pipeline import CodecChain +from zarr.core.codec_pipeline import ChunkTransform from zarr.core.dtype import get_data_type_from_native_dtype @@ -29,26 +29,28 @@ def _make_nd_buffer(arr: np.ndarray[Any, np.dtype[Any]]) -> NDBuffer: return default_buffer_prototype().nd_buffer.from_numpy_array(arr) -class TestCodecChain: +class TestChunkTransform: def test_all_sync(self) -> None: spec = _make_array_spec((100,), np.dtype("float64")) - chain = CodecChain((BytesCodec(),), spec) + chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec) assert chain.all_sync is True def test_all_sync_with_compression(self) -> None: spec = _make_array_spec((100,), np.dtype("float64")) - chain = CodecChain((BytesCodec(), GzipCodec()), spec) + chain = ChunkTransform(codecs=(BytesCodec(), GzipCodec()), array_spec=spec) assert chain.all_sync is True def test_all_sync_full_chain(self) -> None: spec = _make_array_spec((3, 4), np.dtype("float64")) - chain = CodecChain((TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()), spec) + chain = ChunkTransform( + codecs=(TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()), array_spec=spec + ) assert chain.all_sync is True def test_encode_decode_roundtrip_bytes_only(self) -> None: arr = np.arange(100, dtype="float64") spec = _make_array_spec(arr.shape, arr.dtype) - chain = CodecChain((BytesCodec(),), spec) + chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec) nd_buf = _make_nd_buffer(arr) encoded = chain.encode_chunk(nd_buf) @@ -56,10 +58,36 @@ def test_encode_decode_roundtrip_bytes_only(self) -> None: decoded = chain.decode_chunk(encoded) np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + def test_layers_no_aa_codecs(self) -> None: + spec = _make_array_spec((100,), np.dtype("float64")) + chunk = ChunkTransform(codecs=(BytesCodec(), GzipCodec()), array_spec=spec) + assert chunk.layers == () + + def test_layers_with_transpose(self) -> None: + spec = _make_array_spec((3, 4), np.dtype("float64")) + transpose = TransposeCodec(order=(1, 0)) + chunk = ChunkTransform(codecs=(transpose, BytesCodec(), ZstdCodec()), array_spec=spec) + assert len(chunk.layers) == 1 + assert chunk.layers[0][0] is transpose + assert chunk.layers[0][1] is spec + + def test_shape_dtype_no_aa_codecs(self) -> None: + spec = _make_array_spec((100,), np.dtype("float64")) + chunk = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec) + assert chunk.shape == (100,) + assert chunk.dtype == spec.dtype + + def test_shape_dtype_with_transpose(self) -> None: + spec = _make_array_spec((3, 4), np.dtype("float64")) + chunk = ChunkTransform(codecs=(TransposeCodec(order=(1, 0)), BytesCodec()), array_spec=spec) + # After transpose (1,0), shape (3,4) becomes (4,3) + assert chunk.shape == (4, 3) + assert chunk.dtype == spec.dtype + def test_encode_decode_roundtrip_with_compression(self) -> None: arr = np.arange(100, dtype="float64") spec = _make_array_spec(arr.shape, arr.dtype) - chain = CodecChain((BytesCodec(), GzipCodec(level=1)), spec) + chain = ChunkTransform(codecs=(BytesCodec(), GzipCodec(level=1)), array_spec=spec) nd_buf = _make_nd_buffer(arr) encoded = chain.encode_chunk(nd_buf) @@ -70,7 +98,10 @@ def test_encode_decode_roundtrip_with_compression(self) -> None: def test_encode_decode_roundtrip_with_transpose(self) -> None: arr = np.arange(12, dtype="float64").reshape(3, 4) spec = _make_array_spec(arr.shape, arr.dtype) - chain = CodecChain((TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)), spec) + chain = ChunkTransform( + codecs=(TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)), + array_spec=spec, + ) nd_buf = _make_nd_buffer(arr) encoded = chain.encode_chunk(nd_buf) From 9b949849b6f7d23aef954fc21eb260540c8afb50 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 26 Feb 2026 17:51:33 -0500 Subject: [PATCH 07/14] add prepared write logic --- src/zarr/abc/codec.py | 183 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 181 insertions(+), 2 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 3ec5ec522b..30189e4cbb 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -2,7 +2,8 @@ from abc import abstractmethod from collections.abc import Mapping -from typing import TYPE_CHECKING, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable from typing_extensions import ReadOnly, TypedDict @@ -19,7 +20,7 @@ from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import ChunkGrid from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType - from zarr.core.indexing import SelectorTuple + from zarr.core.indexing import ChunkProjection, SelectorTuple from zarr.core.metadata import ArrayMetadata __all__ = [ @@ -32,6 +33,7 @@ "CodecInput", "CodecOutput", "CodecPipeline", + "PreparedWrite", "SupportsSyncCodec", ] @@ -204,9 +206,186 @@ class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]): """Base class for array-to-array codecs.""" +@dataclass +class PreparedWrite: + """Result of ``prepare_write``: existing encoded chunk bytes + selection info.""" + + chunk_dict: dict[tuple[int, ...], Buffer | None] + inner_codec_chain: Any # CodecChain — typed as Any to avoid circular import + inner_chunk_spec: ArraySpec + indexer: list[ChunkProjection] + value_selection: SelectorTuple | None = None + write_full_shard: bool = True + is_complete_shard: bool = False + shard_data: NDBuffer | None = field(default=None) + + class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): """Base class for array-to-bytes codecs.""" + @property + def inner_codec_chain(self) -> Any: + """The codec chain for decoding inner chunks after deserialization. + + Returns ``None`` by default — the pipeline should use its own codec chain. + ``ShardingCodec`` overrides to return its inner codec chain. + """ + return None + + def deserialize( + self, raw: Buffer | None, chunk_spec: ArraySpec + ) -> dict[tuple[int, ...], Buffer | None]: + """Unpack stored bytes into per-inner-chunk buffers. + + Default: single chunk keyed at ``(0,)``. + ``ShardingCodec`` overrides to decode the shard index and slice the + blob into per-chunk buffers. + """ + return {(0,): raw} + + def serialize( + self, + chunk_dict: dict[tuple[int, ...], Buffer | None], + chunk_spec: ArraySpec, + ) -> Buffer | None: + """Pack per-inner-chunk buffers into a storage blob. + + Default: return the single chunk's bytes (or ``None`` if absent). + ``ShardingCodec`` overrides to concatenate chunks and build an index. + Returns ``None`` when all chunks are empty (caller should delete the key). + """ + return chunk_dict.get((0,)) + + # ------------------------------------------------------------------ + # prepare / finalize — sync + # ------------------------------------------------------------------ + + def prepare_read_sync( + self, + byte_getter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + codec_chain: Any, + aa_chain: Any, + ab_pair: Any, + bb_chain: Any, + ) -> NDBuffer | None: + """Sync IO + full decode for the selected region.""" + raw = byte_getter.get_sync(prototype=chunk_spec.prototype) + chunk_array: NDBuffer | None = codec_chain.decode_chunk( + raw, chunk_spec, aa_chain, ab_pair, bb_chain + ) + if chunk_array is not None: + return chunk_array[chunk_selection] + return None + + def prepare_write_sync( + self, + byte_setter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + replace: bool, + codec_chain: Any, + ) -> PreparedWrite: + """Sync IO + deserialize. Returns a :class:`PreparedWrite`.""" + existing: Buffer | None = None + if not replace: + existing = byte_setter.get_sync(prototype=chunk_spec.prototype) + chunk_dict = self.deserialize(existing, chunk_spec) + inner_chain = self.inner_codec_chain or codec_chain + return PreparedWrite( + chunk_dict=chunk_dict, + inner_codec_chain=inner_chain, + inner_chunk_spec=chunk_spec, + indexer=[ + ( # type: ignore[list-item] + (0,), + chunk_selection, + out_selection, + replace, + ) + ], + ) + + def finalize_write_sync( + self, + prepared: PreparedWrite, + chunk_spec: ArraySpec, + byte_setter: Any, + ) -> None: + """Serialize the prepared *chunk_dict* and write to store.""" + blob = self.serialize(prepared.chunk_dict, chunk_spec) + if blob is None: + byte_setter.delete_sync() + else: + byte_setter.set_sync(blob) + + # ------------------------------------------------------------------ + # prepare / finalize — async + # ------------------------------------------------------------------ + + async def prepare_read( + self, + byte_getter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + codec_chain: Any, + aa_chain: Any, + ab_pair: Any, + bb_chain: Any, + ) -> NDBuffer | None: + """Async IO + full decode for the selected region.""" + raw = await byte_getter.get(prototype=chunk_spec.prototype) + chunk_array: NDBuffer | None = codec_chain.decode_chunk( + raw, chunk_spec, aa_chain, ab_pair, bb_chain + ) + if chunk_array is not None: + return chunk_array[chunk_selection] + return None + + async def prepare_write( + self, + byte_setter: Any, + chunk_spec: ArraySpec, + chunk_selection: SelectorTuple, + out_selection: SelectorTuple, + replace: bool, + codec_chain: Any, + ) -> PreparedWrite: + """Async IO + deserialize. Returns a :class:`PreparedWrite`.""" + existing: Buffer | None = None + if not replace: + existing = await byte_setter.get(prototype=chunk_spec.prototype) + chunk_dict = self.deserialize(existing, chunk_spec) + inner_chain = self.inner_codec_chain or codec_chain + return PreparedWrite( + chunk_dict=chunk_dict, + inner_codec_chain=inner_chain, + inner_chunk_spec=chunk_spec, + indexer=[ + ( # type: ignore[list-item] + (0,), + chunk_selection, + out_selection, + replace, + ) + ], + ) + + async def finalize_write( + self, + prepared: PreparedWrite, + chunk_spec: ArraySpec, + byte_setter: Any, + ) -> None: + """Async version of :meth:`finalize_write_sync`.""" + blob = self.serialize(prepared.chunk_dict, chunk_spec) + if blob is None: + await byte_setter.delete() + else: + await byte_setter.set(blob) + class BytesBytesCodec(BaseCodec[Buffer, Buffer]): """Base class for bytes-to-bytes codecs.""" From 8c33b349ecf3cb592965d2a8389e4d06494e36e5 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 26 Feb 2026 23:42:49 -0500 Subject: [PATCH 08/14] add prepared write semantics --- src/zarr/abc/codec.py | 267 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 228 insertions(+), 39 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 30189e4cbb..6d7eabd6cf 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -34,6 +34,7 @@ "CodecOutput", "CodecPipeline", "PreparedWrite", + "SupportsChunkCodec", "SupportsSyncCodec", ] @@ -79,6 +80,17 @@ def _encode_sync( ) -> NDBuffer | Buffer | None: ... +class SupportsChunkCodec(Protocol): + """Protocol for objects that can decode/encode whole chunks synchronously. + + [`CodecChain`][zarr.core.codec_pipeline.CodecChain] satisfies this protocol. + """ + + def decode_chunk(self, chunk_bytes: Buffer) -> NDBuffer: ... + + def encode_chunk(self, chunk_array: NDBuffer) -> Buffer | None: ... + + class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]): """Generic base class for codecs. @@ -208,10 +220,37 @@ class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]): @dataclass class PreparedWrite: - """Result of ``prepare_write``: existing encoded chunk bytes + selection info.""" + """Result of the prepare phase of a write operation. + + Carries deserialized chunk data and selection metadata between + [`prepare_write`][zarr.abc.codec.ArrayBytesCodec.prepare_write] (or + [`prepare_write_sync`][zarr.abc.codec.ArrayBytesCodec.prepare_write_sync]) + and [`finalize_write`][zarr.abc.codec.ArrayBytesCodec.finalize_write] (or + [`finalize_write_sync`][zarr.abc.codec.ArrayBytesCodec.finalize_write_sync]). + + Attributes + ---------- + chunk_dict : dict[tuple[int, ...], Buffer | None] + Per-inner-chunk buffers keyed by chunk coordinates. + inner_codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] for + decoding/encoding inner chunks. + inner_chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for inner chunks. + indexer : list[ChunkProjection] + Mapping from inner-chunk coordinates to value/output selections. + value_selection : SelectorTuple | None + Outer ``out_selection`` for sharding. Unused by the base implementation. + write_full_shard : bool + Whether the full shard blob will be written. Unused by the base implementation. + is_complete_shard : bool + Fast-path flag for complete shard writes. Unused by the base implementation. + shard_data : NDBuffer | None + Full shard value for complete writes. Unused by the base implementation. + """ chunk_dict: dict[tuple[int, ...], Buffer | None] - inner_codec_chain: Any # CodecChain — typed as Any to avoid circular import + inner_codec_chain: SupportsChunkCodec inner_chunk_spec: ArraySpec indexer: list[ChunkProjection] value_selection: SelectorTuple | None = None @@ -224,11 +263,18 @@ class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): """Base class for array-to-bytes codecs.""" @property - def inner_codec_chain(self) -> Any: + def inner_codec_chain(self) -> SupportsChunkCodec | None: """The codec chain for decoding inner chunks after deserialization. - Returns ``None`` by default — the pipeline should use its own codec chain. - ``ShardingCodec`` overrides to return its inner codec chain. + Returns ``None`` by default, meaning the pipeline should use its own + codec chain. ``ShardingCodec`` overrides this to return its inner + codec chain. + + Returns + ------- + SupportsChunkCodec or None + A [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] instance, + or ``None``. """ return None @@ -237,9 +283,22 @@ def deserialize( ) -> dict[tuple[int, ...], Buffer | None]: """Unpack stored bytes into per-inner-chunk buffers. - Default: single chunk keyed at ``(0,)``. - ``ShardingCodec`` overrides to decode the shard index and slice the - blob into per-chunk buffers. + The default implementation returns a single-entry dict keyed at + ``(0,)``. ``ShardingCodec`` overrides this to decode the shard index + and split the blob into per-chunk buffers. + + Parameters + ---------- + raw : Buffer or None + The raw bytes read from the store, or ``None`` if the key was + absent. + chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + + Returns + ------- + dict[tuple[int, ...], Buffer | None] + Mapping from inner-chunk coordinates to their encoded bytes. """ return {(0,): raw} @@ -250,9 +309,22 @@ def serialize( ) -> Buffer | None: """Pack per-inner-chunk buffers into a storage blob. - Default: return the single chunk's bytes (or ``None`` if absent). - ``ShardingCodec`` overrides to concatenate chunks and build an index. - Returns ``None`` when all chunks are empty (caller should delete the key). + The default implementation returns the single entry at ``(0,)``. + ``ShardingCodec`` overrides this to concatenate chunks and build a + shard index. + + Parameters + ---------- + chunk_dict : dict[tuple[int, ...], Buffer | None] + Mapping from inner-chunk coordinates to their encoded bytes. + chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + + Returns + ------- + Buffer or None + The serialized blob, or ``None`` when all chunks are empty + (the caller should delete the key). """ return chunk_dict.get((0,)) @@ -265,19 +337,35 @@ def prepare_read_sync( byte_getter: Any, chunk_spec: ArraySpec, chunk_selection: SelectorTuple, - codec_chain: Any, - aa_chain: Any, - ab_pair: Any, - bb_chain: Any, + codec_chain: SupportsChunkCodec, ) -> NDBuffer | None: - """Sync IO + full decode for the selected region.""" + """Read a chunk from the store synchronously, decode it, and + return the selected region. + + Parameters + ---------- + byte_getter : Any + An object supporting ``get_sync`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + chunk_selection : SelectorTuple + Selection within the decoded chunk array. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] used to + decode the chunk. + + Returns + ------- + NDBuffer or None + The decoded chunk data at *chunk_selection*, or ``None`` if the + chunk does not exist in the store. + """ raw = byte_getter.get_sync(prototype=chunk_spec.prototype) - chunk_array: NDBuffer | None = codec_chain.decode_chunk( - raw, chunk_spec, aa_chain, ab_pair, bb_chain - ) - if chunk_array is not None: - return chunk_array[chunk_selection] - return None + if raw is None: + return None + chunk_array = codec_chain.decode_chunk(raw) + return chunk_array[chunk_selection] def prepare_write_sync( self, @@ -286,9 +374,39 @@ def prepare_write_sync( chunk_selection: SelectorTuple, out_selection: SelectorTuple, replace: bool, - codec_chain: Any, + codec_chain: SupportsChunkCodec, ) -> PreparedWrite: - """Sync IO + deserialize. Returns a :class:`PreparedWrite`.""" + """Prepare a synchronous write by optionally reading existing data. + + When *replace* is ``False``, the existing chunk bytes are fetched + from the store so they can be merged with the new data. When + *replace* is ``True``, the fetch is skipped. + + Parameters + ---------- + byte_setter : Any + An object supporting ``get_sync`` and ``set_sync`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + chunk_selection : SelectorTuple + Selection within the chunk being written. + out_selection : SelectorTuple + Corresponding selection within the source value array. + replace : bool + If ``True``, the write replaces all data in the chunk and no + read-modify-write is needed. If ``False``, existing data is + fetched first. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] used to + decode/encode the chunk. + + Returns + ------- + PreparedWrite + A [`PreparedWrite`][zarr.abc.codec.PreparedWrite] carrying the + deserialized chunk data and selection metadata. + """ existing: Buffer | None = None if not replace: existing = byte_setter.get_sync(prototype=chunk_spec.prototype) @@ -314,7 +432,22 @@ def finalize_write_sync( chunk_spec: ArraySpec, byte_setter: Any, ) -> None: - """Serialize the prepared *chunk_dict* and write to store.""" + """Serialize the prepared chunk data and write it to the store. + + If serialization produces ``None`` (all chunks empty), the key is + deleted instead. + + Parameters + ---------- + prepared : PreparedWrite + The [`PreparedWrite`][zarr.abc.codec.PreparedWrite] returned by + [`prepare_write_sync`][zarr.abc.codec.ArrayBytesCodec.prepare_write_sync]. + chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + byte_setter : Any + An object supporting ``set_sync`` and ``delete_sync`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + """ blob = self.serialize(prepared.chunk_dict, chunk_spec) if blob is None: byte_setter.delete_sync() @@ -330,19 +463,35 @@ async def prepare_read( byte_getter: Any, chunk_spec: ArraySpec, chunk_selection: SelectorTuple, - codec_chain: Any, - aa_chain: Any, - ab_pair: Any, - bb_chain: Any, + codec_chain: SupportsChunkCodec, ) -> NDBuffer | None: - """Async IO + full decode for the selected region.""" + """Async variant of + [`prepare_read_sync`][zarr.abc.codec.ArrayBytesCodec.prepare_read_sync]. + + Parameters + ---------- + byte_getter : Any + An object supporting ``get`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + chunk_selection : SelectorTuple + Selection within the decoded chunk array. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] used to + decode the chunk. + + Returns + ------- + NDBuffer or None + The decoded chunk data at *chunk_selection*, or ``None`` if the + chunk does not exist in the store. + """ raw = await byte_getter.get(prototype=chunk_spec.prototype) - chunk_array: NDBuffer | None = codec_chain.decode_chunk( - raw, chunk_spec, aa_chain, ab_pair, bb_chain - ) - if chunk_array is not None: - return chunk_array[chunk_selection] - return None + if raw is None: + return None + chunk_array = codec_chain.decode_chunk(raw) + return chunk_array[chunk_selection] async def prepare_write( self, @@ -351,9 +500,36 @@ async def prepare_write( chunk_selection: SelectorTuple, out_selection: SelectorTuple, replace: bool, - codec_chain: Any, + codec_chain: SupportsChunkCodec, ) -> PreparedWrite: - """Async IO + deserialize. Returns a :class:`PreparedWrite`.""" + """Async variant of + [`prepare_write_sync`][zarr.abc.codec.ArrayBytesCodec.prepare_write_sync]. + + Parameters + ---------- + byte_setter : Any + An object supporting ``get`` and ``set`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + chunk_selection : SelectorTuple + Selection within the chunk being written. + out_selection : SelectorTuple + Corresponding selection within the source value array. + replace : bool + If ``True``, the write replaces all data in the chunk and no + read-modify-write is needed. If ``False``, existing data is + fetched first. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] used to + decode/encode the chunk. + + Returns + ------- + PreparedWrite + A [`PreparedWrite`][zarr.abc.codec.PreparedWrite] carrying the + deserialized chunk data and selection metadata. + """ existing: Buffer | None = None if not replace: existing = await byte_setter.get(prototype=chunk_spec.prototype) @@ -379,7 +555,20 @@ async def finalize_write( chunk_spec: ArraySpec, byte_setter: Any, ) -> None: - """Async version of :meth:`finalize_write_sync`.""" + """Async variant of + [`finalize_write_sync`][zarr.abc.codec.ArrayBytesCodec.finalize_write_sync]. + + Parameters + ---------- + prepared : PreparedWrite + The [`PreparedWrite`][zarr.abc.codec.PreparedWrite] returned by + [`prepare_write`][zarr.abc.codec.ArrayBytesCodec.prepare_write]. + chunk_spec : ArraySpec + The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + byte_setter : Any + An object supporting ``set`` and ``delete`` (e.g. + [`StorePath`][zarr.storage._common.StorePath]). + """ blob = self.serialize(prepared.chunk_dict, chunk_spec) if blob is None: await byte_setter.delete() From 1193d9ca6c62291e08db1b24b7429b5de2ce7dcd Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 27 Feb 2026 09:01:08 -0500 Subject: [PATCH 09/14] simplify preparedwrite --- src/zarr/abc/codec.py | 35 +---------------------------------- 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 6d7eabd6cf..fc86d0edc0 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -2,7 +2,7 @@ from abc import abstractmethod from collections.abc import Mapping -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable from typing_extensions import ReadOnly, TypedDict @@ -232,31 +232,12 @@ class PreparedWrite: ---------- chunk_dict : dict[tuple[int, ...], Buffer | None] Per-inner-chunk buffers keyed by chunk coordinates. - inner_codec_chain : SupportsChunkCodec - The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] for - decoding/encoding inner chunks. - inner_chunk_spec : ArraySpec - The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for inner chunks. indexer : list[ChunkProjection] Mapping from inner-chunk coordinates to value/output selections. - value_selection : SelectorTuple | None - Outer ``out_selection`` for sharding. Unused by the base implementation. - write_full_shard : bool - Whether the full shard blob will be written. Unused by the base implementation. - is_complete_shard : bool - Fast-path flag for complete shard writes. Unused by the base implementation. - shard_data : NDBuffer | None - Full shard value for complete writes. Unused by the base implementation. """ chunk_dict: dict[tuple[int, ...], Buffer | None] - inner_codec_chain: SupportsChunkCodec - inner_chunk_spec: ArraySpec indexer: list[ChunkProjection] - value_selection: SelectorTuple | None = None - write_full_shard: bool = True - is_complete_shard: bool = False - shard_data: NDBuffer | None = field(default=None) class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): @@ -374,7 +355,6 @@ def prepare_write_sync( chunk_selection: SelectorTuple, out_selection: SelectorTuple, replace: bool, - codec_chain: SupportsChunkCodec, ) -> PreparedWrite: """Prepare a synchronous write by optionally reading existing data. @@ -397,9 +377,6 @@ def prepare_write_sync( If ``True``, the write replaces all data in the chunk and no read-modify-write is needed. If ``False``, existing data is fetched first. - codec_chain : SupportsChunkCodec - The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] used to - decode/encode the chunk. Returns ------- @@ -411,11 +388,8 @@ def prepare_write_sync( if not replace: existing = byte_setter.get_sync(prototype=chunk_spec.prototype) chunk_dict = self.deserialize(existing, chunk_spec) - inner_chain = self.inner_codec_chain or codec_chain return PreparedWrite( chunk_dict=chunk_dict, - inner_codec_chain=inner_chain, - inner_chunk_spec=chunk_spec, indexer=[ ( # type: ignore[list-item] (0,), @@ -500,7 +474,6 @@ async def prepare_write( chunk_selection: SelectorTuple, out_selection: SelectorTuple, replace: bool, - codec_chain: SupportsChunkCodec, ) -> PreparedWrite: """Async variant of [`prepare_write_sync`][zarr.abc.codec.ArrayBytesCodec.prepare_write_sync]. @@ -520,9 +493,6 @@ async def prepare_write( If ``True``, the write replaces all data in the chunk and no read-modify-write is needed. If ``False``, existing data is fetched first. - codec_chain : SupportsChunkCodec - The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] used to - decode/encode the chunk. Returns ------- @@ -534,11 +504,8 @@ async def prepare_write( if not replace: existing = await byte_setter.get(prototype=chunk_spec.prototype) chunk_dict = self.deserialize(existing, chunk_spec) - inner_chain = self.inner_codec_chain or codec_chain return PreparedWrite( chunk_dict=chunk_dict, - inner_codec_chain=inner_chain, - inner_chunk_spec=chunk_spec, indexer=[ ( # type: ignore[list-item] (0,), From d2af4088c9818401d32d26b6c9cac5f8b192d35a Mon Sep 17 00:00:00 2001 From: Mark Kittisopikul Date: Wed, 25 Feb 2026 15:34:12 -0500 Subject: [PATCH 10/14] perf: Fix near-miss penalty in _morton_order with hybrid ceiling+argsort strategy (#3718) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * tests: Add non-power-of-2 shard shapes to benchmarks Add (30,30,30) to large_morton_shards and (10,10,10), (20,20,20), (30,30,30) to morton_iter_shapes to benchmark the scalar fallback path for non-power-of-2 shapes, which are not fully covered by the vectorized hypercube path. Co-Authored-By: Claude Sonnet 4.6 * tests: Add near-miss power-of-2 shape (33,33,33) to benchmarks Documents the performance penalty when a shard shape is just above a power-of-2 boundary, causing n_z to jump from 32,768 to 262,144. Co-Authored-By: Claude Sonnet 4.6 * style: Apply ruff format to benchmark file Co-Authored-By: Claude Sonnet 4.6 * changes: Add changelog entry for PR #3717 Co-Authored-By: Claude Sonnet 4.6 * perf: Fix near-miss penalty in _morton_order with hybrid ceiling+argsort strategy For shapes just above a power-of-2 (e.g. (33,33,33)), the ceiling-only approach generates n_z=262,144 Morton codes for only 35,937 valid coordinates (7.3× overgeneration). The floor+scalar approach is even worse since the scalar loop iterates n_z-n_floor times (229,376 for (33,33,33)), not n_total-n_floor. The fix: when n_z > 4*n_total, use an argsort strategy that enumerates all n_total valid coordinates via meshgrid, encodes each to a Morton code using vectorized bit manipulation, then sorts by Morton code. This avoids the large overgeneration while remaining fully vectorized. Result for test_morton_order_iter: (30,30,30): 24ms (ceiling, ratio=1.21) (32,32,32): 28ms (ceiling, ratio=1.00) (33,33,33): 32ms (argsort, ratio=7.3 → fixed from ~820ms with scalar) Co-Authored-By: Claude Sonnet 4.6 * fix: Address pre-commit CI failures in _morton_order - Replace Unicode multiplication sign × with ASCII x in comment (RUF003) - Add explicit type annotation for np.argsort result to satisfy mypy Co-Authored-By: Claude Sonnet 4.6 * fix: Cast argsort result via np.asarray to resolve mypy no-any-return np.stack returns Any in mypy's view, so indexing into it also returns Any. Using np.asarray(..., dtype=np.intp) makes the type explicit and avoids the no-any-return error at the return site. Co-Authored-By: Claude Sonnet 4.6 * fix: Pre-declare order type to resolve mypy no-any-return in _morton_order np.asarray and np.stack return Any with numpy 2.1 type stubs, causing mypy to infer the return type as Any. Pre-declaring order as npt.NDArray[np.intp] before the if/else makes the intended type explicit. Co-Authored-By: Claude Sonnet 4.6 --------- Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Davis Bennett --- src/zarr/core/indexing.py | 86 ++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 46 deletions(-) diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 454f7e2290..73fd53087d 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -1512,54 +1512,48 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]: out.flags.writeable = False return out - # Optimization: Remove singleton dimensions to enable magic number usage - # for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand. - singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1) - if singleton_dims: - squeezed_shape = tuple(s for s in chunk_shape if s != 1) - if squeezed_shape: - # Compute Morton order on squeezed shape, then expand singleton dims (always 0) - squeezed_order = np.asarray(_morton_order(squeezed_shape)) - out = np.zeros((n_total, n_dims), dtype=np.intp) - squeezed_col = 0 - for full_col in range(n_dims): - if chunk_shape[full_col] != 1: - out[:, full_col] = squeezed_order[:, squeezed_col] - squeezed_col += 1 - else: - # All dimensions are singletons, just return the single point - out = np.zeros((1, n_dims), dtype=np.intp) - out.flags.writeable = False - return out - - # Find the largest power-of-2 hypercube that fits within chunk_shape. - # Within this hypercube, Morton codes are guaranteed to be in bounds. - min_dim = min(chunk_shape) - if min_dim >= 1: - power = min_dim.bit_length() - 1 # floor(log2(min_dim)) - hypercube_size = 1 << power # 2^power - n_hypercube = hypercube_size**n_dims + # Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span + # all valid coordinates in chunk_shape. (c-1).bit_length() gives the number + # of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits + # is the size of this hypercube. + total_bits = sum((c - 1).bit_length() for c in chunk_shape) + n_z = 1 << total_bits if total_bits > 0 else 1 + + # Decode all Morton codes in the ceiling hypercube, then filter to valid coords. + # This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33): + # n_z=262144, n_total=35937), consider the argsort strategy below. + order: npt.NDArray[np.intp] + if n_z <= 4 * n_total: + # Ceiling strategy: decode all n_z codes vectorized, filter in-bounds. + # Works well when the overgeneration ratio n_z/n_total is small (≤4). + z_values = np.arange(n_z, dtype=np.intp) + all_coords = decode_morton_vectorized(z_values, chunk_shape) + shape_arr = np.array(chunk_shape, dtype=np.intp) + valid_mask = np.all(all_coords < shape_arr, axis=1) + order = all_coords[valid_mask] else: - n_hypercube = 0 + # Argsort strategy: enumerate all n_total valid coordinates directly, + # encode each to a Morton code, then sort by code. Avoids the 8x or + # larger overgeneration penalty for near-miss shapes like (33,33,33). + # Cost: O(n_total * bits) encode + O(n_total log n_total) sort, + # vs O(n_z * bits) = O(8 * n_total * bits) for ceiling. + grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in chunk_shape], indexing="ij") + all_coords = np.stack([g.ravel() for g in grids], axis=1) + + # Encode all coordinates to Morton codes (vectorized). + bits_per_dim = tuple((c - 1).bit_length() for c in chunk_shape) + max_coord_bits = max(bits_per_dim) + z_codes = np.zeros(n_total, dtype=np.intp) + output_bit = 0 + for coord_bit in range(max_coord_bits): + for dim in range(n_dims): + if coord_bit < bits_per_dim[dim]: + z_codes |= ((all_coords[:, dim] >> coord_bit) & 1) << output_bit + output_bit += 1 + + sort_idx: npt.NDArray[np.intp] = np.argsort(z_codes, kind="stable") + order = np.asarray(all_coords[sort_idx], dtype=np.intp) - # Within the hypercube, no bounds checking needed - use vectorized decoding - if n_hypercube > 0: - z_values = np.arange(n_hypercube, dtype=np.intp) - order: npt.NDArray[np.intp] = decode_morton_vectorized(z_values, chunk_shape) - else: - order = np.empty((0, n_dims), dtype=np.intp) - - # For remaining elements outside the hypercube, bounds checking is needed - remaining: list[tuple[int, ...]] = [] - i = n_hypercube - while len(order) + len(remaining) < n_total: - m = decode_morton(i, chunk_shape) - if all(x < y for x, y in zip(m, chunk_shape, strict=False)): - remaining.append(m) - i += 1 - - if remaining: - order = np.vstack([order, np.array(remaining, dtype=np.intp)]) order.flags.writeable = False return order From 13b52fd564f9f8dd071439116747a1d54cb3776d Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 27 Feb 2026 15:53:09 -0500 Subject: [PATCH 11/14] use chunktransform --- src/zarr/abc/codec.py | 52 ++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index fc86d0edc0..d5d673e1c7 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -83,9 +83,11 @@ def _encode_sync( class SupportsChunkCodec(Protocol): """Protocol for objects that can decode/encode whole chunks synchronously. - [`CodecChain`][zarr.core.codec_pipeline.CodecChain] satisfies this protocol. + [`ChunkTransform`][zarr.core.codec_pipeline.ChunkTransform] satisfies this protocol. """ + array_spec: ArraySpec + def decode_chunk(self, chunk_bytes: Buffer) -> NDBuffer: ... def encode_chunk(self, chunk_array: NDBuffer) -> Buffer | None: ... @@ -316,7 +318,6 @@ def serialize( def prepare_read_sync( self, byte_getter: Any, - chunk_spec: ArraySpec, chunk_selection: SelectorTuple, codec_chain: SupportsChunkCodec, ) -> NDBuffer | None: @@ -328,13 +329,11 @@ def prepare_read_sync( byte_getter : Any An object supporting ``get_sync`` (e.g. [`StorePath`][zarr.storage._common.StorePath]). - chunk_spec : ArraySpec - The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. chunk_selection : SelectorTuple Selection within the decoded chunk array. codec_chain : SupportsChunkCodec The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] used to - decode the chunk. + decode the chunk. Must carry an ``array_spec`` attribute. Returns ------- @@ -342,7 +341,7 @@ def prepare_read_sync( The decoded chunk data at *chunk_selection*, or ``None`` if the chunk does not exist in the store. """ - raw = byte_getter.get_sync(prototype=chunk_spec.prototype) + raw = byte_getter.get_sync(prototype=codec_chain.array_spec.prototype) if raw is None: return None chunk_array = codec_chain.decode_chunk(raw) @@ -351,7 +350,7 @@ def prepare_read_sync( def prepare_write_sync( self, byte_setter: Any, - chunk_spec: ArraySpec, + codec_chain: SupportsChunkCodec, chunk_selection: SelectorTuple, out_selection: SelectorTuple, replace: bool, @@ -367,8 +366,9 @@ def prepare_write_sync( byte_setter : Any An object supporting ``get_sync`` and ``set_sync`` (e.g. [`StorePath`][zarr.storage._common.StorePath]). - chunk_spec : ArraySpec - The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] + carrying the ``array_spec`` for the chunk. chunk_selection : SelectorTuple Selection within the chunk being written. out_selection : SelectorTuple @@ -384,6 +384,7 @@ def prepare_write_sync( A [`PreparedWrite`][zarr.abc.codec.PreparedWrite] carrying the deserialized chunk data and selection metadata. """ + chunk_spec = codec_chain.array_spec existing: Buffer | None = None if not replace: existing = byte_setter.get_sync(prototype=chunk_spec.prototype) @@ -403,7 +404,7 @@ def prepare_write_sync( def finalize_write_sync( self, prepared: PreparedWrite, - chunk_spec: ArraySpec, + codec_chain: SupportsChunkCodec, byte_setter: Any, ) -> None: """Serialize the prepared chunk data and write it to the store. @@ -416,13 +417,14 @@ def finalize_write_sync( prepared : PreparedWrite The [`PreparedWrite`][zarr.abc.codec.PreparedWrite] returned by [`prepare_write_sync`][zarr.abc.codec.ArrayBytesCodec.prepare_write_sync]. - chunk_spec : ArraySpec - The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] + carrying the ``array_spec`` for the chunk. byte_setter : Any An object supporting ``set_sync`` and ``delete_sync`` (e.g. [`StorePath`][zarr.storage._common.StorePath]). """ - blob = self.serialize(prepared.chunk_dict, chunk_spec) + blob = self.serialize(prepared.chunk_dict, codec_chain.array_spec) if blob is None: byte_setter.delete_sync() else: @@ -435,7 +437,6 @@ def finalize_write_sync( async def prepare_read( self, byte_getter: Any, - chunk_spec: ArraySpec, chunk_selection: SelectorTuple, codec_chain: SupportsChunkCodec, ) -> NDBuffer | None: @@ -447,13 +448,11 @@ async def prepare_read( byte_getter : Any An object supporting ``get`` (e.g. [`StorePath`][zarr.storage._common.StorePath]). - chunk_spec : ArraySpec - The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. chunk_selection : SelectorTuple Selection within the decoded chunk array. codec_chain : SupportsChunkCodec The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] used to - decode the chunk. + decode the chunk. Must carry an ``array_spec`` attribute. Returns ------- @@ -461,7 +460,7 @@ async def prepare_read( The decoded chunk data at *chunk_selection*, or ``None`` if the chunk does not exist in the store. """ - raw = await byte_getter.get(prototype=chunk_spec.prototype) + raw = await byte_getter.get(prototype=codec_chain.array_spec.prototype) if raw is None: return None chunk_array = codec_chain.decode_chunk(raw) @@ -470,7 +469,7 @@ async def prepare_read( async def prepare_write( self, byte_setter: Any, - chunk_spec: ArraySpec, + codec_chain: SupportsChunkCodec, chunk_selection: SelectorTuple, out_selection: SelectorTuple, replace: bool, @@ -483,8 +482,9 @@ async def prepare_write( byte_setter : Any An object supporting ``get`` and ``set`` (e.g. [`StorePath`][zarr.storage._common.StorePath]). - chunk_spec : ArraySpec - The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] + carrying the ``array_spec`` for the chunk. chunk_selection : SelectorTuple Selection within the chunk being written. out_selection : SelectorTuple @@ -500,6 +500,7 @@ async def prepare_write( A [`PreparedWrite`][zarr.abc.codec.PreparedWrite] carrying the deserialized chunk data and selection metadata. """ + chunk_spec = codec_chain.array_spec existing: Buffer | None = None if not replace: existing = await byte_setter.get(prototype=chunk_spec.prototype) @@ -519,7 +520,7 @@ async def prepare_write( async def finalize_write( self, prepared: PreparedWrite, - chunk_spec: ArraySpec, + codec_chain: SupportsChunkCodec, byte_setter: Any, ) -> None: """Async variant of @@ -530,13 +531,14 @@ async def finalize_write( prepared : PreparedWrite The [`PreparedWrite`][zarr.abc.codec.PreparedWrite] returned by [`prepare_write`][zarr.abc.codec.ArrayBytesCodec.prepare_write]. - chunk_spec : ArraySpec - The [`ArraySpec`][zarr.core.array_spec.ArraySpec] for the chunk. + codec_chain : SupportsChunkCodec + The [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] + carrying the ``array_spec`` for the chunk. byte_setter : Any An object supporting ``set`` and ``delete`` (e.g. [`StorePath`][zarr.storage._common.StorePath]). """ - blob = self.serialize(prepared.chunk_dict, chunk_spec) + blob = self.serialize(prepared.chunk_dict, codec_chain.array_spec) if blob is None: await byte_setter.delete() else: From b1c245e4ad85a0e027a59f3466fff29256051b0b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 27 Feb 2026 16:37:30 -0500 Subject: [PATCH 12/14] wire chunktransform up to sharding --- src/zarr/abc/codec.py | 16 --- src/zarr/codecs/sharding.py | 176 +++++++++++++------------------- src/zarr/core/codec_pipeline.py | 45 ++++++++ 3 files changed, 116 insertions(+), 121 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index d5d673e1c7..23b9e3a851 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -245,22 +245,6 @@ class PreparedWrite: class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]): """Base class for array-to-bytes codecs.""" - @property - def inner_codec_chain(self) -> SupportsChunkCodec | None: - """The codec chain for decoding inner chunks after deserialization. - - Returns ``None`` by default, meaning the pipeline should use its own - codec chain. ``ShardingCodec`` overrides this to return its inner - codec chain. - - Returns - ------- - SupportsChunkCodec or None - A [`SupportsChunkCodec`][zarr.abc.codec.SupportsChunkCodec] instance, - or ``None``. - """ - return None - def deserialize( self, raw: Buffer | None, chunk_spec: ArraySpec ) -> dict[tuple[int, ...], Buffer | None]: diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 85162c2f74..4c6a19c504 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping, MutableMapping +from collections.abc import Iterable, Mapping from dataclasses import dataclass, replace from enum import Enum from functools import lru_cache @@ -15,11 +15,9 @@ ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin, Codec, - CodecPipeline, ) from zarr.abc.store import ( ByteGetter, - ByteRequest, ByteSetter, RangeByteRequest, SuffixByteRequest, @@ -35,6 +33,7 @@ numpy_buffer_prototype, ) from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid +from zarr.core.codec_pipeline import ChunkTransform, fill_value_or_default from zarr.core.common import ( ShapeLike, parse_enum, @@ -54,7 +53,6 @@ ) from zarr.core.metadata.v3 import parse_codecs from zarr.registry import get_ndbuffer_class, get_pipeline_class -from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: from collections.abc import Iterator @@ -65,7 +63,6 @@ MAX_UINT_64 = 2**64 - 1 ShardMapping = Mapping[tuple[int, ...], Buffer | None] -ShardMutableMapping = MutableMapping[tuple[int, ...], Buffer | None] class ShardingCodecIndexLocation(Enum): @@ -81,41 +78,6 @@ def parse_index_location(data: object) -> ShardingCodecIndexLocation: return parse_enum(data, ShardingCodecIndexLocation) -@dataclass(frozen=True) -class _ShardingByteGetter(ByteGetter): - shard_dict: ShardMapping - chunk_coords: tuple[int, ...] - - async def get( - self, prototype: BufferPrototype, byte_range: ByteRequest | None = None - ) -> Buffer | None: - assert prototype == default_buffer_prototype(), ( - f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}" - ) - value = self.shard_dict.get(self.chunk_coords) - if value is None: - return None - if byte_range is None: - return value - start, stop = _normalize_byte_range_index(value, byte_range) - return value[start:stop] - - -@dataclass(frozen=True) -class _ShardingByteSetter(_ShardingByteGetter, ByteSetter): - shard_dict: ShardMutableMapping - - async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None: - assert byte_range is None, "byte_range is not supported within shards" - self.shard_dict[self.chunk_coords] = value - - async def delete(self) -> None: - del self.shard_dict[self.chunk_coords] - - async def set_if_not_exists(self, default: Buffer) -> None: - self.shard_dict.setdefault(self.chunk_coords, default) - - class _ShardIndex(NamedTuple): # dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2) offsets_and_lengths: npt.NDArray[np.uint64] @@ -354,9 +316,8 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: _, configuration_parsed = parse_named_configuration(data, "sharding_indexed") return cls(**configuration_parsed) # type: ignore[arg-type] - @property - def codec_pipeline(self) -> CodecPipeline: - return get_pipeline_class().from_codecs(self.codecs) + def _get_chunk_transform(self, chunk_spec: ArraySpec) -> ChunkTransform: + return ChunkTransform(codecs=self.codecs, array_spec=chunk_spec) def to_dict(self) -> dict[str, JSON]: return { @@ -430,20 +391,15 @@ async def _decode_single( out.fill(shard_spec.fill_value) return out - # decoding chunks and writing them into the output buffer - await self.codec_pipeline.read( - [ - ( - _ShardingByteGetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - out, - ) + transform = self._get_chunk_transform(chunk_spec) + fill_value = fill_value_or_default(chunk_spec) + for chunk_coords, chunk_selection, out_selection, _ in indexer: + chunk_bytes = shard_dict.get(chunk_coords) + if chunk_bytes is not None: + chunk_array = await transform.decode_chunk_async(chunk_bytes) + out[out_selection] = chunk_array[chunk_selection] + else: + out[out_selection] = fill_value return out @@ -502,20 +458,16 @@ async def _decode_partial_single( if chunk_bytes: shard_dict[chunk_coords] = chunk_bytes - # decoding chunks and writing them into the output buffer - await self.codec_pipeline.read( - [ - ( - _ShardingByteGetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - out, - ) + # decode chunks and write them into the output buffer + transform = self._get_chunk_transform(chunk_spec) + fill_value = fill_value_or_default(chunk_spec) + for chunk_coords, chunk_selection, out_selection, _ in indexed_chunks: + chunk_bytes = shard_dict.get(chunk_coords) + if chunk_bytes is not None: + chunk_array = await transform.decode_chunk_async(chunk_bytes) + out[out_selection] = chunk_array[chunk_selection] + else: + out[out_selection] = fill_value if hasattr(indexer, "sel_shape"): return out.reshape(indexer.sel_shape) @@ -532,29 +484,23 @@ async def _encode_single( chunks_per_shard = self._get_chunks_per_shard(shard_spec) chunk_spec = self._get_chunk_spec(shard_spec) - indexer = list( - BasicIndexer( - tuple(slice(0, s) for s in shard_shape), - shape=shard_shape, - chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), - ) + indexer = BasicIndexer( + tuple(slice(0, s) for s in shard_shape), + shape=shard_shape, + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), ) - shard_builder = dict.fromkeys(morton_order_iter(chunks_per_shard)) + transform = self._get_chunk_transform(chunk_spec) + fill_value = fill_value_or_default(chunk_spec) + shard_builder: dict[tuple[int, ...], Buffer | None] = {} - await self.codec_pipeline.write( - [ - ( - _ShardingByteSetter(shard_builder, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - shard_array, - ) + for chunk_coords, _, out_selection, _is_complete in indexer: + chunk_array = shard_array[out_selection] + if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal(fill_value): + continue + encoded = await transform.encode_chunk_async(chunk_array) + if encoded is not None: + shard_builder[chunk_coords] = encoded return await self._encode_shard_dict( shard_builder, @@ -581,7 +527,9 @@ async def _encode_partial_single( ) shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) # Use vectorized lookup for better performance - shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard))) + shard_dict: dict[tuple[int, ...], Buffer | None] = shard_reader.to_dict_vectorized( + np.asarray(_morton_order(chunks_per_shard)) + ) indexer = list( get_indexer( @@ -589,19 +537,37 @@ async def _encode_partial_single( ) ) - await self.codec_pipeline.write( - [ - ( - _ShardingByteSetter(shard_dict, chunk_coords), - chunk_spec, - chunk_selection, - out_selection, - is_complete_shard, - ) - for chunk_coords, chunk_selection, out_selection, is_complete_shard in indexer - ], - shard_array, - ) + transform = self._get_chunk_transform(chunk_spec) + fill_value = fill_value_or_default(chunk_spec) + + is_scalar = len(shard_array.shape) == 0 + for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer: + value = shard_array if is_scalar else shard_array[out_selection] + if is_complete_chunk and not is_scalar and value.shape == chunk_spec.shape: + # Complete overwrite with matching shape — use value directly + chunk_data = value + else: + # Read-modify-write: decode existing or create new, merge data + if is_complete_chunk: + existing_bytes = None + else: + existing_bytes = shard_dict.get(chunk_coords) + if existing_bytes is not None: + chunk_data = (await transform.decode_chunk_async(existing_bytes)).copy() + else: + chunk_data = chunk_spec.prototype.nd_buffer.create( + shape=chunk_spec.shape, + dtype=chunk_spec.dtype.to_native_dtype(), + order=chunk_spec.order, + fill_value=fill_value, + ) + chunk_data[chunk_selection] = value + + if not chunk_spec.config.write_empty_chunks and chunk_data.all_equal(fill_value): + shard_dict[chunk_coords] = None + else: + shard_dict[chunk_coords] = await transform.encode_chunk_async(chunk_data) + buf = await self._encode_shard_dict( shard_dict, chunks_per_shard=chunks_per_shard, diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 0c3cccb1d9..aeeb76f032 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -168,6 +168,51 @@ def encode_chunk( return bb_out # type: ignore[no-any-return] + async def decode_chunk_async( + self, + chunk_bytes: Buffer, + ) -> NDBuffer: + """Decode a single chunk through the full codec chain, asynchronously. + + Needed when the codec chain contains async-only codecs (e.g. nested sharding). + """ + bb_out: Any = chunk_bytes + for bb_codec in reversed(self._bb_codecs): + bb_out = await bb_codec._decode_single(bb_out, self._ab_spec) + + ab_out: Any = await self._ab_codec._decode_single(bb_out, self._ab_spec) + + for aa_codec, spec in reversed(self.layers): + ab_out = await aa_codec._decode_single(ab_out, spec) + + return ab_out # type: ignore[no-any-return] + + async def encode_chunk_async( + self, + chunk_array: NDBuffer, + ) -> Buffer | None: + """Encode a single chunk through the full codec chain, asynchronously. + + Needed when the codec chain contains async-only codecs (e.g. nested sharding). + """ + aa_out: Any = chunk_array + + for aa_codec, spec in self.layers: + if aa_out is None: + return None + aa_out = await aa_codec._encode_single(aa_out, spec) + + if aa_out is None: + return None + bb_out: Any = await self._ab_codec._encode_single(aa_out, self._ab_spec) + + for bb_codec in self._bb_codecs: + if bb_out is None: + return None + bb_out = await bb_codec._encode_single(bb_out, self._ab_spec) + + return bb_out # type: ignore[no-any-return] + def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: for codec in self.codecs: byte_length = codec.compute_encoded_size(byte_length, array_spec) From 56dbe612ce80044d9f61d0e01aeaefcf72028b8e Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 27 Feb 2026 18:45:20 -0500 Subject: [PATCH 13/14] refactor is_complete_chunk usage, add chunkrequest --- src/zarr/abc/codec.py | 25 +++---- src/zarr/codecs/sharding.py | 10 +-- src/zarr/core/array.py | 25 +++---- src/zarr/core/codec_pipeline.py | 119 +++++++++++++++----------------- src/zarr/core/indexing.py | 65 +++++++++++------ tests/test_config.py | 7 +- 6 files changed, 133 insertions(+), 118 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 23b9e3a851..50af59166c 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -19,6 +19,7 @@ from zarr.abc.store import ByteGetter, ByteSetter, Store from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import ChunkGrid + from zarr.core.codec_pipeline import ChunkRequest from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType from zarr.core.indexing import ChunkProjection, SelectorTuple from zarr.core.metadata import ArrayMetadata @@ -751,7 +752,7 @@ async def encode( @abstractmethod async def read( self, - batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ChunkRequest], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -760,12 +761,10 @@ async def read( Parameters ---------- - batch_info : Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]] - Ordered set of information about the chunks. - The first slice selection determines which parts of the chunk will be fetched. - The second slice selection determines where in the output array the chunk data will be written. - The ByteGetter is used to fetch the necessary bytes. - The chunk spec contains information about the construction of an array from the bytes. + batch_info : Iterable[ChunkRequest] + Ordered set of chunk requests. Each ``ChunkRequest`` carries the + store path (``byte_setter``), the ``ArraySpec`` for that chunk, + chunk and output selections, and whether the chunk is complete. If the Store returns ``None`` for a chunk, then the chunk was not written and the implementation must set the values of that chunk (or @@ -778,7 +777,7 @@ async def read( @abstractmethod async def write( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -788,12 +787,10 @@ async def write( Parameters ---------- - batch_info : Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]] - Ordered set of information about the chunks. - The first slice selection determines which parts of the chunk will be encoded. - The second slice selection determines where in the value array the chunk data is located. - The ByteSetter is used to fetch and write the necessary bytes. - The chunk spec contains information about the chunk. + batch_info : Iterable[ChunkRequest] + Ordered set of chunk requests. Each ``ChunkRequest`` carries the + store path (``byte_setter``), the ``ArraySpec`` for that chunk, + chunk and output selections, and whether the chunk is complete. value : NDBuffer """ ... diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 4c6a19c504..e8537a187f 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -393,7 +393,7 @@ async def _decode_single( transform = self._get_chunk_transform(chunk_spec) fill_value = fill_value_or_default(chunk_spec) - for chunk_coords, chunk_selection, out_selection, _ in indexer: + for chunk_coords, chunk_selection, out_selection, _is_complete in indexer: chunk_bytes = shard_dict.get(chunk_coords) if chunk_bytes is not None: chunk_array = await transform.decode_chunk_async(chunk_bytes) @@ -461,7 +461,7 @@ async def _decode_partial_single( # decode chunks and write them into the output buffer transform = self._get_chunk_transform(chunk_spec) fill_value = fill_value_or_default(chunk_spec) - for chunk_coords, chunk_selection, out_selection, _ in indexed_chunks: + for chunk_coords, chunk_selection, out_selection, _is_complete in indexed_chunks: chunk_bytes = shard_dict.get(chunk_coords) if chunk_bytes is not None: chunk_array = await transform.decode_chunk_async(chunk_bytes) @@ -541,14 +541,14 @@ async def _encode_partial_single( fill_value = fill_value_or_default(chunk_spec) is_scalar = len(shard_array.shape) == 0 - for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer: + for chunk_coords, chunk_selection, out_selection, is_complete in indexer: value = shard_array if is_scalar else shard_array[out_selection] - if is_complete_chunk and not is_scalar and value.shape == chunk_spec.shape: + if is_complete and not is_scalar and value.shape == chunk_spec.shape: # Complete overwrite with matching shape — use value directly chunk_data = value else: # Read-modify-write: decode existing or create new, merge data - if is_complete_chunk: + if is_complete: existing_bytes = None else: existing_bytes = shard_dict.get(chunk_coords) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 564d0e915a..d1107d2ee8 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -48,6 +48,7 @@ V2ChunkKeyEncoding, parse_chunk_key_encoding, ) +from zarr.core.codec_pipeline import ChunkRequest from zarr.core.common import ( JSON, ZARR_JSON, @@ -5602,12 +5603,12 @@ async def _get_selection( # reading chunks and decoding them await codec_pipeline.read( [ - ( - store_path / metadata.encode_chunk_key(chunk_coords), - metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype), - chunk_selection, - out_selection, - is_complete_chunk, + ChunkRequest( + byte_setter=store_path / metadata.encode_chunk_key(chunk_coords), + chunk_spec=metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype), + chunk_selection=chunk_selection, + out_selection=out_selection, + is_complete_chunk=is_complete_chunk, ) for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer ], @@ -5912,12 +5913,12 @@ async def _set_selection( # merging with existing data and encoding chunks await codec_pipeline.write( [ - ( - store_path / metadata.encode_chunk_key(chunk_coords), - metadata.get_chunk_spec(chunk_coords, _config, prototype), - chunk_selection, - out_selection, - is_complete_chunk, + ChunkRequest( + byte_setter=store_path / metadata.encode_chunk_key(chunk_coords), + chunk_spec=metadata.get_chunk_spec(chunk_coords, _config, prototype), + chunk_selection=chunk_selection, + out_selection=out_selection, + is_complete_chunk=is_complete_chunk, ) for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer ], diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index aeeb76f032..64fcc1ecf5 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -220,6 +220,20 @@ def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: return byte_length +@dataclass(slots=True) +class ChunkRequest: + """A single chunk's worth of metadata for a pipeline read or write. + + Replaces the anonymous 5-tuples formerly threaded through ``batch_info``. + """ + + byte_setter: ByteSetter + chunk_spec: ArraySpec + chunk_selection: SelectorTuple + out_selection: SelectorTuple + is_complete_chunk: bool + + @dataclass(frozen=True) class BatchedCodecPipeline(CodecPipeline): """Default codec pipeline. @@ -400,48 +414,40 @@ async def encode_partial_batch( async def read_batch( self, - batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ChunkRequest], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: + batch_info = list(batch_info) if self.supports_partial_decode: chunk_array_batch = await self.decode_partial_batch( - [ - (byte_getter, chunk_selection, chunk_spec) - for byte_getter, chunk_spec, chunk_selection, *_ in batch_info - ] + [(req.byte_setter, req.chunk_selection, req.chunk_spec) for req in batch_info] ) - for chunk_array, (_, chunk_spec, _, out_selection, _) in zip( - chunk_array_batch, batch_info, strict=False - ): + for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False): if chunk_array is not None: - out[out_selection] = chunk_array + out[req.out_selection] = chunk_array else: - out[out_selection] = fill_value_or_default(chunk_spec) + out[req.out_selection] = fill_value_or_default(req.chunk_spec) else: chunk_bytes_batch = await concurrent_map( - [(byte_getter, chunk_spec.prototype) for byte_getter, chunk_spec, *_ in batch_info], + [(req.byte_setter, req.chunk_spec.prototype) for req in batch_info], lambda byte_getter, prototype: byte_getter.get(prototype), config.get("async.concurrency"), ) chunk_array_batch = await self.decode_batch( [ - (chunk_bytes, chunk_spec) - for chunk_bytes, (_, chunk_spec, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) + (chunk_bytes, req.chunk_spec) + for chunk_bytes, req in zip(chunk_bytes_batch, batch_info, strict=False) ], ) - for chunk_array, (_, chunk_spec, chunk_selection, out_selection, _) in zip( - chunk_array_batch, batch_info, strict=False - ): + for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False): if chunk_array is not None: - tmp = chunk_array[chunk_selection] + tmp = chunk_array[req.chunk_selection] if drop_axes != (): tmp = tmp.squeeze(axis=drop_axes) - out[out_selection] = tmp + out[req.out_selection] = tmp else: - out[out_selection] = fill_value_or_default(chunk_spec) + out[req.out_selection] = fill_value_or_default(req.chunk_spec) def _merge_chunk_array( self, @@ -450,13 +456,11 @@ def _merge_chunk_array( out_selection: SelectorTuple, chunk_spec: ArraySpec, chunk_selection: SelectorTuple, - is_complete_chunk: bool, drop_axes: tuple[int, ...], ) -> NDBuffer: if ( - is_complete_chunk + existing_chunk_array is None and value.shape == chunk_spec.shape - # Guard that this is not a partial chunk at the end with is_complete_chunk=True and value[out_selection].shape == chunk_spec.shape ): return value @@ -489,24 +493,30 @@ def _merge_chunk_array( async def write_batch( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: + batch_info = list(batch_info) if self.supports_partial_encode: # Pass scalar values as is if len(value.shape) == 0: await self.encode_partial_batch( [ - (byte_setter, value, chunk_selection, chunk_spec) - for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info + (req.byte_setter, value, req.chunk_selection, req.chunk_spec) + for req in batch_info ], ) else: await self.encode_partial_batch( [ - (byte_setter, value[out_selection], chunk_selection, chunk_spec) - for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info + ( + req.byte_setter, + value[req.out_selection], + req.chunk_selection, + req.chunk_spec, + ) + for req in batch_info ], ) @@ -523,20 +533,18 @@ async def _read_key( chunk_bytes_batch = await concurrent_map( [ ( - None if is_complete_chunk else byte_setter, - chunk_spec.prototype, + None if req.is_complete_chunk else req.byte_setter, + req.chunk_spec.prototype, ) - for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info + for req in batch_info ], _read_key, config.get("async.concurrency"), ) chunk_array_decoded = await self.decode_batch( [ - (chunk_bytes, chunk_spec) - for chunk_bytes, (_, chunk_spec, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) + (chunk_bytes, req.chunk_spec) + for chunk_bytes, req in zip(chunk_bytes_batch, batch_info, strict=False) ], ) @@ -544,29 +552,20 @@ async def _read_key( self._merge_chunk_array( chunk_array, value, - out_selection, - chunk_spec, - chunk_selection, - is_complete_chunk, + req.out_selection, + req.chunk_spec, + req.chunk_selection, drop_axes, ) - for chunk_array, ( - _, - chunk_spec, - chunk_selection, - out_selection, - is_complete_chunk, - ) in zip(chunk_array_decoded, batch_info, strict=False) + for chunk_array, req in zip(chunk_array_decoded, batch_info, strict=False) ] chunk_array_batch: list[NDBuffer | None] = [] - for chunk_array, (_, chunk_spec, *_) in zip( - chunk_array_merged, batch_info, strict=False - ): + for chunk_array, req in zip(chunk_array_merged, batch_info, strict=False): if chunk_array is None: chunk_array_batch.append(None) # type: ignore[unreachable] else: - if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal( - fill_value_or_default(chunk_spec) + if not req.chunk_spec.config.write_empty_chunks and chunk_array.all_equal( + fill_value_or_default(req.chunk_spec) ): chunk_array_batch.append(None) else: @@ -574,10 +573,8 @@ async def _read_key( chunk_bytes_batch = await self.encode_batch( [ - (chunk_array, chunk_spec) - for chunk_array, (_, chunk_spec, *_) in zip( - chunk_array_batch, batch_info, strict=False - ) + (chunk_array, req.chunk_spec) + for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False) ], ) @@ -589,10 +586,8 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non await concurrent_map( [ - (byte_setter, chunk_bytes) - for chunk_bytes, (byte_setter, *_) in zip( - chunk_bytes_batch, batch_info, strict=False - ) + (req.byte_setter, chunk_bytes) + for chunk_bytes, req in zip(chunk_bytes_batch, batch_info, strict=False) ], _write_key, config.get("async.concurrency"), @@ -618,7 +613,7 @@ async def encode( async def read( self, - batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ChunkRequest], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -633,7 +628,7 @@ async def read( async def write( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 73fd53087d..27e05740d9 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -373,7 +373,6 @@ class ChunkDimProjection(NamedTuple): dim_chunk_ix: int dim_chunk_sel: Selector dim_out_sel: Selector | None - is_complete_chunk: bool @dataclass(frozen=True) @@ -393,8 +392,7 @@ def __iter__(self) -> Iterator[ChunkDimProjection]: dim_offset = dim_chunk_ix * self.dim_chunk_len dim_chunk_sel = self.dim_sel - dim_offset dim_out_sel = None - is_complete_chunk = self.dim_chunk_len == 1 - yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) @dataclass(frozen=True) @@ -468,10 +466,7 @@ def __iter__(self) -> Iterator[ChunkDimProjection]: dim_out_sel = slice(dim_out_offset, dim_out_offset + dim_chunk_nitems) - is_complete_chunk = ( - dim_chunk_sel_start == 0 and (self.stop >= dim_limit) and self.step in [1, None] - ) - yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) def check_selection_length(selection: SelectionNormalized, shape: tuple[int, ...]) -> None: @@ -531,6 +526,35 @@ def ensure_tuple(v: Any) -> SelectionNormalized: return cast("SelectionNormalized", v) +def _is_complete_chunk( + dim_indexers: Sequence[ + IntDimIndexer | SliceDimIndexer | BoolArrayDimIndexer | IntArrayDimIndexer + ], + dim_projections: tuple[ChunkDimProjection, ...], +) -> bool: + """Return True if the combined chunk selection covers the full actual extent. + + The actual extent of each chunk dimension accounts for edge chunks + (where the array length is not a multiple of the chunk length). + """ + for dim_indexer, p in zip(dim_indexers, dim_projections, strict=True): + chunk_ix = p.dim_chunk_ix + actual_extent = ( + min(dim_indexer.dim_len, (chunk_ix + 1) * dim_indexer.dim_chunk_len) + - chunk_ix * dim_indexer.dim_chunk_len + ) + s = p.dim_chunk_sel + if isinstance(s, slice): + if not (s.start in (0, None) and s.stop >= actual_extent and s.step in (1, None)): + return False + elif isinstance(dim_indexer, IntDimIndexer): + if actual_extent != 1: + return False + else: + return False + return True + + class ChunkProjection(NamedTuple): """A mapping of items from chunk to output array. Can be used to extract items from the chunk array for loading into an output array. Can also be used to extract items from a @@ -544,8 +568,8 @@ class ChunkProjection(NamedTuple): Selection of items from chunk array. out_selection Selection of items in target (output) array. - is_complete_chunk: - True if a complete chunk is indexed + is_complete_chunk + True if the selection covers the full actual extent of the chunk. """ chunk_coords: tuple[int, ...] @@ -627,8 +651,8 @@ def __iter__(self) -> Iterator[ChunkProjection]: out_selection = tuple( p.dim_out_sel for p in dim_projections if p.dim_out_sel is not None ) - is_complete_chunk = all(p.is_complete_chunk for p in dim_projections) - yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete_chunk) + is_complete = _is_complete_chunk(self.dim_indexers, dim_projections) + yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete) @dataclass(frozen=True) @@ -696,9 +720,8 @@ def __iter__(self) -> Iterator[ChunkDimProjection]: start = self.chunk_nitems_cumsum[dim_chunk_ix - 1] stop = self.chunk_nitems_cumsum[dim_chunk_ix] dim_out_sel = slice(start, stop) - is_complete_chunk = False # TODO - yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) class Order(Enum): @@ -838,8 +861,7 @@ def __iter__(self) -> Iterator[ChunkDimProjection]: # find region in chunk dim_offset = dim_chunk_ix * self.dim_chunk_len dim_chunk_sel = self.dim_sel[start:stop] - dim_offset - is_complete_chunk = False # TODO - yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel, is_complete_chunk) + yield ChunkDimProjection(dim_chunk_ix, dim_chunk_sel, dim_out_sel) def slice_to_range(s: slice, length: int) -> range: @@ -976,8 +998,8 @@ def __iter__(self) -> Iterator[ChunkProjection]: if not is_basic_selection(out_selection): out_selection = ix_(out_selection, self.shape) - is_complete_chunk = all(p.is_complete_chunk for p in dim_projections) - yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete_chunk) + is_complete = _is_complete_chunk(self.dim_indexers, dim_projections) + yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete) @dataclass(frozen=True) @@ -1106,8 +1128,8 @@ def __iter__(self) -> Iterator[ChunkProjection]: out_selection = tuple( p.dim_out_sel for p in dim_projections if p.dim_out_sel is not None ) - is_complete_chunk = all(p.is_complete_chunk for p in dim_projections) - yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete_chunk) + is_complete = _is_complete_chunk(self.dim_indexers, dim_projections) + yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete) @dataclass(frozen=True) @@ -1274,8 +1296,9 @@ def __iter__(self) -> Iterator[ChunkProjection]: for (dim_sel, dim_chunk_offset) in zip(self.selection, chunk_offsets, strict=True) ) - is_complete_chunk = False # TODO - yield ChunkProjection(chunk_coords, chunk_selection, out_selection, is_complete_chunk) + yield ChunkProjection( + chunk_coords, chunk_selection, out_selection, is_complete_chunk=False + ) @dataclass(frozen=True) diff --git a/tests/test_config.py b/tests/test_config.py index c3102e8efe..b0084b6a63 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -10,7 +10,7 @@ import zarr from zarr import zeros from zarr.abc.codec import CodecPipeline -from zarr.abc.store import ByteSetter, Store +from zarr.abc.store import Store from zarr.codecs import ( BloscCodec, BytesCodec, @@ -20,9 +20,8 @@ from zarr.core.array_spec import ArraySpec from zarr.core.buffer import NDBuffer from zarr.core.buffer.core import Buffer -from zarr.core.codec_pipeline import BatchedCodecPipeline +from zarr.core.codec_pipeline import BatchedCodecPipeline, ChunkRequest from zarr.core.config import BadConfigError, config -from zarr.core.indexing import SelectorTuple from zarr.errors import ZarrUserWarning from zarr.registry import ( fully_qualified_name, @@ -140,7 +139,7 @@ def test_config_codec_pipeline_class(store: Store) -> None: class MockCodecPipeline(BatchedCodecPipeline): async def write( self, - batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]], + batch_info: Iterable[ChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: From 9473b2af6d678d5f5e9f296df3ea4f5c0238c976 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 27 Feb 2026 23:25:39 -0500 Subject: [PATCH 14/14] add sync paths for codecpipeline --- src/zarr/abc/codec.py | 59 ++++- src/zarr/core/array.py | 19 +- src/zarr/core/codec_pipeline.py | 387 +++++++++++++++++++++++++++--- src/zarr/core/config.py | 5 +- tests/test_config.py | 9 +- tests/test_sync_codec_pipeline.py | 292 +++++++++++++++++++++- 6 files changed, 711 insertions(+), 60 deletions(-) diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 50af59166c..1ae1162a31 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -19,7 +19,7 @@ from zarr.abc.store import ByteGetter, ByteSetter, Store from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import ChunkGrid - from zarr.core.codec_pipeline import ChunkRequest + from zarr.core.codec_pipeline import ChunkTransform, ReadChunkRequest, WriteChunkRequest from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType from zarr.core.indexing import ChunkProjection, SelectorTuple from zarr.core.metadata import ArrayMetadata @@ -711,6 +711,20 @@ def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: """ ... + @abstractmethod + def get_chunk_transform(self, array_spec: ArraySpec) -> ChunkTransform: + """Creates a ChunkTransform for the given array spec. + + Parameters + ---------- + array_spec : ArraySpec + + Returns + ------- + ChunkTransform + """ + ... + @abstractmethod async def decode( self, @@ -752,7 +766,7 @@ async def encode( @abstractmethod async def read( self, - batch_info: Iterable[ChunkRequest], + batch_info: Iterable[ReadChunkRequest], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -761,10 +775,10 @@ async def read( Parameters ---------- - batch_info : Iterable[ChunkRequest] - Ordered set of chunk requests. Each ``ChunkRequest`` carries the - store path (``byte_setter``), the ``ArraySpec`` for that chunk, - chunk and output selections, and whether the chunk is complete. + batch_info : Iterable[ReadChunkRequest] + Ordered set of read requests. Each carries a ``byte_getter``, + a ``ChunkTransform`` (codec chain + spec), and chunk/output + selections. If the Store returns ``None`` for a chunk, then the chunk was not written and the implementation must set the values of that chunk (or @@ -777,7 +791,7 @@ async def read( @abstractmethod async def write( self, - batch_info: Iterable[ChunkRequest], + batch_info: Iterable[WriteChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -787,14 +801,37 @@ async def write( Parameters ---------- - batch_info : Iterable[ChunkRequest] - Ordered set of chunk requests. Each ``ChunkRequest`` carries the - store path (``byte_setter``), the ``ArraySpec`` for that chunk, - chunk and output selections, and whether the chunk is complete. + batch_info : Iterable[WriteChunkRequest] + Ordered set of write requests. Each carries a ``byte_setter``, + a ``ChunkTransform`` (codec chain + spec), chunk/output + selections, and whether the chunk is complete. value : NDBuffer """ ... + @property + def supports_sync_io(self) -> bool: + """Whether this pipeline can run read/write entirely on the calling thread.""" + return False + + def read_sync( + self, + batch_info: Iterable[ReadChunkRequest], + out: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + """Synchronous read: fetch bytes from store, decode, scatter into *out*.""" + raise NotImplementedError + + def write_sync( + self, + batch_info: Iterable[WriteChunkRequest], + value: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + """Synchronous write: gather from *value*, encode, persist to store.""" + raise NotImplementedError + async def _batching_helper( func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]], diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index d1107d2ee8..479416277d 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -48,7 +48,7 @@ V2ChunkKeyEncoding, parse_chunk_key_encoding, ) -from zarr.core.codec_pipeline import ChunkRequest +from zarr.core.codec_pipeline import ReadChunkRequest, WriteChunkRequest from zarr.core.common import ( JSON, ZARR_JSON, @@ -5603,14 +5603,15 @@ async def _get_selection( # reading chunks and decoding them await codec_pipeline.read( [ - ChunkRequest( - byte_setter=store_path / metadata.encode_chunk_key(chunk_coords), - chunk_spec=metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype), + ReadChunkRequest( + byte_getter=store_path / metadata.encode_chunk_key(chunk_coords), + transform=codec_pipeline.get_chunk_transform( + metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype) + ), chunk_selection=chunk_selection, out_selection=out_selection, - is_complete_chunk=is_complete_chunk, ) - for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer + for chunk_coords, chunk_selection, out_selection, _is_complete in indexer ], out_buffer, drop_axes=indexer.drop_axes, @@ -5913,9 +5914,11 @@ async def _set_selection( # merging with existing data and encoding chunks await codec_pipeline.write( [ - ChunkRequest( + WriteChunkRequest( byte_setter=store_path / metadata.encode_chunk_key(chunk_coords), - chunk_spec=metadata.get_chunk_spec(chunk_coords, _config, prototype), + transform=codec_pipeline.get_chunk_transform( + metadata.get_chunk_spec(chunk_coords, _config, prototype) + ), chunk_selection=chunk_selection, out_selection=out_selection, is_complete_chunk=is_complete_chunk, diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index 64fcc1ecf5..b2b39e2963 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -1,5 +1,8 @@ from __future__ import annotations +import os +import threading +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from itertools import islice, pairwise from typing import TYPE_CHECKING, Any, TypeVar, cast @@ -15,14 +18,14 @@ CodecPipeline, SupportsSyncCodec, ) -from zarr.core.common import concurrent_map +from zarr.core.common import concurrent_map, product from zarr.core.config import config from zarr.core.indexing import SelectorTuple, is_scalar from zarr.errors import ZarrUserWarning from zarr.registry import register_pipeline if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Callable, Iterable, Iterator from typing import Self from zarr.abc.store import ByteGetter, ByteSetter @@ -69,6 +72,77 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any: return fill_value +# --------------------------------------------------------------------------- +# Thread-pool infrastructure for synchronous codec paths +# --------------------------------------------------------------------------- + +_MIN_CHUNK_NBYTES_FOR_POOL = 100_000 # 100 KB + + +def _get_codec_worker_config() -> tuple[bool, int, int]: + """Read the ``threading.codec_workers`` config. + + Returns + ------- + tuple[bool, int, int] + ``(enabled, min_workers, max_workers)`` + """ + codec_workers = config.get("threading.codec_workers") + enabled: bool = codec_workers.get("enabled", True) + min_workers: int = codec_workers.get("min", 0) + max_workers: int = max(codec_workers.get("max") or os.cpu_count() or 4, min_workers) + return enabled, min_workers, max_workers + + +def _choose_workers(n_chunks: int, chunk_nbytes: int, codecs: Iterable[Codec]) -> int: + """Decide how many thread-pool workers to use (0 = don't use pool).""" + if getattr(_thread_local, "in_pool_worker", False): + return 0 + + enabled, min_workers, max_workers = _get_codec_worker_config() + if not enabled: + return 0 + + if n_chunks < 2: + return min_workers + + if not any(isinstance(c, BytesBytesCodec) for c in codecs) and min_workers == 0: + return 0 + if chunk_nbytes < _MIN_CHUNK_NBYTES_FOR_POOL and min_workers == 0: + return 0 + + return max(min_workers, min(n_chunks, max_workers)) + + +def _get_pool() -> ThreadPoolExecutor: + """Get the module-level thread pool, creating it lazily.""" + global _pool + if _pool is None: + _, _, max_workers = _get_codec_worker_config() + _pool = ThreadPoolExecutor(max_workers=max_workers) + return _pool + + +_pool: ThreadPoolExecutor | None = None +_thread_local = threading.local() + + +def _mark_pool_worker(fn: Callable[..., T]) -> Callable[..., T]: + """Wrap *fn* so that ``_thread_local.in_pool_worker`` is ``True`` while it runs.""" + + def wrapper(*args: Any, **kwargs: Any) -> T: + _thread_local.in_pool_worker = True + try: + return fn(*args, **kwargs) + finally: + _thread_local.in_pool_worker = False + + return wrapper + + +_DELETED = object() + + @dataclass(slots=True, kw_only=True) class ChunkTransform: """A stored chunk, modeled as a layered array. @@ -221,14 +295,21 @@ def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int: @dataclass(slots=True) -class ChunkRequest: - """A single chunk's worth of metadata for a pipeline read or write. +class ReadChunkRequest: + """A request to read and decode a single chunk.""" + + byte_getter: ByteGetter + transform: ChunkTransform + chunk_selection: SelectorTuple + out_selection: SelectorTuple - Replaces the anonymous 5-tuples formerly threaded through ``batch_info``. - """ + +@dataclass(slots=True) +class WriteChunkRequest: + """A request to encode and write a single chunk.""" byte_setter: ByteSetter - chunk_spec: ArraySpec + transform: ChunkTransform chunk_selection: SelectorTuple out_selection: SelectorTuple is_complete_chunk: bool @@ -299,6 +380,9 @@ def __iter__(self) -> Iterator[Codec]: yield self.array_bytes_codec yield from self.bytes_bytes_codecs + def get_chunk_transform(self, array_spec: ArraySpec) -> ChunkTransform: + return ChunkTransform(codecs=tuple(self), array_spec=array_spec) + def validate( self, *, @@ -414,31 +498,41 @@ async def encode_partial_batch( async def read_batch( self, - batch_info: Iterable[ChunkRequest], + batch_info: Iterable[ReadChunkRequest], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: batch_info = list(batch_info) if self.supports_partial_decode: chunk_array_batch = await self.decode_partial_batch( - [(req.byte_setter, req.chunk_selection, req.chunk_spec) for req in batch_info] + [ + (req.byte_getter, req.chunk_selection, req.transform.array_spec) + for req in batch_info + ] ) for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False): if chunk_array is not None: out[req.out_selection] = chunk_array else: - out[req.out_selection] = fill_value_or_default(req.chunk_spec) + out[req.out_selection] = fill_value_or_default(req.transform.array_spec) else: chunk_bytes_batch = await concurrent_map( - [(req.byte_setter, req.chunk_spec.prototype) for req in batch_info], + [(req.byte_getter, req.transform.array_spec.prototype) for req in batch_info], lambda byte_getter, prototype: byte_getter.get(prototype), config.get("async.concurrency"), ) - chunk_array_batch = await self.decode_batch( - [ - (chunk_bytes, req.chunk_spec) - for chunk_bytes, req in zip(chunk_bytes_batch, batch_info, strict=False) - ], + + async def _decode_one( + chunk_bytes: Buffer | None, req: ReadChunkRequest + ) -> NDBuffer | None: + if chunk_bytes is None: + return None + return await req.transform.decode_chunk_async(chunk_bytes) + + chunk_array_batch = await concurrent_map( + list(zip(chunk_bytes_batch, batch_info, strict=False)), + _decode_one, + config.get("async.concurrency"), ) for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False): if chunk_array is not None: @@ -447,7 +541,7 @@ async def read_batch( tmp = tmp.squeeze(axis=drop_axes) out[req.out_selection] = tmp else: - out[req.out_selection] = fill_value_or_default(req.chunk_spec) + out[req.out_selection] = fill_value_or_default(req.transform.array_spec) def _merge_chunk_array( self, @@ -493,7 +587,7 @@ def _merge_chunk_array( async def write_batch( self, - batch_info: Iterable[ChunkRequest], + batch_info: Iterable[WriteChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -503,7 +597,12 @@ async def write_batch( if len(value.shape) == 0: await self.encode_partial_batch( [ - (req.byte_setter, value, req.chunk_selection, req.chunk_spec) + ( + req.byte_setter, + value, + req.chunk_selection, + req.transform.array_spec, + ) for req in batch_info ], ) @@ -514,7 +613,7 @@ async def write_batch( req.byte_setter, value[req.out_selection], req.chunk_selection, - req.chunk_spec, + req.transform.array_spec, ) for req in batch_info ], @@ -534,18 +633,25 @@ async def _read_key( [ ( None if req.is_complete_chunk else req.byte_setter, - req.chunk_spec.prototype, + req.transform.array_spec.prototype, ) for req in batch_info ], _read_key, config.get("async.concurrency"), ) - chunk_array_decoded = await self.decode_batch( - [ - (chunk_bytes, req.chunk_spec) - for chunk_bytes, req in zip(chunk_bytes_batch, batch_info, strict=False) - ], + + async def _decode_one( + chunk_bytes: Buffer | None, req: WriteChunkRequest + ) -> NDBuffer | None: + if chunk_bytes is None: + return None + return await req.transform.decode_chunk_async(chunk_bytes) + + chunk_array_decoded = await concurrent_map( + list(zip(chunk_bytes_batch, batch_info, strict=False)), + _decode_one, + config.get("async.concurrency"), ) chunk_array_merged = [ @@ -553,7 +659,7 @@ async def _read_key( chunk_array, value, req.out_selection, - req.chunk_spec, + req.transform.array_spec, req.chunk_selection, drop_axes, ) @@ -561,21 +667,28 @@ async def _read_key( ] chunk_array_batch: list[NDBuffer | None] = [] for chunk_array, req in zip(chunk_array_merged, batch_info, strict=False): + chunk_spec = req.transform.array_spec if chunk_array is None: chunk_array_batch.append(None) # type: ignore[unreachable] else: - if not req.chunk_spec.config.write_empty_chunks and chunk_array.all_equal( - fill_value_or_default(req.chunk_spec) + if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal( + fill_value_or_default(chunk_spec) ): chunk_array_batch.append(None) else: chunk_array_batch.append(chunk_array) - chunk_bytes_batch = await self.encode_batch( - [ - (chunk_array, req.chunk_spec) - for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False) - ], + async def _encode_one( + chunk_array: NDBuffer | None, req: WriteChunkRequest + ) -> Buffer | None: + if chunk_array is None: + return None + return await req.transform.encode_chunk_async(chunk_array) + + chunk_bytes_batch = await concurrent_map( + list(zip(chunk_array_batch, batch_info, strict=False)), + _encode_one, + config.get("async.concurrency"), ) async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> None: @@ -613,7 +726,7 @@ async def encode( async def read( self, - batch_info: Iterable[ChunkRequest], + batch_info: Iterable[ReadChunkRequest], out: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -628,7 +741,7 @@ async def read( async def write( self, - batch_info: Iterable[ChunkRequest], + batch_info: Iterable[WriteChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: @@ -641,6 +754,208 @@ async def write( config.get("async.concurrency"), ) + # ------------------------------------------------------------------ + # Synchronous read / write + # ------------------------------------------------------------------ + + @property + def supports_sync_io(self) -> bool: + return all(isinstance(c, SupportsSyncCodec) for c in self) + + def _scatter( + self, + chunk_arrays: list[NDBuffer | None], + batch_info_list: list[ReadChunkRequest], + out: NDBuffer, + drop_axes: tuple[int, ...], + ) -> None: + """Assign decoded chunk arrays into the output buffer.""" + for chunk_array, req in zip(chunk_arrays, batch_info_list, strict=False): + if chunk_array is not None: + tmp = chunk_array[req.chunk_selection] + if drop_axes != (): + tmp = tmp.squeeze(axis=drop_axes) + out[req.out_selection] = tmp + else: + out[req.out_selection] = fill_value_or_default(req.transform.array_spec) + + def read_sync( + self, + batch_info: Iterable[ReadChunkRequest], + out: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + batch_info_list = list(batch_info) + if not batch_info_list: + return + + first_spec = batch_info_list[0].transform.array_spec + chunk_nbytes = product(first_spec.shape) * first_spec.dtype.to_native_dtype().itemsize + n_workers = _choose_workers(len(batch_info_list), chunk_nbytes, self) + + if n_workers > 0: + # Threaded: fetch all bytes, decode in parallel, scatter. + chunk_bytes_list: list[Buffer | None] = [ + req.byte_getter.get_sync(prototype=req.transform.array_spec.prototype) # type: ignore[attr-defined] + for req in batch_info_list + ] + pool = _get_pool() + chunk_arrays: list[NDBuffer | None] = list( + pool.map( + _mark_pool_worker( + lambda cb, req: req.transform.decode_chunk(cb) if cb is not None else None + ), + chunk_bytes_list, + batch_info_list, + ) + ) + self._scatter(chunk_arrays, batch_info_list, out, drop_axes) + else: + # Non-threaded: prepare_read_sync handles IO + decode per chunk. + ab_codec = self.array_bytes_codec + for req in batch_info_list: + result = ab_codec.prepare_read_sync( + req.byte_getter, + req.chunk_selection, + req.transform, + ) + if result is not None: + if drop_axes != (): + result = result.squeeze(axis=drop_axes) + out[req.out_selection] = result + else: + out[req.out_selection] = fill_value_or_default(req.transform.array_spec) + + @staticmethod + def _write_chunk_compute( + existing_bytes: Buffer | None, + req: WriteChunkRequest, + value: NDBuffer, + drop_axes: tuple[int, ...], + ) -> Buffer | None | object: + """Per-chunk compute for the threaded write path. + + Returns encoded bytes, ``None`` for a no-op, or ``_DELETED`` + to signal that the key should be removed. + """ + chunk_spec = req.transform.array_spec + existing_chunk_array: NDBuffer | None = None + if existing_bytes is not None: + existing_chunk_array = req.transform.decode_chunk(existing_bytes) + + # Merge + if ( + existing_chunk_array is None + and value.shape == chunk_spec.shape + and value[req.out_selection].shape == chunk_spec.shape + ): + merged = value + else: + if existing_chunk_array is None: + chunk_array = chunk_spec.prototype.nd_buffer.create( + shape=chunk_spec.shape, + dtype=chunk_spec.dtype.to_native_dtype(), + order=chunk_spec.order, + fill_value=fill_value_or_default(chunk_spec), + ) + else: + chunk_array = existing_chunk_array.copy() + + if req.chunk_selection == () or is_scalar( + value.as_ndarray_like(), chunk_spec.dtype.to_native_dtype() + ): + chunk_value = value + else: + chunk_value = value[req.out_selection] + if drop_axes != (): + item = tuple( + None if idx in drop_axes else slice(None) for idx in range(chunk_spec.ndim) + ) + chunk_value = chunk_value[item] + chunk_array[req.chunk_selection] = chunk_value + merged = chunk_array + + # Check write_empty_chunks + if not chunk_spec.config.write_empty_chunks and merged.all_equal( + fill_value_or_default(chunk_spec) + ): + return _DELETED + + return req.transform.encode_chunk(merged) + + def write_sync( + self, + batch_info: Iterable[WriteChunkRequest], + value: NDBuffer, + drop_axes: tuple[int, ...] = (), + ) -> None: + batch_info_list = list(batch_info) + if not batch_info_list: + return + + first_spec = batch_info_list[0].transform.array_spec + chunk_nbytes = product(first_spec.shape) * first_spec.dtype.to_native_dtype().itemsize + n_workers = _choose_workers(len(batch_info_list), chunk_nbytes, self) + + if n_workers > 0: + # Threaded: fetch existing, compute in parallel, write results. + existing_bytes_list: list[Buffer | None] = [ + req.byte_setter.get_sync(prototype=req.transform.array_spec.prototype) # type: ignore[attr-defined] + if not req.is_complete_chunk + else None + for req in batch_info_list + ] + pool = _get_pool() + n = len(batch_info_list) + encoded_list: list[Buffer | None | object] = list( + pool.map( + _mark_pool_worker(self._write_chunk_compute), + existing_bytes_list, + batch_info_list, + [value] * n, + [drop_axes] * n, + ) + ) + for encoded, req in zip(encoded_list, batch_info_list, strict=False): + if encoded is _DELETED: + req.byte_setter.delete_sync() # type: ignore[attr-defined] + elif encoded is not None: + req.byte_setter.set_sync(encoded) # type: ignore[attr-defined] + else: + # Non-threaded: prepare/compute/finalize per chunk. + ab_codec = self.array_bytes_codec + for req in batch_info_list: + prepared = ab_codec.prepare_write_sync( + req.byte_setter, + req.transform, + req.chunk_selection, + req.out_selection, + req.is_complete_chunk, + ) + for coords, chunk_sel, out_sel, _is_complete in prepared.indexer: + existing_inner = prepared.chunk_dict.get(coords) + if existing_inner is not None: + existing_array = req.transform.decode_chunk(existing_inner) + else: + existing_array = None + merged = self._merge_chunk_array( + existing_array, + value, + out_sel, + req.transform.array_spec, + chunk_sel, + drop_axes, + ) + inner_spec = req.transform.array_spec + if not inner_spec.config.write_empty_chunks and merged.all_equal( + fill_value_or_default(inner_spec) + ): + prepared.chunk_dict[coords] = None + else: + prepared.chunk_dict[coords] = req.transform.encode_chunk(merged) + + ab_codec.finalize_write_sync(prepared, req.transform, req.byte_setter) + def codecs_from_list( codecs: Iterable[Codec], diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index f8f8ea4f5f..e7d9e7ec20 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -99,7 +99,10 @@ def enable_gpu(self) -> ConfigSet: "target_shard_size_bytes": None, }, "async": {"concurrency": 10, "timeout": None}, - "threading": {"max_workers": None}, + "threading": { + "max_workers": None, + "codec_workers": {"enabled": True, "min": 0, "max": None}, + }, "json_indent": 2, "codec_pipeline": { "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", diff --git a/tests/test_config.py b/tests/test_config.py index b0084b6a63..7a1c9f3f4e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -20,7 +20,7 @@ from zarr.core.array_spec import ArraySpec from zarr.core.buffer import NDBuffer from zarr.core.buffer.core import Buffer -from zarr.core.codec_pipeline import BatchedCodecPipeline, ChunkRequest +from zarr.core.codec_pipeline import BatchedCodecPipeline, WriteChunkRequest from zarr.core.config import BadConfigError, config from zarr.errors import ZarrUserWarning from zarr.registry import ( @@ -55,7 +55,10 @@ def test_config_defaults_set() -> None: "target_shard_size_bytes": None, }, "async": {"concurrency": 10, "timeout": None}, - "threading": {"max_workers": None}, + "threading": { + "max_workers": None, + "codec_workers": {"enabled": True, "min": 0, "max": None}, + }, "json_indent": 2, "codec_pipeline": { "path": "zarr.core.codec_pipeline.BatchedCodecPipeline", @@ -139,7 +142,7 @@ def test_config_codec_pipeline_class(store: Store) -> None: class MockCodecPipeline(BatchedCodecPipeline): async def write( self, - batch_info: Iterable[ChunkRequest], + batch_info: Iterable[WriteChunkRequest], value: NDBuffer, drop_axes: tuple[int, ...] = (), ) -> None: diff --git a/tests/test_sync_codec_pipeline.py b/tests/test_sync_codec_pipeline.py index e9d05dcec6..d241adee49 100644 --- a/tests/test_sync_codec_pipeline.py +++ b/tests/test_sync_codec_pipeline.py @@ -3,6 +3,7 @@ from typing import Any import numpy as np +import pytest from zarr.codecs.bytes import BytesCodec from zarr.codecs.gzip import GzipCodec @@ -10,8 +11,16 @@ from zarr.codecs.zstd import ZstdCodec from zarr.core.array_spec import ArrayConfig, ArraySpec from zarr.core.buffer import NDBuffer, default_buffer_prototype -from zarr.core.codec_pipeline import ChunkTransform +from zarr.core.codec_pipeline import ( + BatchedCodecPipeline, + ChunkTransform, + ReadChunkRequest, + WriteChunkRequest, + _choose_workers, +) from zarr.core.dtype import get_data_type_from_native_dtype +from zarr.storage._common import StorePath +from zarr.storage._memory import MemoryStore def _make_array_spec(shape: tuple[int, ...], dtype: np.dtype[np.generic]) -> ArraySpec: @@ -108,3 +117,284 @@ def test_encode_decode_roundtrip_with_transpose(self) -> None: assert encoded is not None decoded = chain.decode_chunk(encoded) np.testing.assert_array_equal(arr, decoded.as_numpy_array()) + + +# --------------------------------------------------------------------------- +# Helpers for sync pipeline tests +# --------------------------------------------------------------------------- + + +def _make_pipeline( + codecs: tuple[Any, ...], +) -> BatchedCodecPipeline: + return BatchedCodecPipeline.from_codecs(codecs) + + +def _make_store_path(key: str = "chunk/0") -> StorePath: + store = MemoryStore() + return StorePath(store, key) + + +# --------------------------------------------------------------------------- +# Sync pipeline tests +# --------------------------------------------------------------------------- + + +class TestSyncPipeline: + def test_supports_sync_io(self) -> None: + pipeline = _make_pipeline((BytesCodec(),)) + assert pipeline.supports_sync_io is True + + def test_supports_sync_io_with_compression(self) -> None: + pipeline = _make_pipeline((BytesCodec(), ZstdCodec())) + assert pipeline.supports_sync_io is True + + def test_write_sync_read_sync_roundtrip(self) -> None: + """Write data via write_sync, read it back via read_sync.""" + arr = np.arange(100, dtype="float64") + spec = _make_array_spec(arr.shape, arr.dtype) + pipeline = _make_pipeline((BytesCodec(),)) + store_path = _make_store_path() + transform = ChunkTransform(codecs=tuple(pipeline), array_spec=spec) + + value = _make_nd_buffer(arr) + + # Write + pipeline.write_sync( + [ + WriteChunkRequest( + byte_setter=store_path, + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(None),), + is_complete_chunk=True, + ) + ], + value, + ) + + # Read + out = default_buffer_prototype().nd_buffer.create( + shape=arr.shape, + dtype=arr.dtype, + order="C", + fill_value=0, + ) + pipeline.read_sync( + [ + ReadChunkRequest( + byte_getter=store_path, + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(None),), + ) + ], + out, + ) + + np.testing.assert_array_equal(arr, out.as_numpy_array()) + + def test_write_sync_read_sync_with_compression(self) -> None: + """Round-trip with a compression codec.""" + arr = np.arange(200, dtype="float32").reshape(10, 20) + spec = _make_array_spec(arr.shape, arr.dtype) + pipeline = _make_pipeline((BytesCodec(), GzipCodec(level=1))) + store_path = _make_store_path() + transform = ChunkTransform(codecs=tuple(pipeline), array_spec=spec) + + value = _make_nd_buffer(arr) + + pipeline.write_sync( + [ + WriteChunkRequest( + byte_setter=store_path, + transform=transform, + chunk_selection=(slice(None), slice(None)), + out_selection=(slice(None), slice(None)), + is_complete_chunk=True, + ) + ], + value, + ) + + out = default_buffer_prototype().nd_buffer.create( + shape=arr.shape, + dtype=arr.dtype, + order="C", + fill_value=0, + ) + pipeline.read_sync( + [ + ReadChunkRequest( + byte_getter=store_path, + transform=transform, + chunk_selection=(slice(None), slice(None)), + out_selection=(slice(None), slice(None)), + ) + ], + out, + ) + + np.testing.assert_array_equal(arr, out.as_numpy_array()) + + def test_write_sync_partial_chunk(self) -> None: + """Write a partial chunk (is_complete_chunk=False), read back full chunk.""" + shape = (10,) + spec = _make_array_spec(shape, np.dtype("float64")) + pipeline = _make_pipeline((BytesCodec(),)) + store_path = _make_store_path() + transform = ChunkTransform(codecs=tuple(pipeline), array_spec=spec) + + # Write only first 5 elements + value = _make_nd_buffer(np.arange(5, dtype="float64")) + + pipeline.write_sync( + [ + WriteChunkRequest( + byte_setter=store_path, + transform=transform, + chunk_selection=(slice(0, 5),), + out_selection=(slice(None),), + is_complete_chunk=False, + ) + ], + value, + ) + + # Read back full chunk + out = default_buffer_prototype().nd_buffer.create( + shape=shape, + dtype=np.dtype("float64"), + order="C", + fill_value=-1, + ) + pipeline.read_sync( + [ + ReadChunkRequest( + byte_getter=store_path, + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(None),), + ) + ], + out, + ) + + result = out.as_numpy_array() + np.testing.assert_array_equal(result[:5], np.arange(5, dtype="float64")) + # Remaining elements should be fill value (0) + np.testing.assert_array_equal(result[5:], 0) + + def test_read_sync_missing_chunk(self) -> None: + """Reading a non-existent chunk should fill with fill value.""" + spec = _make_array_spec((10,), np.dtype("float64")) + pipeline = _make_pipeline((BytesCodec(),)) + store_path = _make_store_path("nonexistent/chunk") + transform = ChunkTransform(codecs=tuple(pipeline), array_spec=spec) + + out = default_buffer_prototype().nd_buffer.create( + shape=(10,), + dtype=np.dtype("float64"), + order="C", + fill_value=-1, + ) + pipeline.read_sync( + [ + ReadChunkRequest( + byte_getter=store_path, + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(None),), + ) + ], + out, + ) + + # Should be filled with the spec's fill value (0) + np.testing.assert_array_equal(out.as_numpy_array(), 0) + + def test_write_sync_multiple_chunks(self) -> None: + """Write and read multiple chunks in one batch.""" + spec = _make_array_spec((10,), np.dtype("float64")) + pipeline = _make_pipeline((BytesCodec(),)) + store = MemoryStore() + transform = ChunkTransform(codecs=tuple(pipeline), array_spec=spec) + + value = _make_nd_buffer(np.arange(20, dtype="float64")) + + pipeline.write_sync( + [ + WriteChunkRequest( + byte_setter=StorePath(store, "c/0"), + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(0, 10),), + is_complete_chunk=True, + ), + WriteChunkRequest( + byte_setter=StorePath(store, "c/1"), + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(10, 20),), + is_complete_chunk=True, + ), + ], + value, + ) + + out = default_buffer_prototype().nd_buffer.create( + shape=(20,), + dtype=np.dtype("float64"), + order="C", + fill_value=0, + ) + pipeline.read_sync( + [ + ReadChunkRequest( + byte_getter=StorePath(store, "c/0"), + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(0, 10),), + ), + ReadChunkRequest( + byte_getter=StorePath(store, "c/1"), + transform=transform, + chunk_selection=(slice(None),), + out_selection=(slice(10, 20),), + ), + ], + out, + ) + + np.testing.assert_array_equal(out.as_numpy_array(), np.arange(20, dtype="float64")) + + +class TestChooseWorkers: + def test_returns_zero_for_single_chunk(self) -> None: + codecs = (BytesCodec(), ZstdCodec()) + assert _choose_workers(1, 1_000_000, codecs) == 0 + + def test_returns_nonzero_for_large_compressed_batch(self) -> None: + codecs = (BytesCodec(), ZstdCodec()) + n_workers = _choose_workers(10, 1_000_000, codecs) + assert n_workers > 0 + + def test_returns_zero_without_bb_codecs(self) -> None: + codecs: tuple[Any, ...] = (BytesCodec(),) + assert _choose_workers(10, 1_000_000, codecs) == 0 + + def test_returns_zero_for_small_chunks(self) -> None: + codecs = (BytesCodec(), ZstdCodec()) + assert _choose_workers(10, 100, codecs) == 0 + + @pytest.mark.parametrize("enabled", [True, False]) + def test_config_enabled(self, enabled: bool) -> None: + from zarr.core.config import config + + with config.set({"threading.codec_workers.enabled": enabled}): + codecs = (BytesCodec(), ZstdCodec()) + result = _choose_workers(10, 1_000_000, codecs) + if enabled: + assert result > 0 + else: + assert result == 0