diff --git a/src/zarr/core/codec_pipeline.py b/src/zarr/core/codec_pipeline.py index fd557ac43e..eed49556d3 100644 --- a/src/zarr/core/codec_pipeline.py +++ b/src/zarr/core/codec_pipeline.py @@ -263,6 +263,8 @@ async def read_batch( chunk_array_batch, batch_info, strict=False ): if chunk_array is not None: + if drop_axes: + chunk_array = chunk_array.squeeze(axis=drop_axes) out[out_selection] = chunk_array else: out[out_selection] = fill_value_or_default(chunk_spec) @@ -285,7 +287,7 @@ async def read_batch( ): if chunk_array is not None: tmp = chunk_array[chunk_selection] - if drop_axes != (): + if drop_axes: tmp = tmp.squeeze(axis=drop_axes) out[out_selection] = tmp else: @@ -324,7 +326,7 @@ def _merge_chunk_array( else: chunk_value = value[out_selection] # handle missing singleton dimensions - if drop_axes != (): + if drop_axes: item = tuple( None # equivalent to np.newaxis if idx in drop_axes diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index d0e2d09b7c..d7cbeb5bdb 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -490,7 +490,8 @@ def test_invalid_shard_shape() -> None: with pytest.raises( ValueError, match=re.escape( - "The array's `chunk_shape` (got (16, 16)) needs to be divisible by the shard's inner `chunk_shape` (got (9,))." + "The array's `chunk_shape` (got (16, 16)) needs to be divisible " + "by the shard's inner `chunk_shape` (got (9,))." ), ): zarr.create_array( @@ -501,3 +502,56 @@ def test_invalid_shard_shape() -> None: dtype=np.dtype("uint8"), fill_value=0, ) + + +@pytest.mark.parametrize("store", ["local"], indirect=["store"]) +def test_sharding_mixed_integer_list_indexing(store: Store) -> None: + """Regression test for https://github.com/zarr-developers/zarr-python/issues/3691. + + Mixed integer/list indexing on sharded arrays should return the same + shape and data as on equivalent chunked arrays. + """ + import numpy as np + + data = np.arange(200 * 100 * 10, dtype=np.uint8).reshape(200, 100, 10) + + chunked = zarr.create_array( + store, + name="chunked", + shape=(200, 100, 10), + dtype=np.uint8, + chunks=(200, 100, 1), + overwrite=True, + ) + chunked[:, :, :] = data + + sharded = zarr.create_array( + store, + name="sharded", + shape=(200, 100, 10), + dtype=np.uint8, + chunks=(200, 100, 1), + shards=(200, 100, 10), + overwrite=True, + ) + sharded[:, :, :] = data + + # Mixed integer + list indexing + c = chunked[0:10, 0, [0, 1]] # type: ignore[index] + s = sharded[0:10, 0, [0, 1]] # type: ignore[index] + assert c.shape == s.shape == (10, 2), ( # type: ignore[union-attr] + f"Expected (10, 2), got chunked={c.shape}, sharded={s.shape}" # type: ignore[union-attr] + ) + np.testing.assert_array_equal(c, s) + + # Multiple integer axes + c2 = chunked[0, 0, [0, 1, 2]] # type: ignore[index] + s2 = sharded[0, 0, [0, 1, 2]] # type: ignore[index] + assert c2.shape == s2.shape == (3,) # type: ignore[union-attr] + np.testing.assert_array_equal(c2, s2) + + # Slice + integer + slice + c3 = chunked[0:5, 1, 0:3] + s3 = sharded[0:5, 1, 0:3] + assert c3.shape == s3.shape == (5, 3) # type: ignore[union-attr] + np.testing.assert_array_equal(c3, s3)