Skip to content

Commit acb25b1

Browse files
LucaMarconatoclaude
andcommitted
perf: skip re-validation when building SpatialData from already-valid elements
Add skip_element_validation() context manager (backed by a ContextVar) that makes __setitem__ call get_model(validate=False) — type inference only, no schema.validate(). Use it in every code path that constructs a SpatialData from elements that originated from an existing SpatialData and were never externally mutated: bounding_box_query, polygon_query, query_by_coordinate_system, transform_to_coordinate_system, subset, and init_from_elements. test_query_spatial_data: 0.77s → 0.64s (the remaining time is the query work itself — filtering, shapely ops, raster cropping). Also inline a minimal 2-image SpatialData in test_transformations_between_coordinate_systems instead of relying on the full 8-element images fixture; the test only ever uses image2d and image2d_multiscale, so writing the other 6 to disk was pure waste. test_transformations_between_coordinate_systems: 0.61s → 0.44s. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ac97136 commit acb25b1

4 files changed

Lines changed: 49 additions & 14 deletions

File tree

src/spatialdata/_core/_elements.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from __future__ import annotations
44

5+
import contextlib
56
from collections import UserDict
6-
from collections.abc import Iterable, KeysView, ValuesView
7+
from collections.abc import Iterable, Iterator, KeysView, ValuesView
8+
from contextvars import ContextVar
79
from typing import TypeVar
810

911
from anndata import AnnData
@@ -24,6 +26,25 @@
2426
get_model,
2527
)
2628

29+
_skip_element_validation: ContextVar[bool] = ContextVar("_skip_element_validation", default=False)
30+
31+
32+
@contextlib.contextmanager
33+
def skip_element_validation() -> Iterator[None]:
34+
"""
35+
Context manager to skip schema validation when inserting elements into SpatialData containers.
36+
37+
Use this only when inserting elements that are already known to be valid (e.g. elements
38+
taken directly from an existing SpatialData object). Skipping validation is unsafe for
39+
externally-sourced data.
40+
"""
41+
token = _skip_element_validation.set(True)
42+
try:
43+
yield
44+
finally:
45+
_skip_element_validation.reset(token)
46+
47+
2748
T = TypeVar("T")
2849

2950

@@ -69,7 +90,7 @@ def values(self) -> ValuesView[T]:
6990
class Images(Elements[DataArray | DataTree]):
7091
def __setitem__(self, key: str, value: Raster_T) -> None:
7192
self._check_key(key, self.keys(), self._shared_keys)
72-
schema = get_model(value)
93+
schema = get_model(value, validate=not _skip_element_validation.get())
7394
if schema not in (Image2DModel, Image3DModel):
7495
raise TypeError(f"Unknown element type with schema: {schema!r}.")
7596
super().__setitem__(key, value)
@@ -78,7 +99,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
7899
class Labels(Elements[DataArray | DataTree]):
79100
def __setitem__(self, key: str, value: Raster_T) -> None:
80101
self._check_key(key, self.keys(), self._shared_keys)
81-
schema = get_model(value)
102+
schema = get_model(value, validate=not _skip_element_validation.get())
82103
if schema not in (Labels2DModel, Labels3DModel):
83104
raise TypeError(f"Unknown element type with schema: {schema!r}.")
84105
super().__setitem__(key, value)
@@ -87,7 +108,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
87108
class Shapes(Elements[GeoDataFrame]):
88109
def __setitem__(self, key: str, value: GeoDataFrame) -> None:
89110
self._check_key(key, self.keys(), self._shared_keys)
90-
schema = get_model(value)
111+
schema = get_model(value, validate=not _skip_element_validation.get())
91112
if schema != ShapesModel:
92113
raise TypeError(f"Unknown element type with schema: {schema!r}.")
93114
super().__setitem__(key, value)
@@ -96,7 +117,7 @@ def __setitem__(self, key: str, value: GeoDataFrame) -> None:
96117
class Points(Elements[DaskDataFrame]):
97118
def __setitem__(self, key: str, value: DaskDataFrame) -> None:
98119
self._check_key(key, self.keys(), self._shared_keys)
99-
schema = get_model(value)
120+
schema = get_model(value, validate=not _skip_element_validation.get())
100121
if schema != PointsModel:
101122
raise TypeError(f"Unknown element type with schema: {schema!r}.")
102123
super().__setitem__(key, value)
@@ -105,7 +126,7 @@ def __setitem__(self, key: str, value: DaskDataFrame) -> None:
105126
class Tables(Elements[AnnData]):
106127
def __setitem__(self, key: str, value: AnnData) -> None:
107128
self._check_key(key, self.keys(), self._shared_keys)
108-
schema = get_model(value)
129+
schema = get_model(value, validate=not _skip_element_validation.get())
109130
if schema != TableModel:
110131
raise TypeError(f"Unknown element type with schema: {schema!r}.")
111132
super().__setitem__(key, value)

