diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index a8b2ab2c..027ac8ce 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -316,7 +316,7 @@ def _write_raster( **metadata, ) elif isinstance(raster_data, DataTree): - _write_raster_datatree( + group = _write_raster_datatree( raster_type, group, name, @@ -409,7 +409,7 @@ def _write_raster_datatree( raster_format: RasterFormatType, storage_options: JSONDict | list[JSONDict] | None = None, **metadata: str | JSONDict | list[JSONDict], -) -> None: +) -> zarr.Group: """Write raster data of type DataTree to disk. Parameters @@ -460,6 +460,15 @@ def _write_raster_datatree( # os.replace is called. These can also be alleviated by using 'single-threaded' scheduler. da.compute(*dask_delayed, optimize_graph=False) + # Workaround for https://github.com/scverse/spatialdata/issues/1024. + # ome-zarr-py bundles write_multiscales_metadata() as a dask.delayed task in the compute=False + # code path (see https://github.com/ome/ome-zarr-py/issues/580). When da.compute() runs with + # the 'processes' scheduler that task executes in a subprocess: the on-disk zarr.json is written + # correctly, but the zarr.Group held in this process keeps its original in-memory GroupMetadata + # and never sees the update. Re-opening the group forces a fresh read from the store. + # This workaround should not be needed once https://github.com/ome/ome-zarr-py/issues/580 is fixed. + group = zarr.open_group(store=group.store, path=group.path, mode="r+", use_consolidated=False) + trans_group = group["labels"][element_name] if raster_type == "labels" else group overwrite_coordinate_transformations_raster( group=trans_group, @@ -467,6 +476,7 @@ def _write_raster_datatree( axes=tuple(input_axes), raster_format=raster_format, ) + return group def write_image( diff --git a/tests/io/test_pyramids_performance.py b/tests/io/test_pyramids_performance.py index 31633f63..639d1051 100644 --- a/tests/io/test_pyramids_performance.py +++ b/tests/io/test_pyramids_performance.py @@ -10,10 +10,11 @@ import xarray as xr import zarr -from spatialdata import SpatialData +from spatialdata import SpatialData, read_zarr from spatialdata._io import write_image from spatialdata._io.format import CurrentRasterFormat from spatialdata.models import Image2DModel +from spatialdata.testing import assert_spatial_data_objects_are_identical if TYPE_CHECKING: import _pytest.fixtures @@ -95,3 +96,23 @@ def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_p # In addition, we could do use a mock side effect to check that the entry points from within spatialdata are within # the expected range. assert actual_num_chunk_reads in range(0, num_chunks_scale0.item() * 2 + 1) + + +@pytest.mark.parametrize("scheduler", ["threads", "processes"]) +def test_write_multiscale_image_dask_scheduler(tmp_path: Path, scheduler: str) -> None: + # Regression test for https://github.com/scverse/spatialdata/issues/1024. + # Writing a multiscale image with the 'processes' Dask scheduler previously raised + # KeyError: 'ome' because ome-zarr-py runs write_multiscales_metadata() as a + # dask.delayed task (https://github.com/ome/ome-zarr-py/issues/580): the metadata + # write occurs in a subprocess and the zarr.Group in the main process is never + # refreshed, so subsequent metadata reads fail. + rng = np.random.default_rng(0) + arr = dask.array.from_array(rng.random((3, 64, 64)).astype("float32"), chunks=(3, 32, 32)) + image = Image2DModel.parse(arr, dims=["c", "y", "x"], scale_factors=[2]) + sdata = SpatialData(images={"img": image}) + + store_path = tmp_path / "test.zarr" + with dask.config.set(scheduler=scheduler): + sdata.write(store_path) + + assert_spatial_data_objects_are_identical(sdata, read_zarr(store_path))