Skip to content

Commit b95c710

Browse files
committed
fix: add contextmanager preventing settings leakage between tests
fix: pass raster_write_kwargs recursively when writing list of elements
1 parent bbb4bb6 commit b95c710

3 files changed

Lines changed: 27 additions & 20 deletions

File tree

src/spatialdata/_core/spatialdata.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,6 +1359,7 @@ def write_element(
13591359
overwrite=overwrite,
13601360
sdata_formats=sdata_formats,
13611361
shapes_geometry_encoding=shapes_geometry_encoding,
1362+
raster_write_kwargs=raster_write_kwargs,
13621363
)
13631364
return
13641365

tests/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
from collections.abc import Sequence
4+
from contextlib import contextmanager
5+
from dataclasses import replace
46
from pathlib import Path
57
from typing import Any
68

@@ -19,6 +21,7 @@
1921
from skimage import data
2022
from xarray import DataArray, DataTree
2123

24+
from spatialdata import settings
2225
from spatialdata._core._deepcopy import deepcopy
2326
from spatialdata._core.spatialdata import SpatialData
2427
from spatialdata._types import ArrayLike
@@ -653,3 +656,16 @@ def settings_cls(tmp_path, monkeypatch):
653656

654657
monkeypatch.setattr("spatialdata.config._config_path", lambda: tmp_path / "default_settings.json")
655658
return Settings
659+
660+
661+
@contextmanager
662+
def temporary_settings(**kwargs):
663+
old = replace(settings)
664+
try:
665+
for k, v in kwargs.items():
666+
setattr(settings, k, v)
667+
settings.save()
668+
yield
669+
finally:
670+
settings.__dict__.update(old.__dict__)
671+
settings.save()

tests/io/test_readwrite.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
_get_shapes,
5050
_get_table,
5151
_get_tables,
52+
temporary_settings,
5253
)
5354

5455
RNG = default_rng(0)
@@ -806,28 +807,17 @@ def test_write_raster_sharding(
806807

807808

808809
def test_write_raster_sharding_with_settings(tmp_path: Path) -> None:
809-
from dataclasses import replace
810+
with temporary_settings(raster_chunks=(1, 100, 100)):
811+
data = da.from_array(RNG.random((1, 1000, 1000)), chunks=(1, 200, 200))
812+
element = Image2DModel.parse(data, dims=("c", "y", "x"))
813+
name = "element"
814+
sdata = SpatialData(images={name: element})
815+
path = tmp_path / "data.zarr"
810816

811-
from spatialdata import settings
812-
813-
old_settings = replace(settings)
814-
settings.raster_chunks = (1, 100, 100)
815-
settings.save()
816-
817-
data = da.from_array(RNG.random((1, 1000, 1000)), chunks=(1, 200, 200))
818-
element = Image2DModel.parse(data, dims=("c", "y", "x"))
819-
name = "element"
820-
sdata = SpatialData(images={name: element})
821-
path = tmp_path / "data.zarr"
817+
sdata.write(path)
822818

823-
sdata.write(
824-
path,
825-
)
826-
arr = zarr.open_group(path / "images" / name, mode="r")["s0"]
827-
assert arr.chunks == (1, 100, 100)
828-
old_settings.save()
829-
s = settings.load()
830-
assert s.raster_chunks == old_settings.raster_chunks
819+
arr = zarr.open_group(path / "images" / name, mode="r")["s0"]
820+
assert arr.chunks == (1, 100, 100)
831821

832822

833823
@pytest.mark.parametrize("raster_case", RASTER_CASES_MULTISCALE)

0 commit comments

Comments
 (0)