src/spatialdata/_core/query/spatial_query.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from xarray import DataArray, DataTree
1616

1717
from spatialdata import to_polygons
18+
from spatialdata._core._elements import skip_element_validation
1819
from spatialdata._core.query._utils import _get_filtered_or_unfiltered_tables, get_bounding_box_corners
1920
from spatialdata._core.spatialdata import SpatialData
2021
from spatialdata._docs import docstring_parameter
@@ -538,7 +539,8 @@ def _(
538539

539540
tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata)
540541

541-
return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs)
542+
with skip_element_validation():
543+
return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs)
542544

543545

544546
@bounding_box_query.register(DataArray)
@@ -874,7 +876,8 @@ def _(
874876

875877
tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata)
876878

877-
return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs)
879+
with skip_element_validation():
880+
return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs)
878881

879882

880883
@polygon_query.register(DataArray)

src/spatialdata/_core/spatialdata.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from xarray import DataArray, DataTree
2222
from zarr.errors import GroupNotFoundError
2323

24-
from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables
24+
from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables, skip_element_validation
2525
from spatialdata._core.validation import (
2626
check_all_keys_case_insensitively_unique,
2727
check_target_region_column_symmetry,
@@ -641,7 +641,8 @@ def filter_by_coordinate_system(
641641
element_names=element_names_in_coordinate_system,
642642
)
643643

644-
return SpatialData(**elements, tables=tables, attrs=self.attrs)
644+
with skip_element_validation():
645+
return SpatialData(**elements, tables=tables, attrs=self.attrs)
645646

646647
# TODO: move to relational query with refactor
647648
def _filter_tables(
@@ -885,7 +886,8 @@ def transform_to_coordinate_system(
885886
if element_type not in elements:
886887
elements[element_type] = {}
887888
elements[element_type][element_name] = transformed
888-
return SpatialData(**elements, tables=sdata.tables, attrs=self.attrs)
889+
with skip_element_validation():
890+
return SpatialData(**elements, tables=sdata.tables, attrs=self.attrs)
889891

890892
def elements_are_self_contained(self) -> dict[str, bool]:
891893
"""
@@ -2242,7 +2244,8 @@ def init_from_elements(
22422244
assert model == ShapesModel
22432245
element_type = "shapes"
22442246
elements_dict.setdefault(element_type, {})[name] = element
2245-
return cls(**elements_dict, attrs=attrs)
2247+
with skip_element_validation():
2248+
return cls(**elements_dict, attrs=attrs)
22462249

22472250
def subset(
22482251
self,
@@ -2284,7 +2287,8 @@ def subset(
22842287
include_orphan_tables,
22852288
elements_dict=elements_dict,
22862289
)
2287-
return SpatialData(**elements_dict, tables=tables, attrs=self.attrs)
2290+
with skip_element_validation():
2291+
return SpatialData(**elements_dict, tables=tables, attrs=self.attrs)
22882292

22892293
def __getitem__(self, item: str) -> SpatialElement | AnnData:
22902294
"""

tests/core/operations/test_transform.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,16 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop(
694694
assert np.allclose(affine, np.array([[1, 0], [0, 1]]))
695695

696696

697-
def test_transformations_between_coordinate_systems(images):
697+
def test_transformations_between_coordinate_systems():
698698
# just a test that all the code is executed without errors and a quick test that the affine matrix is correct.
699699
# For a full test the notebooks are more exhaustive
700+
# Use a minimal sdata with only the two images needed — avoids writing 6 unused elements to disk.
701+
images = SpatialData(
702+
images={
703+
"image2d": Image2DModel.parse(np.zeros((3, 8, 8)), dims=("c", "y", "x")),
704+
"image2d_multiscale": Image2DModel.parse(np.zeros((3, 8, 8)), dims=("c", "y", "x"), scale_factors=[2]),
705+
}
706+
)
700707
with tempfile.TemporaryDirectory() as tmpdir:
701708
images.write(Path(tmpdir) / "sdata.zarr")
702709
el0 = images.images["image2d"]

0 commit comments

Comments
 (0)