Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def _write_raster(
**metadata,
)
elif isinstance(raster_data, DataTree):
_write_raster_datatree(
group = _write_raster_datatree(
raster_type,
group,
name,
Expand Down Expand Up @@ -409,7 +409,7 @@ def _write_raster_datatree(
raster_format: RasterFormatType,
storage_options: JSONDict | list[JSONDict] | None = None,
**metadata: str | JSONDict | list[JSONDict],
) -> None:
) -> zarr.Group:
"""Write raster data of type DataTree to disk.

Parameters
Expand Down Expand Up @@ -460,13 +460,23 @@ def _write_raster_datatree(
# os.replace is called. These can also be alleviated by using 'single-threaded' scheduler.
da.compute(*dask_delayed, optimize_graph=False)

# Workaround for https://github.com/scverse/spatialdata/issues/1024.
# ome-zarr-py bundles write_multiscales_metadata() as a dask.delayed task in the compute=False
# code path (see https://github.com/ome/ome-zarr-py/issues/580). When da.compute() runs with
# the 'processes' scheduler that task executes in a subprocess: the on-disk zarr.json is written
# correctly, but the zarr.Group held in this process keeps its original in-memory GroupMetadata
# and never sees the update. Re-opening the group forces a fresh read from the store.
# This workaround should not be needed once https://github.com/ome/ome-zarr-py/issues/580 is fixed.
group = zarr.open_group(store=group.store, path=group.path, mode="r+", use_consolidated=False)

trans_group = group["labels"][element_name] if raster_type == "labels" else group
overwrite_coordinate_transformations_raster(
group=trans_group,
transformations=transformations,
axes=tuple(input_axes),
raster_format=raster_format,
)
return group


def write_image(
Expand Down
23 changes: 22 additions & 1 deletion tests/io/test_pyramids_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
import xarray as xr
import zarr

from spatialdata import SpatialData
from spatialdata import SpatialData, read_zarr
from spatialdata._io import write_image
from spatialdata._io.format import CurrentRasterFormat
from spatialdata.models import Image2DModel
from spatialdata.testing import assert_spatial_data_objects_are_identical

if TYPE_CHECKING:
import _pytest.fixtures
Expand Down Expand Up @@ -95,3 +96,23 @@ def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_p
# In addition, we could do use a mock side effect to check that the entry points from within spatialdata are within
# the expected range.
assert actual_num_chunk_reads in range(0, num_chunks_scale0.item() * 2 + 1)


@pytest.mark.parametrize("scheduler", ["threads", "processes"])
def test_write_multiscale_image_dask_scheduler(tmp_path: Path, scheduler: str) -> None:
# Regression test for https://github.com/scverse/spatialdata/issues/1024.
# Writing a multiscale image with the 'processes' Dask scheduler previously raised
# KeyError: 'ome' because ome-zarr-py runs write_multiscales_metadata() as a
# dask.delayed task (https://github.com/ome/ome-zarr-py/issues/580): the metadata
# write occurs in a subprocess and the zarr.Group in the main process is never
# refreshed, so subsequent metadata reads fail.
rng = np.random.default_rng(0)
arr = dask.array.from_array(rng.random((3, 64, 64)).astype("float32"), chunks=(3, 32, 32))
image = Image2DModel.parse(arr, dims=["c", "y", "x"], scale_factors=[2])
sdata = SpatialData(images={"img": image})

store_path = tmp_path / "test.zarr"
with dask.config.set(scheduler=scheduler):
sdata.write(store_path)

assert_spatial_data_objects_are_identical(sdata, read_zarr(store_path))
Loading