Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/zarr/core/codec_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ async def read_batch(
chunk_array_batch, batch_info, strict=False
):
if chunk_array is not None:
if drop_axes:
chunk_array = chunk_array.squeeze(axis=drop_axes)
out[out_selection] = chunk_array
else:
out[out_selection] = fill_value_or_default(chunk_spec)
Expand All @@ -285,7 +287,7 @@ async def read_batch(
):
if chunk_array is not None:
tmp = chunk_array[chunk_selection]
if drop_axes != ():
if drop_axes:
tmp = tmp.squeeze(axis=drop_axes)
out[out_selection] = tmp
else:
Expand Down Expand Up @@ -324,7 +326,7 @@ def _merge_chunk_array(
else:
chunk_value = value[out_selection]
# handle missing singleton dimensions
if drop_axes != ():
if drop_axes:
item = tuple(
None # equivalent to np.newaxis
if idx in drop_axes
Expand Down
56 changes: 55 additions & 1 deletion tests/test_codecs/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,8 @@ def test_invalid_shard_shape() -> None:
with pytest.raises(
ValueError,
match=re.escape(
"The array's `chunk_shape` (got (16, 16)) needs to be divisible by the shard's inner `chunk_shape` (got (9,))."
"The array's `chunk_shape` (got (16, 16)) needs to be divisible "
"by the shard's inner `chunk_shape` (got (9,))."
),
):
zarr.create_array(
Expand All @@ -501,3 +502,56 @@ def test_invalid_shard_shape() -> None:
dtype=np.dtype("uint8"),
fill_value=0,
)


@pytest.mark.parametrize("store", ["local"], indirect=["store"])
def test_sharding_mixed_integer_list_indexing(store: Store) -> None:
"""Regression test for https://github.com/zarr-developers/zarr-python/issues/3691.

Mixed integer/list indexing on sharded arrays should return the same
shape and data as on equivalent chunked arrays.
"""
import numpy as np

data = np.arange(200 * 100 * 10, dtype=np.uint8).reshape(200, 100, 10)

chunked = zarr.create_array(
store,
name="chunked",
shape=(200, 100, 10),
dtype=np.uint8,
chunks=(200, 100, 1),
overwrite=True,
)
chunked[:, :, :] = data

sharded = zarr.create_array(
store,
name="sharded",
shape=(200, 100, 10),
dtype=np.uint8,
chunks=(200, 100, 1),
shards=(200, 100, 10),
overwrite=True,
)
sharded[:, :, :] = data

# Mixed integer + list indexing
c = chunked[0:10, 0, [0, 1]] # type: ignore[index]
s = sharded[0:10, 0, [0, 1]] # type: ignore[index]
assert c.shape == s.shape == (10, 2), ( # type: ignore[union-attr]
f"Expected (10, 2), got chunked={c.shape}, sharded={s.shape}" # type: ignore[union-attr]
)
np.testing.assert_array_equal(c, s)

# Multiple integer axes
c2 = chunked[0, 0, [0, 1, 2]] # type: ignore[index]
s2 = sharded[0, 0, [0, 1, 2]] # type: ignore[index]
assert c2.shape == s2.shape == (3,) # type: ignore[union-attr]
np.testing.assert_array_equal(c2, s2)

# Slice + integer + slice
c3 = chunked[0:5, 1, 0:3]
s3 = sharded[0:5, 1, 0:3]
assert c3.shape == s3.shape == (5, 3) # type: ignore[union-attr]
np.testing.assert_array_equal(c3, s3)
Loading