From 6821df4504d4417a79e3c5b67e891068a9e9a1dd Mon Sep 17 00:00:00 2001 From: Tomatokeftes <129113023+Tomatokeftes@users.noreply.github.com> Date: Tue, 27 Jan 2026 11:58:07 +0100 Subject: [PATCH 1/8] feat: add lazy table loading via anndata.experimental.read_lazy Add a `lazy` parameter to `SpatialData.read()` and `read_zarr()` that enables lazy loading of tables using anndata's experimental `read_lazy()` function. This is particularly useful for large datasets (e.g., Mass Spectrometry Imaging with millions of pixels) where loading tables into memory is not feasible. Changes: - Add `lazy: bool = False` parameter to `read_zarr()` in io_zarr.py - Add `lazy: bool = False` parameter to `_read_table()` in io_table.py - Add `lazy: bool = False` parameter to `SpatialData.read()` in spatialdata.py - Add `_is_lazy_anndata()` helper to detect lazy AnnData objects - Skip eager validation for lazy tables to preserve lazy loading benefits - Add tests for lazy loading functionality Requires anndata >= 0.12 for lazy loading support. Falls back to eager loading with a warning if anndata version does not support read_lazy. --- src/spatialdata/_core/spatialdata.py | 401 +++++++++++++++++---------- src/spatialdata/_io/io_table.py | 45 ++- src/spatialdata/_io/io_zarr.py | 59 ++-- src/spatialdata/models/models.py | 238 ++++++++++++---- tests/io/test_readwrite.py | 254 ++++++++++++++--- 5 files changed, 726 insertions(+), 271 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index c06e62b74..869683ba8 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -22,40 +22,19 @@ from zarr.errors import GroupNotFoundError from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables -from spatialdata._core.validation import ( - check_all_keys_case_insensitively_unique, - check_target_region_column_symmetry, - check_valid_name, - raise_validation_errors, - validate_table_attr_keys, -) +from spatialdata._core.validation import (check_all_keys_case_insensitively_unique, check_target_region_column_symmetry, + check_valid_name, raise_validation_errors, validate_table_attr_keys) from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T from spatialdata._utils import _deprecation_alias -from spatialdata.models import ( - Image2DModel, - Image3DModel, - Labels2DModel, - Labels3DModel, - PointsModel, - ShapesModel, - TableModel, - get_model, - get_table_keys, -) -from spatialdata.models._utils import ( - SpatialElement, - convert_region_column_to_categorical, - get_axes_names, - set_channel_names, -) +from spatialdata.models import (Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, ShapesModel, + TableModel, get_model, get_table_keys) +from spatialdata.models._utils import (SpatialElement, convert_region_column_to_categorical, get_axes_names, + set_channel_names) if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest - from spatialdata._io.format import ( - SpatialDataContainerFormatType, - SpatialDataFormatType, - ) + from spatialdata._io.format import SpatialDataContainerFormatType, SpatialDataFormatType # schema for elements Label2D_s = Labels2DModel() @@ -140,7 +119,11 @@ def __init__( self._tables: Tables = Tables(shared_keys=self._shared_keys) self.attrs = attrs if attrs else {} # type: ignore[assignment] - element_names = list(chain.from_iterable([e.keys() for e in [images, labels, points, shapes] if e is not None])) + element_names = list( + chain.from_iterable( + [e.keys() for e in [images, labels, points, shapes] if e is not None] + ) + ) if len(element_names) != len(set(element_names)): duplicates = {x for x in element_names if element_names.count(x) > 1} @@ -241,9 +224,7 @@ def get_annotated_regions(table: AnnData) -> list[str]: ------- The annotated regions. """ - from spatialdata.models.models import ( - _get_region_metadata_from_region_key_column, - ) + from spatialdata.models.models import _get_region_metadata_from_region_key_column return _get_region_metadata_from_region_key_column(table) @@ -268,7 +249,9 @@ def get_region_key_column(table: AnnData) -> pd.Series: _, region_key, _ = get_table_keys(table) if table.obs.get(region_key) is not None: return table.obs[region_key] - raise KeyError(f"{region_key} is set as region key column. However the column is not found in table.obs.") + raise KeyError( + f"{region_key} is set as region key column. However the column is not found in table.obs." + ) @staticmethod def get_instance_key_column(table: AnnData) -> pd.Series: @@ -293,9 +276,13 @@ def get_instance_key_column(table: AnnData) -> pd.Series: _, _, instance_key = get_table_keys(table) if table.obs.get(instance_key) is not None: return table.obs[instance_key] - raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.") + raise KeyError( + f"{instance_key} is set as instance key column. However the column is not found in table.obs." + ) - def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None: + def set_channel_names( + self, element_name: str, channel_names: str | list[str], write: bool = False + ) -> None: """Set the channel names for an image `SpatialElement` in the `SpatialData` object. This method will overwrite the element in memory with the same element, but with new channel names. @@ -313,7 +300,9 @@ def set_channel_names(self, element_name: str, channel_names: str | list[str], w Whether to overwrite the channel metadata on disk (lightweight operation). This will not rewrite the pixel data itself (heavy operation). """ - self.images[element_name] = set_channel_names(self.images[element_name], channel_names) + self.images[element_name] = set_channel_names( + self.images[element_name], channel_names + ) if write: self.write_channel_names(element_name) @@ -391,7 +380,9 @@ def _change_table_annotation_target( If provided region_key is not present in table.obs. """ attrs = table.uns[TableModel.ATTRS_KEY] - table_region_key = region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) + table_region_key = ( + region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) + ) TableModel()._validate_set_region_key(table, region_key) TableModel()._validate_set_instance_key(table, instance_key) @@ -399,7 +390,9 @@ def _change_table_annotation_target( attrs[TableModel.REGION_KEY] = region @staticmethod - def update_annotated_regions_metadata(table: AnnData, region_key: str | None = None) -> AnnData: + def update_annotated_regions_metadata( + table: AnnData, region_key: str | None = None + ) -> AnnData: """ Update the annotation target of the table using the region_key column in table.obs. @@ -422,7 +415,9 @@ def update_annotated_regions_metadata(table: AnnData, region_key: str | None = N """ attrs = table.uns.get(TableModel.ATTRS_KEY) if attrs is None: - raise ValueError("The table has no annotation metadata. Please parse the table using `TableModel.parse`.") + raise ValueError( + "The table has no annotation metadata. Please parse the table using `TableModel.parse`." + ) region_key = region_key if region_key else attrs[TableModel.REGION_KEY_KEY] if attrs[TableModel.REGION_KEY_KEY] != region_key: attrs[TableModel.REGION_KEY_KEY] = region_key @@ -470,14 +465,20 @@ def set_table_annotates_spatialelement( isinstance(region, list | pd.Series) and not all(region_element in element_names for region_element in region) ): - raise ValueError(f"Annotation target '{region}' not present as SpatialElement in SpatialData object.") + raise ValueError( + f"Annotation target '{region}' not present as SpatialElement in SpatialData object." + ) if table.uns.get(TableModel.ATTRS_KEY): - self._change_table_annotation_target(table, region, region_key, instance_key) + self._change_table_annotation_target( + table, region, region_key, instance_key + ) elif isinstance(region_key, str) and isinstance(instance_key, str): self._set_table_annotation_target(table, region, region_key, instance_key) else: - raise TypeError("No current annotation metadata found. Please specify both region_key and instance_key.") + raise TypeError( + "No current annotation metadata found. Please specify both region_key and instance_key." + ) convert_region_column_to_categorical(table) @property @@ -588,9 +589,17 @@ def locate_element(self, element: SpatialElement) -> list[str]: found_element_name.append(element_name) if len(found) == 0: return [] - if any("/" in found_element_name[i] or "/" in found_element_type[i] for i in range(len(found))): - raise ValueError("Found an element name with a '/' character. This is not allowed.") - return [f"{found_element_type[i]}/{found_element_name[i]}" for i in range(len(found))] + if any( + "/" in found_element_name[i] or "/" in found_element_type[i] + for i in range(len(found)) + ): + raise ValueError( + "Found an element name with a '/' character. This is not allowed." + ) + return [ + f"{found_element_type[i]}/{found_element_name[i]}" + for i in range(len(found)) + ] def filter_by_coordinate_system( self, @@ -688,28 +697,28 @@ def _filter_tables( if include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): tables[table_name] = table continue - if not include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): + if not include_orphan_tables and not table.uns.get( + TableModel.ATTRS_KEY + ): continue if table_name in names_tables_to_keep: tables[table_name] = table continue # each mode here requires paths or elements, using assert here to avoid mypy errors. if by == "cs": - from spatialdata._core.query.relational_query import ( - _filter_table_by_element_names, - ) + from spatialdata._core.query.relational_query import _filter_table_by_element_names assert element_names is not None table = _filter_table_by_element_names(table, element_names) if table is not None and len(table) != 0: tables[table_name] = table elif by == "elements": - from spatialdata._core.query.relational_query import ( - _filter_table_by_elements, - ) + from spatialdata._core.query.relational_query import _filter_table_by_elements assert elements_dict is not None - table = _filter_table_by_elements(table, elements_dict=elements_dict) + table = _filter_table_by_elements( + table, elements_dict=elements_dict + ) if table is not None and len(table) != 0: tables[table_name] = table else: @@ -731,10 +740,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: The method does not allow to rename a coordinate system into an existing one, unless the existing one is also renamed in the same call. """ - from spatialdata.transformations.operations import ( - get_transformation, - set_transformation, - ) + from spatialdata.transformations.operations import get_transformation, set_transformation # check that the rename_dict is valid old_names = self.coordinate_systems @@ -760,7 +766,9 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: for old_cs, new_cs in rename_dict.items(): if old_cs in transformations: random_suffix = hashlib.sha1(os.urandom(128)).hexdigest()[:8] - transformations[new_cs + random_suffix] = transformations.pop(old_cs) + transformations[new_cs + random_suffix] = transformations.pop( + old_cs + ) suffixes_to_replace.add(new_cs + random_suffix) # remove the random suffixes @@ -771,10 +779,14 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: new_transformations[cs] = transformations[cs_with_suffix] suffixes_to_replace.remove(cs_with_suffix) else: - new_transformations[cs_with_suffix] = transformations[cs_with_suffix] + new_transformations[cs_with_suffix] = transformations[ + cs_with_suffix + ] # set the new transformations - set_transformation(element=element, transformation=new_transformations, set_all=True) + set_transformation( + element=element, transformation=new_transformations, set_all=True + ) def transform_element_to_coordinate_system( self, @@ -802,17 +814,18 @@ def transform_element_to_coordinate_system( """ from spatialdata import transform from spatialdata.transformations import Sequence - from spatialdata.transformations.operations import ( - get_transformation, - get_transformation_between_coordinate_systems, - remove_transformation, - set_transformation, - ) + from spatialdata.transformations.operations import (get_transformation, + get_transformation_between_coordinate_systems, + remove_transformation, set_transformation) element = self.get(element_name) - t = get_transformation_between_coordinate_systems(self, element, target_coordinate_system) + t = get_transformation_between_coordinate_systems( + self, element, target_coordinate_system + ) if maintain_positioning: - transformed = transform(element, transformation=t, maintain_positioning=maintain_positioning) + transformed = transform( + element, transformation=t, maintain_positioning=maintain_positioning + ) else: d = get_transformation(element, get_all=True) assert isinstance(d, dict) @@ -848,7 +861,9 @@ def transform_element_to_coordinate_system( # since target_coordinate_system is in d, we have that t is a Sequence with only one transformation. assert isinstance(t, Sequence) assert len(t.transformations) == 1 - seq = get_transformation(transformed, to_coordinate_system=target_coordinate_system) + seq = get_transformation( + transformed, to_coordinate_system=target_coordinate_system + ) assert isinstance(seq, Sequence) assert len(seq.transformations) == 2 assert seq.transformations[1] is t.transformations[0] @@ -877,7 +892,9 @@ def transform_to_coordinate_system( ------- The transformed SpatialData. """ - sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_tables=False) + sdata = self.filter_by_coordinate_system( + target_coordinate_system, filter_tables=False + ) elements: dict[str, dict[str, SpatialElement]] = {} for element_type, element_name, _ in sdata.gen_elements(): if element_type != "tables": @@ -913,7 +930,9 @@ def elements_are_self_contained(self) -> dict[str, bool]: description = {} for element_type, element_name, element in self.gen_elements(): element_path = self.path / element_type / element_name - description[element_name] = _is_element_self_contained(element, element_path) + description[element_name] = _is_element_self_contained( + element, element_path + ) return description def is_self_contained(self, element_name: str | None = None) -> bool: @@ -1030,8 +1049,12 @@ def _symmetric_difference_with_zarr_store(self) -> tuple[list[str], list[str]]: elements_in_sdata = self.elements_paths_in_memory() elements_in_zarr = self.elements_paths_on_disk() - elements_only_in_sdata = list(set(elements_in_sdata).difference(set(elements_in_zarr))) - elements_only_in_zarr = list(set(elements_in_zarr).difference(set(elements_in_sdata))) + elements_only_in_sdata = list( + set(elements_in_sdata).difference(set(elements_in_zarr)) + ) + elements_only_in_zarr = list( + set(elements_in_zarr).difference(set(elements_in_sdata)) + ) return elements_only_in_sdata, elements_only_in_zarr def _validate_can_safely_write_to_path( @@ -1046,7 +1069,9 @@ def _validate_can_safely_write_to_path( file_path = Path(file_path) if not isinstance(file_path, Path): - raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") + raise ValueError( + f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}." + ) # TODO: add test for this if os.path.exists(file_path): @@ -1068,24 +1093,28 @@ def _validate_can_safely_write_to_path( "Cannot overwrite. The target path of the write operation is in use. Please save the data to a " "different location. " ) - WORKAROUND = ( - "\nWorkaround: please see discussion here https://github.com/scverse/spatialdata/discussions/520 ." - ) + WORKAROUND = "\nWorkaround: please see discussion here https://github.com/scverse/spatialdata/discussions/520 ." if any(_backed_elements_contained_in_path(path=file_path, object=self)): raise ValueError( - ERROR_MSG + "\nDetails: the target path contains one or more files that Dask use for " + ERROR_MSG + + "\nDetails: the target path contains one or more files that Dask use for " "backing elements in the SpatialData object." + WORKAROUND ) if self.path is not None and ( - _is_subfolder(parent=self.path, child=file_path) or _is_subfolder(parent=file_path, child=self.path) + _is_subfolder(parent=self.path, child=file_path) + or _is_subfolder(parent=file_path, child=self.path) ): - if saving_an_element and _is_subfolder(parent=self.path, child=file_path): + if saving_an_element and _is_subfolder( + parent=self.path, child=file_path + ): raise ValueError( - ERROR_MSG + "\nDetails: the target path in which to save an element is a subfolder " + ERROR_MSG + + "\nDetails: the target path in which to save an element is a subfolder " "of the current Zarr store." + WORKAROUND ) raise ValueError( - ERROR_MSG + "\nDetails: the target path either contains, coincides or is contained in" + ERROR_MSG + + "\nDetails: the target path either contains, coincides or is contained in" " the current Zarr store." + WORKAROUND ) @@ -1110,7 +1139,9 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, update_sdata_path: bool = True, - sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + sdata_formats: ( + SpatialDataFormatType | list[SpatialDataFormatType] | None + ) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1172,7 +1203,9 @@ def write( store = _resolve_zarr_store(file_path) zarr_format = parsed["SpatialData"].zarr_format - zarr_group = zarr.create_group(store=store, overwrite=overwrite, zarr_format=zarr_format) + zarr_group = zarr.create_group( + store=store, overwrite=overwrite, zarr_format=zarr_format + ) self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) store.close() @@ -1215,15 +1248,12 @@ def _write_element( ) root_group, element_type_group, element_group = _get_groups_for_element( - zarr_path=zarr_container_path, element_type=element_type, element_name=element_name, use_consolidated=False - ) - from spatialdata._io import ( - write_image, - write_labels, - write_points, - write_shapes, - write_table, + zarr_path=zarr_container_path, + element_type=element_type, + element_name=element_name, + use_consolidated=False, ) + from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table from spatialdata._io.format import _parse_formats if parsed_formats is None: @@ -1270,7 +1300,9 @@ def write_element( self, element_name: str | list[str], overwrite: bool = False, - sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + sdata_formats: ( + SpatialDataFormatType | list[SpatialDataFormatType] | None + ) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1315,7 +1347,9 @@ def write_element( self._validate_element_names_are_unique() element = self.get(element_name) if element is None: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) if self.path is None: raise ValueError( @@ -1329,11 +1363,15 @@ def write_element( element_type = _element_type break if element_type is None: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) if element_type == "tables": validate_table_attr_keys(element) - self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) + self._check_element_not_on_disk_with_different_type( + element_type=element_type, element_name=element_name + ) self._write_element( element=element, @@ -1393,10 +1431,16 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: raise ValueError("The SpatialData object is not backed by a Zarr store.") on_disk = self.elements_paths_on_disk() - one_disk_names = [self._element_type_and_name_from_element_path(path)[1] for path in on_disk] + one_disk_names = [ + self._element_type_and_name_from_element_path(path)[1] for path in on_disk + ] in_memory = self.elements_paths_in_memory() - in_memory_names = [self._element_type_and_name_from_element_path(path)[1] for path in in_memory] - only_in_memory_names = list(set(in_memory_names).difference(set(one_disk_names))) + in_memory_names = [ + self._element_type_and_name_from_element_path(path)[1] for path in in_memory + ] + only_in_memory_names = list( + set(in_memory_names).difference(set(one_disk_names)) + ) only_on_disk_names = list(set(one_disk_names).difference(set(in_memory_names))) ERROR_MESSAGE = f"Element {element_name} is not found in the Zarr store associated with the SpatialData object." @@ -1409,19 +1453,25 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: if found: _element_type = self._element_type_from_element_name(element_name) - self._check_element_not_on_disk_with_different_type(element_type=_element_type, element_name=element_name) + self._check_element_not_on_disk_with_different_type( + element_type=_element_type, element_name=element_name + ) element_type = None on_disk = self.elements_paths_on_disk() for path in on_disk: - _element_type, _element_name = self._element_type_and_name_from_element_path(path) + _element_type, _element_name = ( + self._element_type_and_name_from_element_path(path) + ) if _element_name == element_name: element_type = _element_type break assert element_type is not None file_path_of_element = self.path / element_type / element_name - if any(_backed_elements_contained_in_path(path=file_path_of_element, object=self)): + if any( + _backed_elements_contained_in_path(path=file_path_of_element, object=self) + ): raise ValueError( "The file path specified is a parent directory of one or more files used for backing for one or " "more elements in the SpatialData object. Deleting the data would corrupt the SpatialData object." @@ -1438,10 +1488,14 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: if self.has_consolidated_metadata(): self.write_consolidated_metadata() - def _check_element_not_on_disk_with_different_type(self, element_type: str, element_name: str) -> None: + def _check_element_not_on_disk_with_different_type( + self, element_type: str, element_name: str + ) -> None: only_on_disk = self.elements_paths_on_disk() for disk_path in only_on_disk: - disk_element_type, disk_element_name = self._element_type_and_name_from_element_path(disk_path) + disk_element_type, disk_element_name = ( + self._element_type_and_name_from_element_path(disk_path) + ) if disk_element_name == element_name and disk_element_type != element_type: raise ValueError( f"Element {element_name} is found in the Zarr store as a {disk_element_type}, but it is found " @@ -1466,7 +1520,9 @@ def has_consolidated_metadata(self) -> bool: store.close() return return_value - def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[str, SpatialElement | AnnData] | None: + def _validate_can_write_metadata_on_element( + self, element_name: str + ) -> tuple[str, SpatialElement | AnnData] | None: """Validate if metadata can be written on an element, returns None if it cannot be written.""" from spatialdata._io._utils import _is_element_self_contained from spatialdata._io.io_zarr import _group_for_element_exists @@ -1489,7 +1545,9 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st element_type = self._element_type_from_element_name(element_name) - self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) + self._check_element_not_on_disk_with_different_type( + element_type=element_type, element_name=element_name + ) # check if the element exists in the Zarr storage if not _group_for_element_exists( @@ -1508,7 +1566,9 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st # warn the users if the element is not self-contained, that is, it is Dask-backed by files outside the Zarr # group for the element element_zarr_path = Path(self.path) / element_type / element_name - if not _is_element_self_contained(element=element, element_path=element_zarr_path): + if not _is_element_self_contained( + element=element, element_path=element_zarr_path + ): logger.info( f"Element {element_type}/{element_name} is not self-contained. The metadata will be" " saved to the Zarr group of the element in the SpatialData Zarr store. The data outside the element " @@ -1531,7 +1591,9 @@ def write_channel_names(self, element_name: str | None = None) -> None: if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) # recursively write the transformation for all the SpatialElement if element_name is None: @@ -1548,14 +1610,19 @@ def write_channel_names(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have the check in the conditional if element_type == "images" and self.path is not None: _, _, element_group = _get_groups_for_element( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name, use_consolidated=False + zarr_path=Path(self.path), + element_type=element_type, + element_name=element_name, + use_consolidated=False, ) from spatialdata._io._utils import overwrite_channel_names overwrite_channel_names(element_group, element) else: - raise ValueError(f"Can't set channel names for element of type '{element_type}'.") + raise ValueError( + f"Can't set channel names for element of type '{element_type}'." + ) def write_transformations(self, element_name: str | None = None) -> None: """ @@ -1571,7 +1638,9 @@ def write_transformations(self, element_name: str | None = None) -> None: if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) # recursively write the transformation for all the SpatialElement if element_name is None: @@ -1599,19 +1668,20 @@ def write_transformations(self, element_name: str | None = None) -> None: ) axes = get_axes_names(element) if isinstance(element, DataArray | DataTree): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_raster, - ) + from spatialdata._io._utils import overwrite_coordinate_transformations_raster from spatialdata._io.format import RasterFormats - raster_format = RasterFormats[element_group.metadata.attributes["spatialdata_attrs"]["version"]] + raster_format = RasterFormats[ + element_group.metadata.attributes["spatialdata_attrs"]["version"] + ] overwrite_coordinate_transformations_raster( - group=element_group, axes=axes, transformations=transformations, raster_format=raster_format + group=element_group, + axes=axes, + transformations=transformations, + raster_format=raster_format, ) elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_non_raster, - ) + from spatialdata._io._utils import overwrite_coordinate_transformations_non_raster overwrite_coordinate_transformations_non_raster( group=element_group, @@ -1625,7 +1695,9 @@ def _element_type_from_element_name(self, element_name: str) -> str: self._validate_element_names_are_unique() element = self.get(element_name) if element is None: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) located = self.locate_element(element) element_type = None @@ -1639,7 +1711,9 @@ def _element_type_from_element_name(self, element_name: str) -> str: assert element_type is not None return element_type - def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[str, str]: + def _element_type_and_name_from_element_path( + self, element_path: str + ) -> tuple[str, str]: element_type, element_name = element_path.split("/") return element_type, element_name @@ -1652,19 +1726,27 @@ def write_attrs( from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import CurrentSpatialDataContainerFormat, SpatialDataContainerFormatType - sdata_format = sdata_format if sdata_format is not None else CurrentSpatialDataContainerFormat() + sdata_format = ( + sdata_format + if sdata_format is not None + else CurrentSpatialDataContainerFormat() + ) assert isinstance(sdata_format, SpatialDataContainerFormatType) store = None if zarr_group is None: - assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." + assert ( + self.is_backed() + ), "The SpatialData object must be backed by a Zarr store to write attrs." store = _resolve_zarr_store(self.path) zarr_group = zarr.open_group(store=store, mode="r+") version = sdata_format.spatialdata_format_version version_specific_attrs = sdata_format.attrs_to_dict() - attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs + attrs_to_write = { + "spatialdata_attrs": {"version": version} | version_specific_attrs + } | self.attrs try: zarr_group.attrs.put(attrs_to_write) @@ -1713,7 +1795,9 @@ def write_metadata( if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + raise ValueError( + f"Element with name {element_name} not found in SpatialData object." + ) if write_attrs: self.write_attrs(sdata_format=sdata_format) @@ -1756,7 +1840,9 @@ def get_attrs( the value of `return_as`. """ - def _flatten_mapping(m: Mapping[str, Any], parent_key: str = "", sep: str = "_") -> dict[str, Any]: + def _flatten_mapping( + m: Mapping[str, Any], parent_key: str = "", sep: str = "_" + ) -> dict[str, Any]: items: list[tuple[str, Any]] = [] for k, v in m.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k @@ -1801,7 +1887,9 @@ def _flatten_mapping(m: Mapping[str, Any], parent_key: str = "", sep: str = "_") except Exception as e: raise ValueError(f"Failed to convert data to DataFrame: {e}") from e - raise ValueError(f"Invalid 'return_as' value: {return_as}. Expected 'dict', 'json', 'df', or None.") + raise ValueError( + f"Invalid 'return_as' value: {return_as}. Expected 'dict', 'json', 'df', or None." + ) @property def tables(self) -> Tables: @@ -1830,6 +1918,7 @@ def read( file_path: str | Path | UPath | zarr.Group, selection: tuple[str] | None = None, reconsolidate_metadata: bool = False, + lazy: bool = False, ) -> SpatialData: """ Read a SpatialData object from a Zarr storage (on-disk or remote). @@ -1842,6 +1931,11 @@ def read( The elements to read (images, labels, points, shapes, table). If None, all elements are read. reconsolidate_metadata If the consolidated metadata store got corrupted this can lead to errors when trying to read the data. + lazy + If True, read tables lazily using anndata.experimental.read_lazy. + This keeps large tables out of memory until needed. Requires anndata >= 0.12. + Note: Images, labels, and points are always read lazily (using Dask). + This parameter only affects tables, which are normally loaded into memory. Returns ------- @@ -1854,7 +1948,7 @@ def read( _write_consolidated_metadata(file_path) - return read_zarr(file_path, selection=selection) + return read_zarr(file_path, selection=selection, lazy=lazy) @property def images(self) -> Images: @@ -1933,7 +2027,8 @@ def _non_empty_elements(self) -> list[str]: return [ element for element in all_elements - if (getattr(self, element) is not None) and (len(getattr(self, element)) > 0) + if (getattr(self, element) is not None) + and (len(getattr(self, element)) > 0) ] def __repr__(self) -> str: @@ -1971,7 +2066,9 @@ def h(s: str) -> str: descr += f"\n{h('level0')}{attr.capitalize()}" unsorted_elements = attribute.items() - sorted_elements = sorted(unsorted_elements, key=lambda x: _natural_keys(x[0])) + sorted_elements = sorted( + unsorted_elements, key=lambda x: _natural_keys(x[0]) + ) for k, v in sorted_elements: descr += f"{h('empty_line')}" descr_class = v.__class__.__name__ @@ -2003,7 +2100,16 @@ def h(s: str) -> str: else: shape_str = ( "(" - + ", ".join([(str(dim) if not isinstance(dim, Scalar) else "") for dim in v.shape]) + + ", ".join( + [ + ( + str(dim) + if not isinstance(dim, Scalar) + else "" + ) + for dim in v.shape + ] + ) + ")" ) descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} with shape: {shape_str} {dim_string}" @@ -2079,7 +2185,9 @@ def _element_path_to_element_name_with_type(element_path: str) -> str: if not self.is_self_contained(): assert self.path is not None - descr += "\nwith the following Dask-backed elements not being self-contained:" + descr += ( + "\nwith the following Dask-backed elements not being self-contained:" + ) description = self.elements_are_self_contained() for _, element_name, element in self.gen_elements(): if not description[element_name]: @@ -2087,7 +2195,9 @@ def _element_path_to_element_name_with_type(element_path: str) -> str: descr += f"\n ▸ {element_name}: {backing_files}" if self.path is not None: - elements_only_in_sdata, elements_only_in_zarr = self._symmetric_difference_with_zarr_store() + elements_only_in_sdata, elements_only_in_zarr = ( + self._symmetric_difference_with_zarr_store() + ) if len(elements_only_in_sdata) > 0: descr += "\nwith the following elements not in the Zarr store:" for element_path in elements_only_in_sdata: @@ -2174,9 +2284,13 @@ def _validate_element_names_are_unique(self) -> None: ValueError If the element names are not unique. """ - check_all_keys_case_insensitively_unique([name for _, name, _ in self.gen_elements()], location=()) + check_all_keys_case_insensitively_unique( + [name for _, name, _ in self.gen_elements()], location=() + ) - def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | AnnData]: + def _find_element( + self, element_name: str + ) -> tuple[str, str, SpatialElement | AnnData]: """ Retrieve SpatialElement or Table from the SpatialData instance matching element_name. @@ -2275,7 +2389,9 @@ def subset( """ elements_dict: dict[str, SpatialElement] = {} names_tables_to_keep: set[str] = set() - for element_type, element_name, element in self._gen_elements(include_tables=True): + for element_type, element_name, element in self._gen_elements( + include_tables=True + ): if element_name in element_names: if element_type != "tables": elements_dict.setdefault(element_type, {})[element_name] = element @@ -2308,11 +2424,16 @@ def __getitem__(self, item: str) -> SpatialElement | AnnData: def __contains__(self, key: str) -> bool: element_dict = { - element_name: element_value for _, element_name, element_value in self._gen_elements(include_tables=True) + element_name: element_value + for _, element_name, element_value in self._gen_elements( + include_tables=True + ) } return key in element_dict - def get(self, key: str, default_value: SpatialElement | AnnData | None = None) -> SpatialElement | AnnData | None: + def get( + self, key: str, default_value: SpatialElement | AnnData | None = None + ) -> SpatialElement | AnnData | None: """ Get element from SpatialData object based on corresponding name. @@ -2420,7 +2541,9 @@ def filter_by_table_query( obs_names_expr: Predicates | None = None, var_names_expr: Predicates | None = None, layer: str | None = None, - how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", + how: Literal[ + "left", "left_exclusive", "inner", "right", "right_exclusive" + ] = "right", ) -> SpatialData: """ Filter the SpatialData object based on a set of table queries. diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 719c9c1a9..8bb1de0d0 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -7,18 +7,45 @@ from anndata._io.specs import write_elem as write_adata from ome_zarr.format import Format -from spatialdata._io.format import ( - CurrentTablesFormat, - TablesFormats, - TablesFormatV01, - TablesFormatV02, - _parse_version, -) +from spatialdata._io.format import CurrentTablesFormat, TablesFormats, TablesFormatV01, TablesFormatV02, _parse_version from spatialdata.models import TableModel, get_table_keys -def _read_table(store: str | Path) -> AnnData: - table = read_anndata_zarr(str(store)) +def _read_table(store: str | Path, lazy: bool = False) -> AnnData: + """ + Read a table from a zarr store. + + Parameters + ---------- + store + Path to the zarr store containing the table. + lazy + If True, read the table lazily using anndata.experimental.read_lazy. + This requires anndata >= 0.12. If the installed version does not support + lazy reading, a warning is raised and the table is read eagerly. + + Returns + ------- + The AnnData table, either lazily loaded or in-memory. + """ + if lazy: + try: + from anndata.experimental import read_lazy + + table = read_lazy(str(store)) + except ImportError: + import warnings + + warnings.warn( + "Lazy reading of tables requires anndata >= 0.12. " + "Falling back to eager reading. To enable lazy reading, " + "upgrade anndata with: pip install 'anndata>=0.12'", + UserWarning, + stacklevel=2, + ) + table = read_anndata_zarr(str(store)) + else: + table = read_anndata_zarr(str(store)) f = zarr.open(store, mode="r") version = _parse_version(f, expect_attrs_key=False) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 0312dc965..c24c257ca 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,6 +1,7 @@ import os import warnings from collections.abc import Callable +from functools import partial from json import JSONDecodeError from pathlib import Path from typing import Any, Literal, cast @@ -15,11 +16,7 @@ from zarr.errors import ArrayNotFoundError from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import ( - BadFileHandleMethod, - _resolve_zarr_store, - handle_read_errors, -) +from spatialdata._io._utils import BadFileHandleMethod, _resolve_zarr_store, handle_read_errors from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes @@ -36,7 +33,12 @@ def _read_zarr_group_spatialdata_element( read_func: Callable[..., Any], group_name: Literal["images", "labels", "shapes", "points", "tables"], element_type: Literal["image", "labels", "shapes", "points", "tables"], - element_container: (dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData]), + element_container: ( + dict[str, Raster_T] + | dict[str, DaskDataFrame] + | dict[str, GeoDataFrame] + | dict[str, AnnData] + ), on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], ) -> None: with handle_read_errors( @@ -66,7 +68,9 @@ def _read_zarr_group_spatialdata_element( ), ): if element_type in ["image", "labels"]: - reader_format = get_raster_format_for_read(elem_group, sdata_version) + reader_format = get_raster_format_for_read( + elem_group, sdata_version + ) element = read_func( elem_group_path, cast(Literal["image", "labels"], element_type), @@ -104,10 +108,7 @@ def get_raster_format_for_read( ------- The ome-zarr format to use for reading the raster element. """ - from spatialdata._io.format import ( - sdata_zarr_version_to_ome_zarr_format, - sdata_zarr_version_to_raster_format, - ) + from spatialdata._io.format import sdata_zarr_version_to_ome_zarr_format, sdata_zarr_version_to_raster_format if sdata_version == "0.1": group_version = group.metadata.attributes["multiscales"][0]["version"] @@ -123,7 +124,10 @@ def get_raster_format_for_read( def read_zarr( store: str | Path | UPath | zarr.Group, selection: None | tuple[str] = None, - on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, + on_bad_files: Literal[ + BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN + ] = BadFileHandleMethod.ERROR, + lazy: bool = False, ) -> SpatialData: """ Read a SpatialData dataset from a zarr store (on-disk or remote). @@ -147,6 +151,12 @@ def read_zarr( object is returned containing only elements that could be read. Failures can only be determined from the warnings. + lazy + If True, read tables lazily using anndata.experimental.read_lazy. + This keeps large tables out of memory until needed. Requires anndata >= 0.12. + Note: Images, labels, and points are always read lazily (using Dask). + This parameter only affects tables, which are normally loaded into memory. + Returns ------- A SpatialData object. @@ -176,7 +186,11 @@ def read_zarr( shapes: dict[str, GeoDataFrame] = {} tables: dict[str, AnnData] = {} - selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) + selector = ( + {"images", "labels", "points", "shapes", "tables"} + if not selection + else set(selection or []) + ) logger.debug(f"Reading selection {selector}") # we could make this more readable. One can get lost when looking at this dict and iteration over the items @@ -185,7 +199,10 @@ def read_zarr( tuple[ Callable[..., Any], Literal["image", "labels", "shapes", "points", "tables"], - dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData], + dict[str, Raster_T] + | dict[str, DaskDataFrame] + | dict[str, GeoDataFrame] + | dict[str, AnnData], ], ] = { # ome-zarr-py needs a kwargs that has "image" has key. So here we have "image" and not "images" @@ -193,7 +210,7 @@ def read_zarr( "labels": (_read_multiscale, "labels", labels), "points": (_read_points, "points", points), "shapes": (_read_shapes, "shapes", shapes), - "tables": (_read_table, "tables", tables), + "tables": (partial(_read_table, lazy=lazy), "tables", tables), } for group_name, ( read_func, @@ -279,15 +296,21 @@ def _get_groups_for_element( # When writing, use_consolidated must be set to False. Otherwise, the metadata store # can get out of sync with newly added elements (e.g., labels), leading to errors. - root_group = zarr.open_group(store=resolved_store, mode="r+", use_consolidated=use_consolidated) + root_group = zarr.open_group( + store=resolved_store, mode="r+", use_consolidated=use_consolidated + ) element_type_group = root_group.require_group(element_type) - element_type_group = zarr.open_group(element_type_group.store_path, mode="a", use_consolidated=use_consolidated) + element_type_group = zarr.open_group( + element_type_group.store_path, mode="a", use_consolidated=use_consolidated + ) element_name_group = element_type_group.require_group(element_name) return root_group, element_type_group, element_name_group -def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: str) -> bool: +def _group_for_element_exists( + zarr_path: Path, element_type: str, element_name: str +) -> bool: """ Check if the group for an element exists. diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index e834ad78d..fa3920c2e 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -23,12 +23,7 @@ from shapely.io import from_geojson, from_ragged_array from spatial_image import to_spatial_image from xarray import DataArray, DataTree -from xarray_schema.components import ( - ArrayTypeSchema, - AttrSchema, - AttrsSchema, - DimsSchema, -) +from xarray_schema.components import ArrayTypeSchema, AttrSchema, AttrsSchema, DimsSchema from xarray_schema.dataarray import DataArraySchema from spatialdata._core.validation import validate_table_attr_keys @@ -37,30 +32,49 @@ from spatialdata._utils import _check_match_length_channels_c_dim from spatialdata.config import settings from spatialdata.models import C, X, Y, Z, get_axes_names -from spatialdata.models._utils import ( - DEFAULT_COORDINATE_SYSTEM, - TRANSFORM_KEY, - MappingToCoordinateSystem_t, - SpatialElement, - _validate_mapping_to_coordinate_system_type, - convert_region_column_to_categorical, -) -from spatialdata.transformations._utils import ( - _get_transformations, - _set_transformations, - compute_coordinates, -) +from spatialdata.models._utils import (DEFAULT_COORDINATE_SYSTEM, TRANSFORM_KEY, MappingToCoordinateSystem_t, + SpatialElement, _validate_mapping_to_coordinate_system_type, + convert_region_column_to_categorical) +from spatialdata.transformations._utils import _get_transformations, _set_transformations, compute_coordinates from spatialdata.transformations.transformations import BaseTransformation, Identity # Types -Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] +Chunks_t: TypeAlias = ( + int + | tuple[int, ...] + | tuple[tuple[int, ...], ...] + | Mapping[Any, None | int | tuple[int, ...]] +) ScaleFactors_t = Sequence[dict[str, int] | int] Transform_s = AttrSchema(BaseTransformation, None) ATTRS_KEY = "spatialdata_attrs" -def _parse_transformations(element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None) -> None: +def _is_lazy_anndata(adata: AnnData) -> bool: + """Check if an AnnData object is lazily loaded. + + Lazy AnnData objects (from anndata.experimental.read_lazy) have obs/var + stored as xarray Dataset2D instead of pandas DataFrame. + + Parameters + ---------- + adata + The AnnData object to check. + + Returns + ------- + True if the AnnData is lazily loaded, False otherwise. + """ + # Check if obs is not a pandas DataFrame (lazy AnnData uses xarray Dataset2D) + if not isinstance(adata.obs, pd.DataFrame): + return True + return False + + +def _parse_transformations( + element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None +) -> None: _validate_mapping_to_coordinate_system_type(transformations) transformations_in_element = _get_transformations(element) if ( @@ -166,7 +180,9 @@ def parse( if transformations: transformations = transformations.copy() if "name" in kwargs: - raise ValueError("The `name` argument is not (yet) supported for raster data.") + raise ValueError( + "The `name` argument is not (yet) supported for raster data." + ) # if dims is specified inside the data, get the value of dims from the data if isinstance(data, DataArray): if not isinstance(data.data, DaskArray): # numpy -> dask @@ -214,13 +230,18 @@ def parse( if c_coords is not None: c_coords = _check_match_length_channels_c_dim(data, c_coords, cls.dims.dims) - if c_coords is not None and len(c_coords) != data.shape[cls.dims.dims.index("c")]: + if ( + c_coords is not None + and len(c_coords) != data.shape[cls.dims.dims.index("c")] + ): raise ValueError( f"The number of channel names `{len(c_coords)}` does not match the length of dimension 'c'" f" with length {data.shape[cls.dims.dims.index('c')]}." ) - data = to_spatial_image(array_like=data, dims=cls.dims.dims, c_coords=c_coords, **kwargs) + data = to_spatial_image( + array_like=data, dims=cls.dims.dims, c_coords=c_coords, **kwargs + ) # parse transformations _parse_transformations(data, transformations) # convert to multiscale if needed @@ -275,12 +296,18 @@ def _(self, data: DataArray) -> None: @validate.register(DataTree) def _(self, data: DataTree) -> None: - for j, k in zip(data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))], strict=True): + for j, k in zip( + data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))], strict=True + ): if j != k: - raise ValueError(f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`.") + raise ValueError( + f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`." + ) name = {list(data[i].data_vars.keys())[0] for i in data} if len(name) != 1: - raise ValueError(f"Expected exactly one data variable for the datatree: found `{name}`.") + raise ValueError( + f"Expected exactly one data variable for the datatree: found `{name}`." + ) name = list(name)[0] for d in data: super().validate(data[d][name]) @@ -448,9 +475,14 @@ def validate(cls, data: GeoDataFrame) -> None: """ SUGGESTION = " Please use ShapesModel.parse() to construct data that is guaranteed to be valid." if cls.GEOMETRY_KEY not in data: - raise KeyError(f"GeoDataFrame must have a column named `{cls.GEOMETRY_KEY}`." + SUGGESTION) + raise KeyError( + f"GeoDataFrame must have a column named `{cls.GEOMETRY_KEY}`." + + SUGGESTION + ) if not isinstance(data[cls.GEOMETRY_KEY], GeoSeries): - raise ValueError(f"Column `{cls.GEOMETRY_KEY}` must be a GeoSeries." + SUGGESTION) + raise ValueError( + f"Column `{cls.GEOMETRY_KEY}` must be a GeoSeries." + SUGGESTION + ) if len(data[cls.GEOMETRY_KEY]) == 0: raise ValueError(f"Column `{cls.GEOMETRY_KEY}` is empty." + SUGGESTION) geom_ = data[cls.GEOMETRY_KEY].values[0] @@ -475,7 +507,10 @@ def validate(cls, data: GeoDataFrame) -> None: "please correct the radii of the circles before calling the parser function.", ) if cls.TRANSFORM_KEY not in data.attrs: - raise ValueError(f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + SUGGESTION) + raise ValueError( + f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + + SUGGESTION + ) if len(data) > 0: n = data.geometry.iloc[0]._ndim if n != 2: @@ -572,7 +607,9 @@ def parse(cls, data: Any, **kwargs: Any) -> GeoDataFrame: def _( cls, data: np.ndarray, # type: ignore[type-arg] - geometry: Literal[0, 3, 6], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON] + geometry: Literal[ + 0, 3, 6 + ], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON] offsets: tuple[ArrayLike, ...] | None = None, radius: float | ArrayLike | None = None, index: ArrayLike | None = None, @@ -583,7 +620,9 @@ def _( geo_df = GeoDataFrame({"geometry": data}) if GeometryType(geometry).name == "POINT": if radius is None: - raise ValueError("If `geometry` is `Circles`, `radius` must be provided.") + raise ValueError( + "If `geometry` is `Circles`, `radius` must be provided." + ) geo_df[cls.RADIUS_KEY] = radius if index is not None: geo_df.index = index @@ -610,7 +649,9 @@ def _( geo_df = GeoDataFrame({"geometry": gc.geoms}) if isinstance(geo_df["geometry"].iloc[0], Point): if radius is None: - raise ValueError("If `geometry` is `Circles`, `radius` must be provided.") + raise ValueError( + "If `geometry` is `Circles`, `radius` must be provided." + ) geo_df[cls.RADIUS_KEY] = radius if index is not None: geo_df.index = index @@ -627,7 +668,10 @@ def _( ) -> GeoDataFrame: if "geometry" not in data.columns: raise ValueError("`geometry` column not found in `GeoDataFrame`.") - if isinstance(data["geometry"].iloc[0], Point) and cls.RADIUS_KEY not in data.columns: + if ( + isinstance(data["geometry"].iloc[0], Point) + and cls.RADIUS_KEY not in data.columns + ): raise ValueError(f"Column `{cls.RADIUS_KEY}` not found.") _parse_transformations(data, transformations) cls.validate(data) @@ -667,7 +711,8 @@ def validate(cls, data: DaskDataFrame) -> None: raise ValueError(f"Column `{ax}` must be of type `int` or `float`.") if cls.TRANSFORM_KEY not in data.attrs: raise ValueError( - f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + SUGGESTION + f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + + SUGGESTION ) if ATTRS_KEY in data.attrs and "feature_key" in data.attrs[ATTRS_KEY]: feature_key = data.attrs[ATTRS_KEY][cls.FEATURE_KEY] @@ -750,11 +795,15 @@ def _( if annotation is not None: if feature_key is not None: - df_dict[feature_key] = annotation[feature_key].astype(str).astype("category") + df_dict[feature_key] = ( + annotation[feature_key].astype(str).astype("category") + ) if instance_key is not None: df_dict[instance_key] = annotation[instance_key] if Z not in axes and Z in annotation.columns: - logger.info(f"Column `{Z}` in `annotation` will be ignored since the data is 2D.") + logger.info( + f"Column `{Z}` in `annotation` will be ignored since the data is 2D." + ) for c in set(annotation.columns) - {feature_key, instance_key, X, Y, Z}: df_dict[c] = annotation[c] @@ -793,7 +842,9 @@ def _( if "sort" not in kwargs: index_monotonically_increasing = data.index.is_monotonic_increasing if not isinstance(index_monotonically_increasing, bool): - index_monotonically_increasing = index_monotonically_increasing.compute() + index_monotonically_increasing = ( + index_monotonically_increasing.compute() + ) sort = index_monotonically_increasing else: sort = kwargs["sort"] @@ -831,7 +882,9 @@ def _( if data[feature_key].dtype.name == "category": table[feature_key] = data[feature_key] else: - table[feature_key] = data[feature_key].astype(str).astype("category") + table[feature_key] = ( + data[feature_key].astype(str).astype("category") + ) if instance_key is not None: table[instance_key] = data[instance_key] for c in [X, Y, Z]: @@ -891,9 +944,13 @@ def _add_metadata_and_validate( # It also just changes the state of the series, so it is not a big deal. if isinstance(data[c].dtype, CategoricalDtype) and not data[c].cat.known: try: - data[c] = data[c].cat.set_categories(data[c].compute().cat.categories) + data[c] = data[c].cat.set_categories( + data[c].compute().cat.categories + ) except ValueError: - logger.info(f"Column `{c}` contains unknown categories. Consider casting it.") + logger.info( + f"Column `{c}` contains unknown categories. Consider casting it." + ) _parse_transformations(data, transformations) cls.validate(data) @@ -907,7 +964,9 @@ class TableModel: INSTANCE_KEY = "instance_key" ATTRS_KEY = ATTRS_KEY - def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) -> None: + def _validate_set_region_key( + self, data: AnnData, region_key: str | None = None + ) -> None: """ Validate the region key in table.uns or set a new region key as the region key column. @@ -947,7 +1006,9 @@ def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) raise ValueError(f"'{region_key}' column not present in table.obs") attrs[self.REGION_KEY_KEY] = region_key - def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = None) -> None: + def _validate_set_instance_key( + self, data: AnnData, instance_key: str | None = None + ) -> None: """ Validate the instance_key in table.uns or set a new instance_key as the instance_key column. @@ -991,7 +1052,9 @@ def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = N if instance_key in data.obs: attrs[self.INSTANCE_KEY] = instance_key else: - raise ValueError(f"Instance key column '{instance_key}' not found in table.obs.") + raise ValueError( + f"Instance key column '{instance_key}' not found in table.obs." + ) def _validate_table_annotation_metadata(self, data: AnnData) -> None: """ @@ -1026,16 +1089,33 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: attr = data.uns[ATTRS_KEY] if "region" not in attr: - raise ValueError(f"`region` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) + raise ValueError( + f"`region` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION + ) if "region_key" not in attr: - raise ValueError(f"`region_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) + raise ValueError( + f"`region_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION + ) if "instance_key" not in attr: - raise ValueError(f"`instance_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) + raise ValueError( + f"`instance_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION + ) if attr[self.REGION_KEY_KEY] not in data.obs: - raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.") + raise ValueError( + f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column." + ) if attr[self.INSTANCE_KEY] not in data.obs: - raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.") + raise ValueError( + f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column." + ) + + # Skip detailed dtype/value validation for lazy-loaded AnnData + # These checks would trigger data loading, defeating the purpose of lazy loading + # Validation will occur when data is actually computed/accessed + if _is_lazy_anndata(data): + return + if ( (dtype := data.obs[attr[self.INSTANCE_KEY]].dtype) not in [ @@ -1049,26 +1129,41 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: "O", ] and not pd.api.types.is_string_dtype(data.obs[attr[self.INSTANCE_KEY]]) - or (dtype == "O" and (val_dtype := type(data.obs[attr[self.INSTANCE_KEY]].iloc[0])) is not str) + or ( + dtype == "O" + and (val_dtype := type(data.obs[attr[self.INSTANCE_KEY]].iloc[0])) + is not str + ) ): dtype = dtype if dtype != "O" else val_dtype raise TypeError( f"Only int, np.int16, np.int32, np.int64, uint equivalents or string allowed as dtype for " f"instance_key column in obs. Dtype found to be {dtype}" ) - expected_regions = attr[self.REGION_KEY] if isinstance(attr[self.REGION_KEY], list) else [attr[self.REGION_KEY]] + expected_regions = ( + attr[self.REGION_KEY] + if isinstance(attr[self.REGION_KEY], list) + else [attr[self.REGION_KEY]] + ) found_regions = data.obs[attr[self.REGION_KEY_KEY]].unique().tolist() if len(set(expected_regions).symmetric_difference(set(found_regions))) > 0: - raise ValueError(f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match.") + raise ValueError( + f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match." + ) # Warning for object/string columns with NaN in region_key or instance_key instance_key = attr[self.INSTANCE_KEY] region_key = attr[self.REGION_KEY_KEY] - for key_name, key_value in [("region_key", region_key), ("instance_key", instance_key)]: + for key_name, key_value in [ + ("region_key", region_key), + ("instance_key", instance_key), + ]: if key_value in data.obs: col = data.obs[key_value] col_dtype = col.dtype - if (col_dtype == "object" or pd.api.types.is_string_dtype(col_dtype)) and col.isna().any(): + if ( + col_dtype == "object" or pd.api.types.is_string_dtype(col_dtype) + ) and col.isna().any(): logger.warning( f"The {key_name} column '{key_value}' is of {col_dtype} type and contains NaN values. " "After writing and reading with AnnData, NaN values may (depending on the AnnData version) " @@ -1099,6 +1194,10 @@ def validate( if ATTRS_KEY not in data.uns: return data + # Check if this is a lazy-loaded AnnData (from anndata.experimental.read_lazy) + # Lazy AnnData has xarray-based obs/var, which requires different validation + is_lazy = _is_lazy_anndata(data) + _, region_key, instance_key = get_table_keys(data) if region_key is not None: if region_key not in data.obs: @@ -1106,7 +1205,10 @@ def validate( f"Region key `{region_key}` not in `adata.obs`. Please create the column and parse " f"using TableModel.parse(adata)." ) - if not isinstance(data.obs[region_key].dtype, CategoricalDtype): + # Skip dtype validation for lazy tables (would require loading data) + if not is_lazy and not isinstance( + data.obs[region_key].dtype, CategoricalDtype + ): raise ValueError( f"`table.obs[{region_key}]` must be of type `categorical`, not `{type(data.obs[region_key])}`." ) @@ -1116,8 +1218,11 @@ def validate( f"Instance key `{instance_key}` not in `adata.obs`. Please create the column and parse" f" using TableModel.parse(adata)." ) - if data.obs[instance_key].isnull().values.any(): - raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") + # Skip null check for lazy tables (would require loading data) + if not is_lazy and data.obs[instance_key].isnull().values.any(): + raise ValueError( + "`table.obs[instance_key]` must not contain null values, but it does." + ) self._validate_table_annotation_metadata(data) @@ -1154,7 +1259,9 @@ def parse( """ validate_table_attr_keys(adata) # either all live in adata.uns or all be passed in as argument - n_args = sum([region is not None, region_key is not None, instance_key is not None]) + n_args = sum( + [region is not None, region_key is not None, instance_key is not None] + ) if n_args == 0: if cls.ATTRS_KEY not in adata.uns: # table not annotating any element @@ -1183,7 +1290,9 @@ def parse( region = region.tolist() region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): - raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") + raise ValueError( + f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values." + ) adata.uns[cls.ATTRS_KEY][cls.REGION_KEY] = region adata.uns[cls.ATTRS_KEY][cls.REGION_KEY_KEY] = region_key @@ -1194,7 +1303,9 @@ def parse( grouped = adata.obs.groupby(region_key, observed=True) grouped_size = grouped.size() grouped_nunique = grouped.nunique() - not_unique = grouped_size[grouped_size != grouped_nunique[instance_key]].index.tolist() + not_unique = grouped_size[ + grouped_size != grouped_nunique[instance_key] + ].index.tolist() if not_unique: raise ValueError( f"Instance key column for region(s) `{', '.join(not_unique)}` does not contain only unique values" @@ -1305,6 +1416,11 @@ def _get_region_metadata_from_region_key_column(table: AnnData) -> list[str]: ) annotated_regions = region_key_column.unique().tolist() else: - annotated_regions = table.obs[region_key].cat.remove_unused_categories().cat.categories.unique().tolist() + annotated_regions = ( + table.obs[region_key] + .cat.remove_unused_categories() + .cat.categories.unique() + .tolist() + ) assert isinstance(annotated_regions, list) return annotated_regions diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index af028d29c..241336df7 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -94,7 +94,11 @@ def test_shapes( # add a mixed Polygon + MultiPolygon element shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]]) - shapes.write(tmpdir, sdata_formats=sdata_container_format, shapes_geometry_encoding=geometry_encoding) + shapes.write( + tmpdir, + sdata_formats=sdata_container_format, + shapes_geometry_encoding=geometry_encoding, + ) sdata = SpatialData.read(tmpdir) if geometry_encoding == "WKB": @@ -102,7 +106,9 @@ def test_shapes( else: # convert each Polygon to a MultiPolygon mixed_multipolygon = shapes["mixed"].assign( - geometry=lambda df: df.geometry.apply(lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g) + geometry=lambda df: df.geometry.apply( + lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g + ) ) assert sdata["mixed"].equals(mixed_multipolygon) assert not sdata["mixed"].equals(shapes["mixed"]) @@ -139,7 +145,9 @@ def test_shapes_geometry_encoding_write_element( # Write each shape element - should use global setting for shape_name in shapes.shapes: - empty_sdata.write_element(shape_name, sdata_formats=sdata_container_format) + empty_sdata.write_element( + shape_name, sdata_formats=sdata_container_format + ) # Verify the encoding metadata in the parquet file parquet_file = tmpdir / "shapes" / shape_name / "shapes.parquet" @@ -220,8 +228,12 @@ def test_multiple_tables( tables: list[AnnData], sdata_container_format: SpatialDataContainerFormatType, ) -> None: - sdata_tables = SpatialData(tables={str(i): tables[i] for i in range(len(tables))}) - self._test_table(tmp_path, sdata_tables, sdata_container_format=sdata_container_format) + sdata_tables = SpatialData( + tables={str(i): tables[i] for i in range(len(tables))} + ) + self._test_table( + tmp_path, sdata_tables, sdata_container_format=sdata_container_format + ) def test_roundtrip( self, @@ -252,7 +264,9 @@ def test_incremental_io_list_of_elements( assert "shapes/new_shapes0" not in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" not in shapes.elements_paths_on_disk() - shapes.write_element(["new_shapes0", "new_shapes1"], sdata_formats=sdata_container_format) + shapes.write_element( + ["new_shapes0", "new_shapes1"], sdata_formats=sdata_container_format + ) assert "shapes/new_shapes0" in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" in shapes.elements_paths_on_disk() @@ -367,7 +381,9 @@ def test_incremental_io_on_disk( ValueError, match=match, ): - sdata.write_element(name, overwrite=True, sdata_formats=sdata_container_format) + sdata.write_element( + name, overwrite=True, sdata_formats=sdata_container_format + ) if workaround == 1: new_name = f"{name}_new_place" @@ -398,7 +414,9 @@ def test_incremental_io_on_disk( sdata.delete_element_from_disk(name) sdata.write_element(name, sdata_formats=sdata_container_format) - def test_io_and_lazy_loading_points(self, points, sdata_container_format: SpatialDataContainerFormatType): + def test_io_and_lazy_loading_points( + self, points, sdata_container_format: SpatialDataContainerFormatType + ): with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") points.write(f, sdata_formats=sdata_container_format) @@ -407,7 +425,9 @@ def test_io_and_lazy_loading_points(self, points, sdata_container_format: Spatia sdata2 = SpatialData.read(f) assert len(get_dask_backing_files(sdata2)) > 0 - def test_io_and_lazy_loading_raster(self, images, labels, sdata_container_format: SpatialDataContainerFormatType): + def test_io_and_lazy_loading_raster( + self, images, labels, sdata_container_format: SpatialDataContainerFormatType + ): sdatas = {"images": images, "labels": labels} for k, sdata in sdatas.items(): d = getattr(sdata, k) @@ -457,9 +477,13 @@ def test_replace_transformation_on_disk_non_raster( with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") sdata.write(f, sdata_formats=sdata_container_format) - t0 = get_transformation(SpatialData.read(f).__getattribute__(k)[elem_name]) + t0 = get_transformation( + SpatialData.read(f).__getattribute__(k)[elem_name] + ) assert isinstance(t0, Identity) - set_transformation(sdata[elem_name], Scale([2.0], axes=("x",)), write_to_sdata=sdata) + set_transformation( + sdata[elem_name], Scale([2.0], axes=("x",)), write_to_sdata=sdata + ) t1 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t1, Scale) @@ -470,10 +494,16 @@ def test_write_overwrite_fails_when_no_zarr_store( f = Path(tmpdir) / "data.zarr" f.mkdir() old_data = SpatialData() - with pytest.raises(ValueError, match="The target file path specified already exists"): + with pytest.raises( + ValueError, match="The target file path specified already exists" + ): old_data.write(f, sdata_formats=sdata_container_format) - with pytest.raises(ValueError, match="The target file path specified already exists"): - full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) + with pytest.raises( + ValueError, match="The target file path specified already exists" + ): + full_sdata.write( + f, overwrite=True, sdata_formats=sdata_container_format + ) def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( self, @@ -506,7 +536,9 @@ def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( match=r"Details: the target path contains one or more files that Dask use for " "backing elements in the SpatialData object", ): - full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) + full_sdata.write( + f, overwrite=True, sdata_formats=sdata_container_format + ) def test_overwrite_fails_when_zarr_store_present( self, full_sdata, sdata_container_format: SpatialDataContainerFormatType @@ -526,7 +558,9 @@ def test_overwrite_fails_when_zarr_store_present( ValueError, match=r"Details: the target path either contains, coincides or is contained in the current Zarr store", ): - full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) + full_sdata.write( + f, overwrite=True, sdata_formats=sdata_container_format + ) # support for overwriting backed sdata has been temporarily removed # with tempfile.TemporaryDirectory() as tmpdir: @@ -547,9 +581,7 @@ def test_overwrite_fails_when_zarr_store_present( def test_overwrite_fails_onto_non_zarr_file( self, full_sdata, sdata_container_format: SpatialDataContainerFormatType ): - ERROR_MESSAGE = ( - "The target file path specified already exists, and it has been detected to not be a Zarr store." - ) + ERROR_MESSAGE = "The target file path specified already exists, and it has been detected to not be a Zarr store." with tempfile.TemporaryDirectory() as tmpdir: f0 = os.path.join(tmpdir, "test.txt") with open(f0, "w"): @@ -562,13 +594,17 @@ def test_overwrite_fails_onto_non_zarr_file( ValueError, match=ERROR_MESSAGE, ): - full_sdata.write(f0, overwrite=True, sdata_formats=sdata_container_format) + full_sdata.write( + f0, overwrite=True, sdata_formats=sdata_container_format + ) f1 = os.path.join(tmpdir, "test.zarr") os.mkdir(f1) with pytest.raises(ValueError, match=ERROR_MESSAGE): full_sdata.write(f1, sdata_formats=sdata_container_format) with pytest.raises(ValueError, match=ERROR_MESSAGE): - full_sdata.write(f1, overwrite=True, sdata_formats=sdata_container_format) + full_sdata.write( + f1, overwrite=True, sdata_formats=sdata_container_format + ) def test_incremental_io_in_memory( @@ -606,7 +642,9 @@ def test_bug_rechunking_after_queried_raster(): # https://github.com/scverse/spatialdata-io/issues/117 ## single_scale = Image2DModel.parse(RNG.random((100, 10, 10)), chunks=(5, 5, 5)) - multi_scale = Image2DModel.parse(RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5)) + multi_scale = Image2DModel.parse( + RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5) + ) images = {"single_scale": single_scale, "multi_scale": multi_scale} sdata = SpatialData(images=images) queried = sdata.query.bounding_box( @@ -621,7 +659,9 @@ def test_bug_rechunking_after_queried_raster(): @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: +def test_self_contained( + full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType +) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained assert full_sdata.is_self_contained() description = full_sdata.elements_are_self_contained() @@ -645,7 +685,10 @@ def test_self_contained(full_sdata: SpatialData, sdata_container_format: Spatial # because of the images, labels and points description = sdata2.elements_are_self_contained() for element_name, self_contained in description.items(): - if any(element_name.startswith(prefix) for prefix in ["image", "labels", "points"]): + if any( + element_name.startswith(prefix) + for prefix in ["image", "labels", "points"] + ): assert not self_contained else: assert self_contained @@ -678,7 +721,11 @@ def test_self_contained(full_sdata: SpatialData, sdata_container_format: Spatial assert not sdata2.is_self_contained() description = sdata2.elements_are_self_contained() assert description["combined"] is False - assert all(description[element_name] for element_name in description if element_name != "combined") + assert all( + description[element_name] + for element_name in description + if element_name != "combined" + ) @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) @@ -690,7 +737,9 @@ def test_symmetric_difference_with_zarr_store( full_sdata.write(f, sdata_formats=sdata_container_format) # the list of element on-disk and in-memory is the same - only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() + only_in_memory, only_on_disk = ( + full_sdata._symmetric_difference_with_zarr_store() + ) assert len(only_in_memory) == 0 assert len(only_on_disk) == 0 @@ -706,7 +755,9 @@ def test_symmetric_difference_with_zarr_store( del full_sdata.tables["table"] # now the list of element on-disk and in-memory is different - only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() + only_in_memory, only_on_disk = ( + full_sdata._symmetric_difference_with_zarr_store() + ) assert set(only_in_memory) == { "images/new_image2d", "labels/new_labels2d", @@ -724,13 +775,17 @@ def test_symmetric_difference_with_zarr_store( @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_change_path_of_subset(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: +def test_change_path_of_subset( + full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType +) -> None: """A subset SpatialData object has not Zarr path associated, show that we can reassign the path""" with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") full_sdata.write(f, sdata_formats=sdata_container_format) - subset = full_sdata.subset(["image2d", "labels2d", "points_0", "circles", "table"]) + subset = full_sdata.subset( + ["image2d", "labels2d", "points_0", "circles", "table"] + ) assert subset.path is None subset.path = Path(f) @@ -795,7 +850,9 @@ def test_incremental_io_valid_name(full_sdata: SpatialData) -> None: @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_incremental_io_attrs(points: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: +def test_incremental_io_attrs( + points: SpatialData, sdata_container_format: SpatialDataContainerFormatType +) -> None: with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") my_attrs = {"a": "b", "c": 1} @@ -822,7 +879,9 @@ def test_incremental_io_attrs(points: SpatialData, sdata_container_format: Spati cached_sdata_blobs = blobs() -@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) +@pytest.mark.parametrize( + "element_name", ["image2d", "labels2d", "points_0", "circles", "table"] +) @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_delete_element_from_disk( full_sdata, @@ -830,7 +889,9 @@ def test_delete_element_from_disk( sdata_container_format: SpatialDataContainerFormatType, ) -> None: # can't delete an element for a SpatialData object without associated Zarr store - with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): + with pytest.raises( + ValueError, match="The SpatialData object is not backed by a Zarr store." + ): full_sdata.delete_element_from_disk("image2d") with tempfile.TemporaryDirectory() as tmpdir: @@ -858,7 +919,9 @@ def test_delete_element_from_disk( # can delete an element present both in-memory and on-disk full_sdata.delete_element_from_disk(element_name) - only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() + only_in_memory, only_on_disk = ( + full_sdata._symmetric_difference_with_zarr_store() + ) element_type = full_sdata._element_type_from_element_name(element_name) element_path = f"{element_type}/{element_name}" assert element_path in only_in_memory @@ -873,7 +936,9 @@ def test_delete_element_from_disk( assert element_path not in on_disk -@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) +@pytest.mark.parametrize( + "element_name", ["image2d", "labels2d", "points_0", "circles", "table"] +) @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_element_already_on_disk_different_type( full_sdata, @@ -927,7 +992,9 @@ def test_writing_invalid_name(tmp_path: Path): invalid_sdata.images.data[""] = next(iter(_get_images().values())) invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) - invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) + invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next( + iter(_get_shapes().values()) + ) invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") with pytest.raises(ValueError, match="Name (must|cannot)"): @@ -938,7 +1005,9 @@ def test_writing_valid_table_name_invalid_table(tmp_path: Path): # also try with a valid table name but invalid table # testing just one case, all the cases are in test_table_model_invalid_names() invalid_sdata = SpatialData() - invalid_sdata.tables.data["valid_name"] = AnnData(np.array([[0]]), layers={"invalid name": np.array([[0]])}) + invalid_sdata.tables.data["valid_name"] = AnnData( + np.array([[0]]), layers={"invalid name": np.array([[0]])} + ) with pytest.raises(ValueError, match="Name (must|cannot)"): invalid_sdata.write(tmp_path / "data.zarr") @@ -951,7 +1020,9 @@ def test_incremental_writing_invalid_name(tmp_path: Path): invalid_sdata.images.data[""] = next(iter(_get_images().values())) invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) - invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) + invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next( + iter(_get_shapes().values()) + ) invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") for element_type in ["images", "labels", "points", "shapes", "tables"]: @@ -966,7 +1037,9 @@ def test_incremental_writing_valid_table_name_invalid_table(tmp_path: Path): # testing just one case, all the cases are in test_table_model_invalid_names() invalid_sdata = SpatialData() invalid_sdata.write(tmp_path / "data2.zarr") - invalid_sdata.tables.data["valid_name"] = AnnData(np.array([[0]]), layers={"invalid name": np.array([[0]])}) + invalid_sdata.tables.data["valid_name"] = AnnData( + np.array([[0]]), layers={"invalid name": np.array([[0]])} + ) with pytest.raises(ValueError, match="Name (must|cannot)"): invalid_sdata.write_element("valid_name") @@ -986,13 +1059,19 @@ def test_reading_invalid_name(tmp_path: Path): ) valid_sdata.write(tmp_path / "data.zarr") # Circumvent validation at construction time and check validation happens again at writing time. - (tmp_path / "data.zarr/points" / points_name).rename(tmp_path / "data.zarr/points" / "has whitespace") + (tmp_path / "data.zarr/points" / points_name).rename( + tmp_path / "data.zarr/points" / "has whitespace" + ) # This one is not allowed on windows - (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()+,@") + (tmp_path / "data.zarr/shapes" / shapes_name).rename( + tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()+,@" + ) # We do this as the key of the element is otherwise not in the consolidated metadata, leading to an error. valid_sdata.write_consolidated_metadata() - with pytest.raises(ValidationError, match="Cannot construct SpatialData") as exc_info: + with pytest.raises( + ValidationError, match="Cannot construct SpatialData" + ) as exc_info: read_zarr(tmp_path / "data.zarr") actual_message = str(exc_info.value) @@ -1005,10 +1084,14 @@ def test_reading_invalid_name(tmp_path: Path): @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_write_store_unconsolidated_and_read(full_sdata, sdata_container_format: SpatialDataContainerFormatType): +def test_write_store_unconsolidated_and_read( + full_sdata, sdata_container_format: SpatialDataContainerFormatType +): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "data.zarr" - full_sdata.write(path, consolidate_metadata=False, sdata_formats=sdata_container_format) + full_sdata.write( + path, consolidate_metadata=False, sdata_formats=sdata_container_format + ) group = zarr.open_group(path, mode="r") assert group.metadata.consolidated_metadata is None @@ -1017,7 +1100,9 @@ def test_write_store_unconsolidated_and_read(full_sdata, sdata_container_format: @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format: SpatialDataContainerFormatType): +def test_can_read_sdata_with_reconsolidation( + full_sdata, sdata_container_format: SpatialDataContainerFormatType +): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "data.zarr" full_sdata.write(path, sdata_formats=sdata_container_format) @@ -1107,3 +1192,84 @@ def test_sdata_with_nan_in_obs() -> None: # After round-trip, NaN in object-dtype column becomes string "nan" assert sdata2["table"].obs["column_only_region1"].iloc[1] == "nan" assert np.isnan(sdata2["table"].obs["column_only_region2"].iloc[0]) + + +class TestLazyTableLoading: + """Tests for lazy table loading functionality. + + Lazy loading uses anndata.experimental.read_lazy() to keep large tables + out of memory until needed. This is particularly useful for MSI data + where tables can contain millions of pixels. + """ + + @pytest.fixture + def sdata_with_table(self) -> SpatialData: + """Create a SpatialData object with a simple table for testing.""" + table = TableModel.parse( + AnnData( + X=np.random.rand(100, 50), + obs=pd.DataFrame( + { + "region": pd.Categorical(["region1"] * 100), + "instance": np.arange(100), + } + ), + ), + region_key="region", + instance_key="instance", + region="region1", + ) + return SpatialData(tables={"test_table": table}) + + def test_lazy_read_basic(self, sdata_with_table: SpatialData) -> None: + """Test that lazy=True reads tables without loading into memory.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Read with lazy=True + try: + sdata_lazy = SpatialData.read(path, lazy=True) + + # Table should be present + assert "test_table" in sdata_lazy.tables + + # Check that X is a lazy array (dask or similar) + # Lazy AnnData from read_lazy uses dask arrays + table = sdata_lazy.tables["test_table"] + assert hasattr(table, "X") + + except ImportError: + # If anndata.experimental.read_lazy is not available, skip + pytest.skip("anndata.experimental.read_lazy not available") + + def test_lazy_false_loads_normally(self, sdata_with_table: SpatialData) -> None: + """Test that lazy=False (default) loads tables into memory normally.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Read with lazy=False (default) + sdata_normal = SpatialData.read(path, lazy=False) + + # Table should be present and loaded normally + assert "test_table" in sdata_normal.tables + table = sdata_normal.tables["test_table"] + + # X should be a numpy array or scipy sparse matrix (in-memory) + import scipy.sparse as sp + + assert isinstance(table.X, np.ndarray | sp.spmatrix) + + def test_read_zarr_lazy_parameter(self, sdata_with_table: SpatialData) -> None: + """Test that read_zarr function accepts lazy parameter.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Test read_zarr directly with lazy parameter + try: + sdata = read_zarr(path, lazy=True) + assert "test_table" in sdata.tables + except ImportError: + pytest.skip("anndata.experimental.read_lazy not available") From 44d4f45045f362b46cfcbabf26df51c012442247 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:59:20 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata/_core/spatialdata.py | 340 +++++++++------------------ src/spatialdata/_io/io_zarr.py | 38 +-- src/spatialdata/models/models.py | 184 ++++----------- tests/io/test_readwrite.py | 167 ++++--------- 4 files changed, 204 insertions(+), 525 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 869683ba8..079ca28b6 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -22,15 +22,33 @@ from zarr.errors import GroupNotFoundError from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables -from spatialdata._core.validation import (check_all_keys_case_insensitively_unique, check_target_region_column_symmetry, - check_valid_name, raise_validation_errors, validate_table_attr_keys) +from spatialdata._core.validation import ( + check_all_keys_case_insensitively_unique, + check_target_region_column_symmetry, + check_valid_name, + raise_validation_errors, + validate_table_attr_keys, +) from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T from spatialdata._utils import _deprecation_alias -from spatialdata.models import (Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, ShapesModel, - TableModel, get_model, get_table_keys) -from spatialdata.models._utils import (SpatialElement, convert_region_column_to_categorical, get_axes_names, - set_channel_names) +from spatialdata.models import ( + Image2DModel, + Image3DModel, + Labels2DModel, + Labels3DModel, + PointsModel, + ShapesModel, + TableModel, + get_model, + get_table_keys, +) +from spatialdata.models._utils import ( + SpatialElement, + convert_region_column_to_categorical, + get_axes_names, + set_channel_names, +) if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest @@ -119,11 +137,7 @@ def __init__( self._tables: Tables = Tables(shared_keys=self._shared_keys) self.attrs = attrs if attrs else {} # type: ignore[assignment] - element_names = list( - chain.from_iterable( - [e.keys() for e in [images, labels, points, shapes] if e is not None] - ) - ) + element_names = list(chain.from_iterable([e.keys() for e in [images, labels, points, shapes] if e is not None])) if len(element_names) != len(set(element_names)): duplicates = {x for x in element_names if element_names.count(x) > 1} @@ -249,9 +263,7 @@ def get_region_key_column(table: AnnData) -> pd.Series: _, region_key, _ = get_table_keys(table) if table.obs.get(region_key) is not None: return table.obs[region_key] - raise KeyError( - f"{region_key} is set as region key column. However the column is not found in table.obs." - ) + raise KeyError(f"{region_key} is set as region key column. However the column is not found in table.obs.") @staticmethod def get_instance_key_column(table: AnnData) -> pd.Series: @@ -276,13 +288,9 @@ def get_instance_key_column(table: AnnData) -> pd.Series: _, _, instance_key = get_table_keys(table) if table.obs.get(instance_key) is not None: return table.obs[instance_key] - raise KeyError( - f"{instance_key} is set as instance key column. However the column is not found in table.obs." - ) + raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.") - def set_channel_names( - self, element_name: str, channel_names: str | list[str], write: bool = False - ) -> None: + def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None: """Set the channel names for an image `SpatialElement` in the `SpatialData` object. This method will overwrite the element in memory with the same element, but with new channel names. @@ -300,9 +308,7 @@ def set_channel_names( Whether to overwrite the channel metadata on disk (lightweight operation). This will not rewrite the pixel data itself (heavy operation). """ - self.images[element_name] = set_channel_names( - self.images[element_name], channel_names - ) + self.images[element_name] = set_channel_names(self.images[element_name], channel_names) if write: self.write_channel_names(element_name) @@ -380,9 +386,7 @@ def _change_table_annotation_target( If provided region_key is not present in table.obs. """ attrs = table.uns[TableModel.ATTRS_KEY] - table_region_key = ( - region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) - ) + table_region_key = region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) TableModel()._validate_set_region_key(table, region_key) TableModel()._validate_set_instance_key(table, instance_key) @@ -390,9 +394,7 @@ def _change_table_annotation_target( attrs[TableModel.REGION_KEY] = region @staticmethod - def update_annotated_regions_metadata( - table: AnnData, region_key: str | None = None - ) -> AnnData: + def update_annotated_regions_metadata(table: AnnData, region_key: str | None = None) -> AnnData: """ Update the annotation target of the table using the region_key column in table.obs. @@ -415,9 +417,7 @@ def update_annotated_regions_metadata( """ attrs = table.uns.get(TableModel.ATTRS_KEY) if attrs is None: - raise ValueError( - "The table has no annotation metadata. Please parse the table using `TableModel.parse`." - ) + raise ValueError("The table has no annotation metadata. Please parse the table using `TableModel.parse`.") region_key = region_key if region_key else attrs[TableModel.REGION_KEY_KEY] if attrs[TableModel.REGION_KEY_KEY] != region_key: attrs[TableModel.REGION_KEY_KEY] = region_key @@ -465,20 +465,14 @@ def set_table_annotates_spatialelement( isinstance(region, list | pd.Series) and not all(region_element in element_names for region_element in region) ): - raise ValueError( - f"Annotation target '{region}' not present as SpatialElement in SpatialData object." - ) + raise ValueError(f"Annotation target '{region}' not present as SpatialElement in SpatialData object.") if table.uns.get(TableModel.ATTRS_KEY): - self._change_table_annotation_target( - table, region, region_key, instance_key - ) + self._change_table_annotation_target(table, region, region_key, instance_key) elif isinstance(region_key, str) and isinstance(instance_key, str): self._set_table_annotation_target(table, region, region_key, instance_key) else: - raise TypeError( - "No current annotation metadata found. Please specify both region_key and instance_key." - ) + raise TypeError("No current annotation metadata found. Please specify both region_key and instance_key.") convert_region_column_to_categorical(table) @property @@ -589,17 +583,9 @@ def locate_element(self, element: SpatialElement) -> list[str]: found_element_name.append(element_name) if len(found) == 0: return [] - if any( - "/" in found_element_name[i] or "/" in found_element_type[i] - for i in range(len(found)) - ): - raise ValueError( - "Found an element name with a '/' character. This is not allowed." - ) - return [ - f"{found_element_type[i]}/{found_element_name[i]}" - for i in range(len(found)) - ] + if any("/" in found_element_name[i] or "/" in found_element_type[i] for i in range(len(found))): + raise ValueError("Found an element name with a '/' character. This is not allowed.") + return [f"{found_element_type[i]}/{found_element_name[i]}" for i in range(len(found))] def filter_by_coordinate_system( self, @@ -697,9 +683,7 @@ def _filter_tables( if include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): tables[table_name] = table continue - if not include_orphan_tables and not table.uns.get( - TableModel.ATTRS_KEY - ): + if not include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): continue if table_name in names_tables_to_keep: tables[table_name] = table @@ -716,9 +700,7 @@ def _filter_tables( from spatialdata._core.query.relational_query import _filter_table_by_elements assert elements_dict is not None - table = _filter_table_by_elements( - table, elements_dict=elements_dict - ) + table = _filter_table_by_elements(table, elements_dict=elements_dict) if table is not None and len(table) != 0: tables[table_name] = table else: @@ -766,9 +748,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: for old_cs, new_cs in rename_dict.items(): if old_cs in transformations: random_suffix = hashlib.sha1(os.urandom(128)).hexdigest()[:8] - transformations[new_cs + random_suffix] = transformations.pop( - old_cs - ) + transformations[new_cs + random_suffix] = transformations.pop(old_cs) suffixes_to_replace.add(new_cs + random_suffix) # remove the random suffixes @@ -779,14 +759,10 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: new_transformations[cs] = transformations[cs_with_suffix] suffixes_to_replace.remove(cs_with_suffix) else: - new_transformations[cs_with_suffix] = transformations[ - cs_with_suffix - ] + new_transformations[cs_with_suffix] = transformations[cs_with_suffix] # set the new transformations - set_transformation( - element=element, transformation=new_transformations, set_all=True - ) + set_transformation(element=element, transformation=new_transformations, set_all=True) def transform_element_to_coordinate_system( self, @@ -814,18 +790,17 @@ def transform_element_to_coordinate_system( """ from spatialdata import transform from spatialdata.transformations import Sequence - from spatialdata.transformations.operations import (get_transformation, - get_transformation_between_coordinate_systems, - remove_transformation, set_transformation) + from spatialdata.transformations.operations import ( + get_transformation, + get_transformation_between_coordinate_systems, + remove_transformation, + set_transformation, + ) element = self.get(element_name) - t = get_transformation_between_coordinate_systems( - self, element, target_coordinate_system - ) + t = get_transformation_between_coordinate_systems(self, element, target_coordinate_system) if maintain_positioning: - transformed = transform( - element, transformation=t, maintain_positioning=maintain_positioning - ) + transformed = transform(element, transformation=t, maintain_positioning=maintain_positioning) else: d = get_transformation(element, get_all=True) assert isinstance(d, dict) @@ -861,9 +836,7 @@ def transform_element_to_coordinate_system( # since target_coordinate_system is in d, we have that t is a Sequence with only one transformation. assert isinstance(t, Sequence) assert len(t.transformations) == 1 - seq = get_transformation( - transformed, to_coordinate_system=target_coordinate_system - ) + seq = get_transformation(transformed, to_coordinate_system=target_coordinate_system) assert isinstance(seq, Sequence) assert len(seq.transformations) == 2 assert seq.transformations[1] is t.transformations[0] @@ -892,9 +865,7 @@ def transform_to_coordinate_system( ------- The transformed SpatialData. """ - sdata = self.filter_by_coordinate_system( - target_coordinate_system, filter_tables=False - ) + sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_tables=False) elements: dict[str, dict[str, SpatialElement]] = {} for element_type, element_name, _ in sdata.gen_elements(): if element_type != "tables": @@ -930,9 +901,7 @@ def elements_are_self_contained(self) -> dict[str, bool]: description = {} for element_type, element_name, element in self.gen_elements(): element_path = self.path / element_type / element_name - description[element_name] = _is_element_self_contained( - element, element_path - ) + description[element_name] = _is_element_self_contained(element, element_path) return description def is_self_contained(self, element_name: str | None = None) -> bool: @@ -1049,12 +1018,8 @@ def _symmetric_difference_with_zarr_store(self) -> tuple[list[str], list[str]]: elements_in_sdata = self.elements_paths_in_memory() elements_in_zarr = self.elements_paths_on_disk() - elements_only_in_sdata = list( - set(elements_in_sdata).difference(set(elements_in_zarr)) - ) - elements_only_in_zarr = list( - set(elements_in_zarr).difference(set(elements_in_sdata)) - ) + elements_only_in_sdata = list(set(elements_in_sdata).difference(set(elements_in_zarr))) + elements_only_in_zarr = list(set(elements_in_zarr).difference(set(elements_in_sdata))) return elements_only_in_sdata, elements_only_in_zarr def _validate_can_safely_write_to_path( @@ -1069,9 +1034,7 @@ def _validate_can_safely_write_to_path( file_path = Path(file_path) if not isinstance(file_path, Path): - raise ValueError( - f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}." - ) + raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") # TODO: add test for this if os.path.exists(file_path): @@ -1093,28 +1056,24 @@ def _validate_can_safely_write_to_path( "Cannot overwrite. The target path of the write operation is in use. Please save the data to a " "different location. " ) - WORKAROUND = "\nWorkaround: please see discussion here https://github.com/scverse/spatialdata/discussions/520 ." + WORKAROUND = ( + "\nWorkaround: please see discussion here https://github.com/scverse/spatialdata/discussions/520 ." + ) if any(_backed_elements_contained_in_path(path=file_path, object=self)): raise ValueError( - ERROR_MSG - + "\nDetails: the target path contains one or more files that Dask use for " + ERROR_MSG + "\nDetails: the target path contains one or more files that Dask use for " "backing elements in the SpatialData object." + WORKAROUND ) if self.path is not None and ( - _is_subfolder(parent=self.path, child=file_path) - or _is_subfolder(parent=file_path, child=self.path) + _is_subfolder(parent=self.path, child=file_path) or _is_subfolder(parent=file_path, child=self.path) ): - if saving_an_element and _is_subfolder( - parent=self.path, child=file_path - ): + if saving_an_element and _is_subfolder(parent=self.path, child=file_path): raise ValueError( - ERROR_MSG - + "\nDetails: the target path in which to save an element is a subfolder " + ERROR_MSG + "\nDetails: the target path in which to save an element is a subfolder " "of the current Zarr store." + WORKAROUND ) raise ValueError( - ERROR_MSG - + "\nDetails: the target path either contains, coincides or is contained in" + ERROR_MSG + "\nDetails: the target path either contains, coincides or is contained in" " the current Zarr store." + WORKAROUND ) @@ -1139,9 +1098,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, update_sdata_path: bool = True, - sdata_formats: ( - SpatialDataFormatType | list[SpatialDataFormatType] | None - ) = None, + sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1203,9 +1160,7 @@ def write( store = _resolve_zarr_store(file_path) zarr_format = parsed["SpatialData"].zarr_format - zarr_group = zarr.create_group( - store=store, overwrite=overwrite, zarr_format=zarr_format - ) + zarr_group = zarr.create_group(store=store, overwrite=overwrite, zarr_format=zarr_format) self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) store.close() @@ -1300,9 +1255,7 @@ def write_element( self, element_name: str | list[str], overwrite: bool = False, - sdata_formats: ( - SpatialDataFormatType | list[SpatialDataFormatType] | None - ) = None, + sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1347,9 +1300,7 @@ def write_element( self._validate_element_names_are_unique() element = self.get(element_name) if element is None: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") if self.path is None: raise ValueError( @@ -1363,15 +1314,11 @@ def write_element( element_type = _element_type break if element_type is None: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") if element_type == "tables": validate_table_attr_keys(element) - self._check_element_not_on_disk_with_different_type( - element_type=element_type, element_name=element_name - ) + self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) self._write_element( element=element, @@ -1431,16 +1378,10 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: raise ValueError("The SpatialData object is not backed by a Zarr store.") on_disk = self.elements_paths_on_disk() - one_disk_names = [ - self._element_type_and_name_from_element_path(path)[1] for path in on_disk - ] + one_disk_names = [self._element_type_and_name_from_element_path(path)[1] for path in on_disk] in_memory = self.elements_paths_in_memory() - in_memory_names = [ - self._element_type_and_name_from_element_path(path)[1] for path in in_memory - ] - only_in_memory_names = list( - set(in_memory_names).difference(set(one_disk_names)) - ) + in_memory_names = [self._element_type_and_name_from_element_path(path)[1] for path in in_memory] + only_in_memory_names = list(set(in_memory_names).difference(set(one_disk_names))) only_on_disk_names = list(set(one_disk_names).difference(set(in_memory_names))) ERROR_MESSAGE = f"Element {element_name} is not found in the Zarr store associated with the SpatialData object." @@ -1453,25 +1394,19 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: if found: _element_type = self._element_type_from_element_name(element_name) - self._check_element_not_on_disk_with_different_type( - element_type=_element_type, element_name=element_name - ) + self._check_element_not_on_disk_with_different_type(element_type=_element_type, element_name=element_name) element_type = None on_disk = self.elements_paths_on_disk() for path in on_disk: - _element_type, _element_name = ( - self._element_type_and_name_from_element_path(path) - ) + _element_type, _element_name = self._element_type_and_name_from_element_path(path) if _element_name == element_name: element_type = _element_type break assert element_type is not None file_path_of_element = self.path / element_type / element_name - if any( - _backed_elements_contained_in_path(path=file_path_of_element, object=self) - ): + if any(_backed_elements_contained_in_path(path=file_path_of_element, object=self)): raise ValueError( "The file path specified is a parent directory of one or more files used for backing for one or " "more elements in the SpatialData object. Deleting the data would corrupt the SpatialData object." @@ -1488,14 +1423,10 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: if self.has_consolidated_metadata(): self.write_consolidated_metadata() - def _check_element_not_on_disk_with_different_type( - self, element_type: str, element_name: str - ) -> None: + def _check_element_not_on_disk_with_different_type(self, element_type: str, element_name: str) -> None: only_on_disk = self.elements_paths_on_disk() for disk_path in only_on_disk: - disk_element_type, disk_element_name = ( - self._element_type_and_name_from_element_path(disk_path) - ) + disk_element_type, disk_element_name = self._element_type_and_name_from_element_path(disk_path) if disk_element_name == element_name and disk_element_type != element_type: raise ValueError( f"Element {element_name} is found in the Zarr store as a {disk_element_type}, but it is found " @@ -1520,9 +1451,7 @@ def has_consolidated_metadata(self) -> bool: store.close() return return_value - def _validate_can_write_metadata_on_element( - self, element_name: str - ) -> tuple[str, SpatialElement | AnnData] | None: + def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[str, SpatialElement | AnnData] | None: """Validate if metadata can be written on an element, returns None if it cannot be written.""" from spatialdata._io._utils import _is_element_self_contained from spatialdata._io.io_zarr import _group_for_element_exists @@ -1545,9 +1474,7 @@ def _validate_can_write_metadata_on_element( element_type = self._element_type_from_element_name(element_name) - self._check_element_not_on_disk_with_different_type( - element_type=element_type, element_name=element_name - ) + self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) # check if the element exists in the Zarr storage if not _group_for_element_exists( @@ -1566,9 +1493,7 @@ def _validate_can_write_metadata_on_element( # warn the users if the element is not self-contained, that is, it is Dask-backed by files outside the Zarr # group for the element element_zarr_path = Path(self.path) / element_type / element_name - if not _is_element_self_contained( - element=element, element_path=element_zarr_path - ): + if not _is_element_self_contained(element=element, element_path=element_zarr_path): logger.info( f"Element {element_type}/{element_name} is not self-contained. The metadata will be" " saved to the Zarr group of the element in the SpatialData Zarr store. The data outside the element " @@ -1591,9 +1516,7 @@ def write_channel_names(self, element_name: str | None = None) -> None: if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") # recursively write the transformation for all the SpatialElement if element_name is None: @@ -1620,9 +1543,7 @@ def write_channel_names(self, element_name: str | None = None) -> None: overwrite_channel_names(element_group, element) else: - raise ValueError( - f"Can't set channel names for element of type '{element_type}'." - ) + raise ValueError(f"Can't set channel names for element of type '{element_type}'.") def write_transformations(self, element_name: str | None = None) -> None: """ @@ -1638,9 +1559,7 @@ def write_transformations(self, element_name: str | None = None) -> None: if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") # recursively write the transformation for all the SpatialElement if element_name is None: @@ -1671,9 +1590,7 @@ def write_transformations(self, element_name: str | None = None) -> None: from spatialdata._io._utils import overwrite_coordinate_transformations_raster from spatialdata._io.format import RasterFormats - raster_format = RasterFormats[ - element_group.metadata.attributes["spatialdata_attrs"]["version"] - ] + raster_format = RasterFormats[element_group.metadata.attributes["spatialdata_attrs"]["version"]] overwrite_coordinate_transformations_raster( group=element_group, axes=axes, @@ -1695,9 +1612,7 @@ def _element_type_from_element_name(self, element_name: str) -> str: self._validate_element_names_are_unique() element = self.get(element_name) if element is None: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") located = self.locate_element(element) element_type = None @@ -1711,9 +1626,7 @@ def _element_type_from_element_name(self, element_name: str) -> str: assert element_type is not None return element_type - def _element_type_and_name_from_element_path( - self, element_path: str - ) -> tuple[str, str]: + def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[str, str]: element_type, element_name = element_path.split("/") return element_type, element_name @@ -1726,27 +1639,19 @@ def write_attrs( from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import CurrentSpatialDataContainerFormat, SpatialDataContainerFormatType - sdata_format = ( - sdata_format - if sdata_format is not None - else CurrentSpatialDataContainerFormat() - ) + sdata_format = sdata_format if sdata_format is not None else CurrentSpatialDataContainerFormat() assert isinstance(sdata_format, SpatialDataContainerFormatType) store = None if zarr_group is None: - assert ( - self.is_backed() - ), "The SpatialData object must be backed by a Zarr store to write attrs." + assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." store = _resolve_zarr_store(self.path) zarr_group = zarr.open_group(store=store, mode="r+") version = sdata_format.spatialdata_format_version version_specific_attrs = sdata_format.attrs_to_dict() - attrs_to_write = { - "spatialdata_attrs": {"version": version} | version_specific_attrs - } | self.attrs + attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs try: zarr_group.attrs.put(attrs_to_write) @@ -1795,9 +1700,7 @@ def write_metadata( if element_name is not None: check_valid_name(element_name) if element_name not in self: - raise ValueError( - f"Element with name {element_name} not found in SpatialData object." - ) + raise ValueError(f"Element with name {element_name} not found in SpatialData object.") if write_attrs: self.write_attrs(sdata_format=sdata_format) @@ -1840,9 +1743,7 @@ def get_attrs( the value of `return_as`. """ - def _flatten_mapping( - m: Mapping[str, Any], parent_key: str = "", sep: str = "_" - ) -> dict[str, Any]: + def _flatten_mapping(m: Mapping[str, Any], parent_key: str = "", sep: str = "_") -> dict[str, Any]: items: list[tuple[str, Any]] = [] for k, v in m.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k @@ -1887,9 +1788,7 @@ def _flatten_mapping( except Exception as e: raise ValueError(f"Failed to convert data to DataFrame: {e}") from e - raise ValueError( - f"Invalid 'return_as' value: {return_as}. Expected 'dict', 'json', 'df', or None." - ) + raise ValueError(f"Invalid 'return_as' value: {return_as}. Expected 'dict', 'json', 'df', or None.") @property def tables(self) -> Tables: @@ -2027,8 +1926,7 @@ def _non_empty_elements(self) -> list[str]: return [ element for element in all_elements - if (getattr(self, element) is not None) - and (len(getattr(self, element)) > 0) + if (getattr(self, element) is not None) and (len(getattr(self, element)) > 0) ] def __repr__(self) -> str: @@ -2066,9 +1964,7 @@ def h(s: str) -> str: descr += f"\n{h('level0')}{attr.capitalize()}" unsorted_elements = attribute.items() - sorted_elements = sorted( - unsorted_elements, key=lambda x: _natural_keys(x[0]) - ) + sorted_elements = sorted(unsorted_elements, key=lambda x: _natural_keys(x[0])) for k, v in sorted_elements: descr += f"{h('empty_line')}" descr_class = v.__class__.__name__ @@ -2100,16 +1996,7 @@ def h(s: str) -> str: else: shape_str = ( "(" - + ", ".join( - [ - ( - str(dim) - if not isinstance(dim, Scalar) - else "" - ) - for dim in v.shape - ] - ) + + ", ".join([(str(dim) if not isinstance(dim, Scalar) else "") for dim in v.shape]) + ")" ) descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} with shape: {shape_str} {dim_string}" @@ -2185,9 +2072,7 @@ def _element_path_to_element_name_with_type(element_path: str) -> str: if not self.is_self_contained(): assert self.path is not None - descr += ( - "\nwith the following Dask-backed elements not being self-contained:" - ) + descr += "\nwith the following Dask-backed elements not being self-contained:" description = self.elements_are_self_contained() for _, element_name, element in self.gen_elements(): if not description[element_name]: @@ -2195,9 +2080,7 @@ def _element_path_to_element_name_with_type(element_path: str) -> str: descr += f"\n ▸ {element_name}: {backing_files}" if self.path is not None: - elements_only_in_sdata, elements_only_in_zarr = ( - self._symmetric_difference_with_zarr_store() - ) + elements_only_in_sdata, elements_only_in_zarr = self._symmetric_difference_with_zarr_store() if len(elements_only_in_sdata) > 0: descr += "\nwith the following elements not in the Zarr store:" for element_path in elements_only_in_sdata: @@ -2284,13 +2167,9 @@ def _validate_element_names_are_unique(self) -> None: ValueError If the element names are not unique. """ - check_all_keys_case_insensitively_unique( - [name for _, name, _ in self.gen_elements()], location=() - ) + check_all_keys_case_insensitively_unique([name for _, name, _ in self.gen_elements()], location=()) - def _find_element( - self, element_name: str - ) -> tuple[str, str, SpatialElement | AnnData]: + def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | AnnData]: """ Retrieve SpatialElement or Table from the SpatialData instance matching element_name. @@ -2389,9 +2268,7 @@ def subset( """ elements_dict: dict[str, SpatialElement] = {} names_tables_to_keep: set[str] = set() - for element_type, element_name, element in self._gen_elements( - include_tables=True - ): + for element_type, element_name, element in self._gen_elements(include_tables=True): if element_name in element_names: if element_type != "tables": elements_dict.setdefault(element_type, {})[element_name] = element @@ -2424,16 +2301,11 @@ def __getitem__(self, item: str) -> SpatialElement | AnnData: def __contains__(self, key: str) -> bool: element_dict = { - element_name: element_value - for _, element_name, element_value in self._gen_elements( - include_tables=True - ) + element_name: element_value for _, element_name, element_value in self._gen_elements(include_tables=True) } return key in element_dict - def get( - self, key: str, default_value: SpatialElement | AnnData | None = None - ) -> SpatialElement | AnnData | None: + def get(self, key: str, default_value: SpatialElement | AnnData | None = None) -> SpatialElement | AnnData | None: """ Get element from SpatialData object based on corresponding name. @@ -2541,9 +2413,7 @@ def filter_by_table_query( obs_names_expr: Predicates | None = None, var_names_expr: Predicates | None = None, layer: str | None = None, - how: Literal[ - "left", "left_exclusive", "inner", "right", "right_exclusive" - ] = "right", + how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", ) -> SpatialData: """ Filter the SpatialData object based on a set of table queries. diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index c24c257ca..1cb733882 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -33,12 +33,7 @@ def _read_zarr_group_spatialdata_element( read_func: Callable[..., Any], group_name: Literal["images", "labels", "shapes", "points", "tables"], element_type: Literal["image", "labels", "shapes", "points", "tables"], - element_container: ( - dict[str, Raster_T] - | dict[str, DaskDataFrame] - | dict[str, GeoDataFrame] - | dict[str, AnnData] - ), + element_container: (dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData]), on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], ) -> None: with handle_read_errors( @@ -68,9 +63,7 @@ def _read_zarr_group_spatialdata_element( ), ): if element_type in ["image", "labels"]: - reader_format = get_raster_format_for_read( - elem_group, sdata_version - ) + reader_format = get_raster_format_for_read(elem_group, sdata_version) element = read_func( elem_group_path, cast(Literal["image", "labels"], element_type), @@ -124,9 +117,7 @@ def get_raster_format_for_read( def read_zarr( store: str | Path | UPath | zarr.Group, selection: None | tuple[str] = None, - on_bad_files: Literal[ - BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN - ] = BadFileHandleMethod.ERROR, + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, lazy: bool = False, ) -> SpatialData: """ @@ -186,11 +177,7 @@ def read_zarr( shapes: dict[str, GeoDataFrame] = {} tables: dict[str, AnnData] = {} - selector = ( - {"images", "labels", "points", "shapes", "tables"} - if not selection - else set(selection or []) - ) + selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") # we could make this more readable. One can get lost when looking at this dict and iteration over the items @@ -199,10 +186,7 @@ def read_zarr( tuple[ Callable[..., Any], Literal["image", "labels", "shapes", "points", "tables"], - dict[str, Raster_T] - | dict[str, DaskDataFrame] - | dict[str, GeoDataFrame] - | dict[str, AnnData], + dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData], ], ] = { # ome-zarr-py needs a kwargs that has "image" has key. So here we have "image" and not "images" @@ -296,21 +280,15 @@ def _get_groups_for_element( # When writing, use_consolidated must be set to False. Otherwise, the metadata store # can get out of sync with newly added elements (e.g., labels), leading to errors. - root_group = zarr.open_group( - store=resolved_store, mode="r+", use_consolidated=use_consolidated - ) + root_group = zarr.open_group(store=resolved_store, mode="r+", use_consolidated=use_consolidated) element_type_group = root_group.require_group(element_type) - element_type_group = zarr.open_group( - element_type_group.store_path, mode="a", use_consolidated=use_consolidated - ) + element_type_group = zarr.open_group(element_type_group.store_path, mode="a", use_consolidated=use_consolidated) element_name_group = element_type_group.require_group(element_name) return root_group, element_type_group, element_name_group -def _group_for_element_exists( - zarr_path: Path, element_type: str, element_name: str -) -> bool: +def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: str) -> bool: """ Check if the group for an element exists. diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index fa3920c2e..42be9cddd 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -32,19 +32,19 @@ from spatialdata._utils import _check_match_length_channels_c_dim from spatialdata.config import settings from spatialdata.models import C, X, Y, Z, get_axes_names -from spatialdata.models._utils import (DEFAULT_COORDINATE_SYSTEM, TRANSFORM_KEY, MappingToCoordinateSystem_t, - SpatialElement, _validate_mapping_to_coordinate_system_type, - convert_region_column_to_categorical) +from spatialdata.models._utils import ( + DEFAULT_COORDINATE_SYSTEM, + TRANSFORM_KEY, + MappingToCoordinateSystem_t, + SpatialElement, + _validate_mapping_to_coordinate_system_type, + convert_region_column_to_categorical, +) from spatialdata.transformations._utils import _get_transformations, _set_transformations, compute_coordinates from spatialdata.transformations.transformations import BaseTransformation, Identity # Types -Chunks_t: TypeAlias = ( - int - | tuple[int, ...] - | tuple[tuple[int, ...], ...] - | Mapping[Any, None | int | tuple[int, ...]] -) +Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]] ScaleFactors_t = Sequence[dict[str, int] | int] Transform_s = AttrSchema(BaseTransformation, None) @@ -72,9 +72,7 @@ def _is_lazy_anndata(adata: AnnData) -> bool: return False -def _parse_transformations( - element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None -) -> None: +def _parse_transformations(element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None) -> None: _validate_mapping_to_coordinate_system_type(transformations) transformations_in_element = _get_transformations(element) if ( @@ -180,9 +178,7 @@ def parse( if transformations: transformations = transformations.copy() if "name" in kwargs: - raise ValueError( - "The `name` argument is not (yet) supported for raster data." - ) + raise ValueError("The `name` argument is not (yet) supported for raster data.") # if dims is specified inside the data, get the value of dims from the data if isinstance(data, DataArray): if not isinstance(data.data, DaskArray): # numpy -> dask @@ -230,18 +226,13 @@ def parse( if c_coords is not None: c_coords = _check_match_length_channels_c_dim(data, c_coords, cls.dims.dims) - if ( - c_coords is not None - and len(c_coords) != data.shape[cls.dims.dims.index("c")] - ): + if c_coords is not None and len(c_coords) != data.shape[cls.dims.dims.index("c")]: raise ValueError( f"The number of channel names `{len(c_coords)}` does not match the length of dimension 'c'" f" with length {data.shape[cls.dims.dims.index('c')]}." ) - data = to_spatial_image( - array_like=data, dims=cls.dims.dims, c_coords=c_coords, **kwargs - ) + data = to_spatial_image(array_like=data, dims=cls.dims.dims, c_coords=c_coords, **kwargs) # parse transformations _parse_transformations(data, transformations) # convert to multiscale if needed @@ -296,18 +287,12 @@ def _(self, data: DataArray) -> None: @validate.register(DataTree) def _(self, data: DataTree) -> None: - for j, k in zip( - data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))], strict=True - ): + for j, k in zip(data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))], strict=True): if j != k: - raise ValueError( - f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`." - ) + raise ValueError(f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`.") name = {list(data[i].data_vars.keys())[0] for i in data} if len(name) != 1: - raise ValueError( - f"Expected exactly one data variable for the datatree: found `{name}`." - ) + raise ValueError(f"Expected exactly one data variable for the datatree: found `{name}`.") name = list(name)[0] for d in data: super().validate(data[d][name]) @@ -475,14 +460,9 @@ def validate(cls, data: GeoDataFrame) -> None: """ SUGGESTION = " Please use ShapesModel.parse() to construct data that is guaranteed to be valid." if cls.GEOMETRY_KEY not in data: - raise KeyError( - f"GeoDataFrame must have a column named `{cls.GEOMETRY_KEY}`." - + SUGGESTION - ) + raise KeyError(f"GeoDataFrame must have a column named `{cls.GEOMETRY_KEY}`." + SUGGESTION) if not isinstance(data[cls.GEOMETRY_KEY], GeoSeries): - raise ValueError( - f"Column `{cls.GEOMETRY_KEY}` must be a GeoSeries." + SUGGESTION - ) + raise ValueError(f"Column `{cls.GEOMETRY_KEY}` must be a GeoSeries." + SUGGESTION) if len(data[cls.GEOMETRY_KEY]) == 0: raise ValueError(f"Column `{cls.GEOMETRY_KEY}` is empty." + SUGGESTION) geom_ = data[cls.GEOMETRY_KEY].values[0] @@ -507,10 +487,7 @@ def validate(cls, data: GeoDataFrame) -> None: "please correct the radii of the circles before calling the parser function.", ) if cls.TRANSFORM_KEY not in data.attrs: - raise ValueError( - f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." - + SUGGESTION - ) + raise ValueError(f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + SUGGESTION) if len(data) > 0: n = data.geometry.iloc[0]._ndim if n != 2: @@ -607,9 +584,7 @@ def parse(cls, data: Any, **kwargs: Any) -> GeoDataFrame: def _( cls, data: np.ndarray, # type: ignore[type-arg] - geometry: Literal[ - 0, 3, 6 - ], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON] + geometry: Literal[0, 3, 6], # [GeometryType.POINT, GeometryType.POLYGON, GeometryType.MULTIPOLYGON] offsets: tuple[ArrayLike, ...] | None = None, radius: float | ArrayLike | None = None, index: ArrayLike | None = None, @@ -620,9 +595,7 @@ def _( geo_df = GeoDataFrame({"geometry": data}) if GeometryType(geometry).name == "POINT": if radius is None: - raise ValueError( - "If `geometry` is `Circles`, `radius` must be provided." - ) + raise ValueError("If `geometry` is `Circles`, `radius` must be provided.") geo_df[cls.RADIUS_KEY] = radius if index is not None: geo_df.index = index @@ -649,9 +622,7 @@ def _( geo_df = GeoDataFrame({"geometry": gc.geoms}) if isinstance(geo_df["geometry"].iloc[0], Point): if radius is None: - raise ValueError( - "If `geometry` is `Circles`, `radius` must be provided." - ) + raise ValueError("If `geometry` is `Circles`, `radius` must be provided.") geo_df[cls.RADIUS_KEY] = radius if index is not None: geo_df.index = index @@ -668,10 +639,7 @@ def _( ) -> GeoDataFrame: if "geometry" not in data.columns: raise ValueError("`geometry` column not found in `GeoDataFrame`.") - if ( - isinstance(data["geometry"].iloc[0], Point) - and cls.RADIUS_KEY not in data.columns - ): + if isinstance(data["geometry"].iloc[0], Point) and cls.RADIUS_KEY not in data.columns: raise ValueError(f"Column `{cls.RADIUS_KEY}` not found.") _parse_transformations(data, transformations) cls.validate(data) @@ -711,8 +679,7 @@ def validate(cls, data: DaskDataFrame) -> None: raise ValueError(f"Column `{ax}` must be of type `int` or `float`.") if cls.TRANSFORM_KEY not in data.attrs: raise ValueError( - f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." - + SUGGESTION + f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + SUGGESTION ) if ATTRS_KEY in data.attrs and "feature_key" in data.attrs[ATTRS_KEY]: feature_key = data.attrs[ATTRS_KEY][cls.FEATURE_KEY] @@ -795,15 +762,11 @@ def _( if annotation is not None: if feature_key is not None: - df_dict[feature_key] = ( - annotation[feature_key].astype(str).astype("category") - ) + df_dict[feature_key] = annotation[feature_key].astype(str).astype("category") if instance_key is not None: df_dict[instance_key] = annotation[instance_key] if Z not in axes and Z in annotation.columns: - logger.info( - f"Column `{Z}` in `annotation` will be ignored since the data is 2D." - ) + logger.info(f"Column `{Z}` in `annotation` will be ignored since the data is 2D.") for c in set(annotation.columns) - {feature_key, instance_key, X, Y, Z}: df_dict[c] = annotation[c] @@ -842,9 +805,7 @@ def _( if "sort" not in kwargs: index_monotonically_increasing = data.index.is_monotonic_increasing if not isinstance(index_monotonically_increasing, bool): - index_monotonically_increasing = ( - index_monotonically_increasing.compute() - ) + index_monotonically_increasing = index_monotonically_increasing.compute() sort = index_monotonically_increasing else: sort = kwargs["sort"] @@ -882,9 +843,7 @@ def _( if data[feature_key].dtype.name == "category": table[feature_key] = data[feature_key] else: - table[feature_key] = ( - data[feature_key].astype(str).astype("category") - ) + table[feature_key] = data[feature_key].astype(str).astype("category") if instance_key is not None: table[instance_key] = data[instance_key] for c in [X, Y, Z]: @@ -944,13 +903,9 @@ def _add_metadata_and_validate( # It also just changes the state of the series, so it is not a big deal. if isinstance(data[c].dtype, CategoricalDtype) and not data[c].cat.known: try: - data[c] = data[c].cat.set_categories( - data[c].compute().cat.categories - ) + data[c] = data[c].cat.set_categories(data[c].compute().cat.categories) except ValueError: - logger.info( - f"Column `{c}` contains unknown categories. Consider casting it." - ) + logger.info(f"Column `{c}` contains unknown categories. Consider casting it.") _parse_transformations(data, transformations) cls.validate(data) @@ -964,9 +919,7 @@ class TableModel: INSTANCE_KEY = "instance_key" ATTRS_KEY = ATTRS_KEY - def _validate_set_region_key( - self, data: AnnData, region_key: str | None = None - ) -> None: + def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) -> None: """ Validate the region key in table.uns or set a new region key as the region key column. @@ -1006,9 +959,7 @@ def _validate_set_region_key( raise ValueError(f"'{region_key}' column not present in table.obs") attrs[self.REGION_KEY_KEY] = region_key - def _validate_set_instance_key( - self, data: AnnData, instance_key: str | None = None - ) -> None: + def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = None) -> None: """ Validate the instance_key in table.uns or set a new instance_key as the instance_key column. @@ -1052,9 +1003,7 @@ def _validate_set_instance_key( if instance_key in data.obs: attrs[self.INSTANCE_KEY] = instance_key else: - raise ValueError( - f"Instance key column '{instance_key}' not found in table.obs." - ) + raise ValueError(f"Instance key column '{instance_key}' not found in table.obs.") def _validate_table_annotation_metadata(self, data: AnnData) -> None: """ @@ -1089,26 +1038,16 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: attr = data.uns[ATTRS_KEY] if "region" not in attr: - raise ValueError( - f"`region` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION - ) + raise ValueError(f"`region` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) if "region_key" not in attr: - raise ValueError( - f"`region_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION - ) + raise ValueError(f"`region_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) if "instance_key" not in attr: - raise ValueError( - f"`instance_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION - ) + raise ValueError(f"`instance_key` not found in `adata.uns['{ATTRS_KEY}']`." + SUGGESTION) if attr[self.REGION_KEY_KEY] not in data.obs: - raise ValueError( - f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column." - ) + raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.") if attr[self.INSTANCE_KEY] not in data.obs: - raise ValueError( - f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column." - ) + raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.") # Skip detailed dtype/value validation for lazy-loaded AnnData # These checks would trigger data loading, defeating the purpose of lazy loading @@ -1129,27 +1068,17 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: "O", ] and not pd.api.types.is_string_dtype(data.obs[attr[self.INSTANCE_KEY]]) - or ( - dtype == "O" - and (val_dtype := type(data.obs[attr[self.INSTANCE_KEY]].iloc[0])) - is not str - ) + or (dtype == "O" and (val_dtype := type(data.obs[attr[self.INSTANCE_KEY]].iloc[0])) is not str) ): dtype = dtype if dtype != "O" else val_dtype raise TypeError( f"Only int, np.int16, np.int32, np.int64, uint equivalents or string allowed as dtype for " f"instance_key column in obs. Dtype found to be {dtype}" ) - expected_regions = ( - attr[self.REGION_KEY] - if isinstance(attr[self.REGION_KEY], list) - else [attr[self.REGION_KEY]] - ) + expected_regions = attr[self.REGION_KEY] if isinstance(attr[self.REGION_KEY], list) else [attr[self.REGION_KEY]] found_regions = data.obs[attr[self.REGION_KEY_KEY]].unique().tolist() if len(set(expected_regions).symmetric_difference(set(found_regions))) > 0: - raise ValueError( - f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match." - ) + raise ValueError(f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match.") # Warning for object/string columns with NaN in region_key or instance_key instance_key = attr[self.INSTANCE_KEY] @@ -1161,9 +1090,7 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None: if key_value in data.obs: col = data.obs[key_value] col_dtype = col.dtype - if ( - col_dtype == "object" or pd.api.types.is_string_dtype(col_dtype) - ) and col.isna().any(): + if (col_dtype == "object" or pd.api.types.is_string_dtype(col_dtype)) and col.isna().any(): logger.warning( f"The {key_name} column '{key_value}' is of {col_dtype} type and contains NaN values. " "After writing and reading with AnnData, NaN values may (depending on the AnnData version) " @@ -1206,9 +1133,7 @@ def validate( f"using TableModel.parse(adata)." ) # Skip dtype validation for lazy tables (would require loading data) - if not is_lazy and not isinstance( - data.obs[region_key].dtype, CategoricalDtype - ): + if not is_lazy and not isinstance(data.obs[region_key].dtype, CategoricalDtype): raise ValueError( f"`table.obs[{region_key}]` must be of type `categorical`, not `{type(data.obs[region_key])}`." ) @@ -1220,9 +1145,7 @@ def validate( ) # Skip null check for lazy tables (would require loading data) if not is_lazy and data.obs[instance_key].isnull().values.any(): - raise ValueError( - "`table.obs[instance_key]` must not contain null values, but it does." - ) + raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") self._validate_table_annotation_metadata(data) @@ -1259,9 +1182,7 @@ def parse( """ validate_table_attr_keys(adata) # either all live in adata.uns or all be passed in as argument - n_args = sum( - [region is not None, region_key is not None, instance_key is not None] - ) + n_args = sum([region is not None, region_key is not None, instance_key is not None]) if n_args == 0: if cls.ATTRS_KEY not in adata.uns: # table not annotating any element @@ -1290,9 +1211,7 @@ def parse( region = region.tolist() region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): - raise ValueError( - f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values." - ) + raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") adata.uns[cls.ATTRS_KEY][cls.REGION_KEY] = region adata.uns[cls.ATTRS_KEY][cls.REGION_KEY_KEY] = region_key @@ -1303,9 +1222,7 @@ def parse( grouped = adata.obs.groupby(region_key, observed=True) grouped_size = grouped.size() grouped_nunique = grouped.nunique() - not_unique = grouped_size[ - grouped_size != grouped_nunique[instance_key] - ].index.tolist() + not_unique = grouped_size[grouped_size != grouped_nunique[instance_key]].index.tolist() if not_unique: raise ValueError( f"Instance key column for region(s) `{', '.join(not_unique)}` does not contain only unique values" @@ -1416,11 +1333,6 @@ def _get_region_metadata_from_region_key_column(table: AnnData) -> list[str]: ) annotated_regions = region_key_column.unique().tolist() else: - annotated_regions = ( - table.obs[region_key] - .cat.remove_unused_categories() - .cat.categories.unique() - .tolist() - ) + annotated_regions = table.obs[region_key].cat.remove_unused_categories().cat.categories.unique().tolist() assert isinstance(annotated_regions, list) return annotated_regions diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 241336df7..80e9b5da2 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -106,9 +106,7 @@ def test_shapes( else: # convert each Polygon to a MultiPolygon mixed_multipolygon = shapes["mixed"].assign( - geometry=lambda df: df.geometry.apply( - lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g - ) + geometry=lambda df: df.geometry.apply(lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g) ) assert sdata["mixed"].equals(mixed_multipolygon) assert not sdata["mixed"].equals(shapes["mixed"]) @@ -145,9 +143,7 @@ def test_shapes_geometry_encoding_write_element( # Write each shape element - should use global setting for shape_name in shapes.shapes: - empty_sdata.write_element( - shape_name, sdata_formats=sdata_container_format - ) + empty_sdata.write_element(shape_name, sdata_formats=sdata_container_format) # Verify the encoding metadata in the parquet file parquet_file = tmpdir / "shapes" / shape_name / "shapes.parquet" @@ -228,12 +224,8 @@ def test_multiple_tables( tables: list[AnnData], sdata_container_format: SpatialDataContainerFormatType, ) -> None: - sdata_tables = SpatialData( - tables={str(i): tables[i] for i in range(len(tables))} - ) - self._test_table( - tmp_path, sdata_tables, sdata_container_format=sdata_container_format - ) + sdata_tables = SpatialData(tables={str(i): tables[i] for i in range(len(tables))}) + self._test_table(tmp_path, sdata_tables, sdata_container_format=sdata_container_format) def test_roundtrip( self, @@ -264,9 +256,7 @@ def test_incremental_io_list_of_elements( assert "shapes/new_shapes0" not in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" not in shapes.elements_paths_on_disk() - shapes.write_element( - ["new_shapes0", "new_shapes1"], sdata_formats=sdata_container_format - ) + shapes.write_element(["new_shapes0", "new_shapes1"], sdata_formats=sdata_container_format) assert "shapes/new_shapes0" in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" in shapes.elements_paths_on_disk() @@ -381,9 +371,7 @@ def test_incremental_io_on_disk( ValueError, match=match, ): - sdata.write_element( - name, overwrite=True, sdata_formats=sdata_container_format - ) + sdata.write_element(name, overwrite=True, sdata_formats=sdata_container_format) if workaround == 1: new_name = f"{name}_new_place" @@ -414,9 +402,7 @@ def test_incremental_io_on_disk( sdata.delete_element_from_disk(name) sdata.write_element(name, sdata_formats=sdata_container_format) - def test_io_and_lazy_loading_points( - self, points, sdata_container_format: SpatialDataContainerFormatType - ): + def test_io_and_lazy_loading_points(self, points, sdata_container_format: SpatialDataContainerFormatType): with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") points.write(f, sdata_formats=sdata_container_format) @@ -425,9 +411,7 @@ def test_io_and_lazy_loading_points( sdata2 = SpatialData.read(f) assert len(get_dask_backing_files(sdata2)) > 0 - def test_io_and_lazy_loading_raster( - self, images, labels, sdata_container_format: SpatialDataContainerFormatType - ): + def test_io_and_lazy_loading_raster(self, images, labels, sdata_container_format: SpatialDataContainerFormatType): sdatas = {"images": images, "labels": labels} for k, sdata in sdatas.items(): d = getattr(sdata, k) @@ -477,13 +461,9 @@ def test_replace_transformation_on_disk_non_raster( with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") sdata.write(f, sdata_formats=sdata_container_format) - t0 = get_transformation( - SpatialData.read(f).__getattribute__(k)[elem_name] - ) + t0 = get_transformation(SpatialData.read(f).__getattribute__(k)[elem_name]) assert isinstance(t0, Identity) - set_transformation( - sdata[elem_name], Scale([2.0], axes=("x",)), write_to_sdata=sdata - ) + set_transformation(sdata[elem_name], Scale([2.0], axes=("x",)), write_to_sdata=sdata) t1 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t1, Scale) @@ -494,16 +474,10 @@ def test_write_overwrite_fails_when_no_zarr_store( f = Path(tmpdir) / "data.zarr" f.mkdir() old_data = SpatialData() - with pytest.raises( - ValueError, match="The target file path specified already exists" - ): + with pytest.raises(ValueError, match="The target file path specified already exists"): old_data.write(f, sdata_formats=sdata_container_format) - with pytest.raises( - ValueError, match="The target file path specified already exists" - ): - full_sdata.write( - f, overwrite=True, sdata_formats=sdata_container_format - ) + with pytest.raises(ValueError, match="The target file path specified already exists"): + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( self, @@ -536,9 +510,7 @@ def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( match=r"Details: the target path contains one or more files that Dask use for " "backing elements in the SpatialData object", ): - full_sdata.write( - f, overwrite=True, sdata_formats=sdata_container_format - ) + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) def test_overwrite_fails_when_zarr_store_present( self, full_sdata, sdata_container_format: SpatialDataContainerFormatType @@ -558,9 +530,7 @@ def test_overwrite_fails_when_zarr_store_present( ValueError, match=r"Details: the target path either contains, coincides or is contained in the current Zarr store", ): - full_sdata.write( - f, overwrite=True, sdata_formats=sdata_container_format - ) + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) # support for overwriting backed sdata has been temporarily removed # with tempfile.TemporaryDirectory() as tmpdir: @@ -581,7 +551,9 @@ def test_overwrite_fails_when_zarr_store_present( def test_overwrite_fails_onto_non_zarr_file( self, full_sdata, sdata_container_format: SpatialDataContainerFormatType ): - ERROR_MESSAGE = "The target file path specified already exists, and it has been detected to not be a Zarr store." + ERROR_MESSAGE = ( + "The target file path specified already exists, and it has been detected to not be a Zarr store." + ) with tempfile.TemporaryDirectory() as tmpdir: f0 = os.path.join(tmpdir, "test.txt") with open(f0, "w"): @@ -594,17 +566,13 @@ def test_overwrite_fails_onto_non_zarr_file( ValueError, match=ERROR_MESSAGE, ): - full_sdata.write( - f0, overwrite=True, sdata_formats=sdata_container_format - ) + full_sdata.write(f0, overwrite=True, sdata_formats=sdata_container_format) f1 = os.path.join(tmpdir, "test.zarr") os.mkdir(f1) with pytest.raises(ValueError, match=ERROR_MESSAGE): full_sdata.write(f1, sdata_formats=sdata_container_format) with pytest.raises(ValueError, match=ERROR_MESSAGE): - full_sdata.write( - f1, overwrite=True, sdata_formats=sdata_container_format - ) + full_sdata.write(f1, overwrite=True, sdata_formats=sdata_container_format) def test_incremental_io_in_memory( @@ -642,9 +610,7 @@ def test_bug_rechunking_after_queried_raster(): # https://github.com/scverse/spatialdata-io/issues/117 ## single_scale = Image2DModel.parse(RNG.random((100, 10, 10)), chunks=(5, 5, 5)) - multi_scale = Image2DModel.parse( - RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5) - ) + multi_scale = Image2DModel.parse(RNG.random((100, 10, 10)), scale_factors=[2, 2], chunks=(5, 5, 5)) images = {"single_scale": single_scale, "multi_scale": multi_scale} sdata = SpatialData(images=images) queried = sdata.query.bounding_box( @@ -659,9 +625,7 @@ def test_bug_rechunking_after_queried_raster(): @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_self_contained( - full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType -) -> None: +def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained assert full_sdata.is_self_contained() description = full_sdata.elements_are_self_contained() @@ -685,10 +649,7 @@ def test_self_contained( # because of the images, labels and points description = sdata2.elements_are_self_contained() for element_name, self_contained in description.items(): - if any( - element_name.startswith(prefix) - for prefix in ["image", "labels", "points"] - ): + if any(element_name.startswith(prefix) for prefix in ["image", "labels", "points"]): assert not self_contained else: assert self_contained @@ -721,11 +682,7 @@ def test_self_contained( assert not sdata2.is_self_contained() description = sdata2.elements_are_self_contained() assert description["combined"] is False - assert all( - description[element_name] - for element_name in description - if element_name != "combined" - ) + assert all(description[element_name] for element_name in description if element_name != "combined") @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) @@ -737,9 +694,7 @@ def test_symmetric_difference_with_zarr_store( full_sdata.write(f, sdata_formats=sdata_container_format) # the list of element on-disk and in-memory is the same - only_in_memory, only_on_disk = ( - full_sdata._symmetric_difference_with_zarr_store() - ) + only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() assert len(only_in_memory) == 0 assert len(only_on_disk) == 0 @@ -755,9 +710,7 @@ def test_symmetric_difference_with_zarr_store( del full_sdata.tables["table"] # now the list of element on-disk and in-memory is different - only_in_memory, only_on_disk = ( - full_sdata._symmetric_difference_with_zarr_store() - ) + only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() assert set(only_in_memory) == { "images/new_image2d", "labels/new_labels2d", @@ -775,17 +728,13 @@ def test_symmetric_difference_with_zarr_store( @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_change_path_of_subset( - full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType -) -> None: +def test_change_path_of_subset(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: """A subset SpatialData object has not Zarr path associated, show that we can reassign the path""" with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") full_sdata.write(f, sdata_formats=sdata_container_format) - subset = full_sdata.subset( - ["image2d", "labels2d", "points_0", "circles", "table"] - ) + subset = full_sdata.subset(["image2d", "labels2d", "points_0", "circles", "table"]) assert subset.path is None subset.path = Path(f) @@ -850,9 +799,7 @@ def test_incremental_io_valid_name(full_sdata: SpatialData) -> None: @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_incremental_io_attrs( - points: SpatialData, sdata_container_format: SpatialDataContainerFormatType -) -> None: +def test_incremental_io_attrs(points: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") my_attrs = {"a": "b", "c": 1} @@ -879,9 +826,7 @@ def test_incremental_io_attrs( cached_sdata_blobs = blobs() -@pytest.mark.parametrize( - "element_name", ["image2d", "labels2d", "points_0", "circles", "table"] -) +@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_delete_element_from_disk( full_sdata, @@ -889,9 +834,7 @@ def test_delete_element_from_disk( sdata_container_format: SpatialDataContainerFormatType, ) -> None: # can't delete an element for a SpatialData object without associated Zarr store - with pytest.raises( - ValueError, match="The SpatialData object is not backed by a Zarr store." - ): + with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): full_sdata.delete_element_from_disk("image2d") with tempfile.TemporaryDirectory() as tmpdir: @@ -919,9 +862,7 @@ def test_delete_element_from_disk( # can delete an element present both in-memory and on-disk full_sdata.delete_element_from_disk(element_name) - only_in_memory, only_on_disk = ( - full_sdata._symmetric_difference_with_zarr_store() - ) + only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() element_type = full_sdata._element_type_from_element_name(element_name) element_path = f"{element_type}/{element_name}" assert element_path in only_in_memory @@ -936,9 +877,7 @@ def test_delete_element_from_disk( assert element_path not in on_disk -@pytest.mark.parametrize( - "element_name", ["image2d", "labels2d", "points_0", "circles", "table"] -) +@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_element_already_on_disk_different_type( full_sdata, @@ -992,9 +931,7 @@ def test_writing_invalid_name(tmp_path: Path): invalid_sdata.images.data[""] = next(iter(_get_images().values())) invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) - invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next( - iter(_get_shapes().values()) - ) + invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") with pytest.raises(ValueError, match="Name (must|cannot)"): @@ -1005,9 +942,7 @@ def test_writing_valid_table_name_invalid_table(tmp_path: Path): # also try with a valid table name but invalid table # testing just one case, all the cases are in test_table_model_invalid_names() invalid_sdata = SpatialData() - invalid_sdata.tables.data["valid_name"] = AnnData( - np.array([[0]]), layers={"invalid name": np.array([[0]])} - ) + invalid_sdata.tables.data["valid_name"] = AnnData(np.array([[0]]), layers={"invalid name": np.array([[0]])}) with pytest.raises(ValueError, match="Name (must|cannot)"): invalid_sdata.write(tmp_path / "data.zarr") @@ -1020,9 +955,7 @@ def test_incremental_writing_invalid_name(tmp_path: Path): invalid_sdata.images.data[""] = next(iter(_get_images().values())) invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) - invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next( - iter(_get_shapes().values()) - ) + invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") for element_type in ["images", "labels", "points", "shapes", "tables"]: @@ -1037,9 +970,7 @@ def test_incremental_writing_valid_table_name_invalid_table(tmp_path: Path): # testing just one case, all the cases are in test_table_model_invalid_names() invalid_sdata = SpatialData() invalid_sdata.write(tmp_path / "data2.zarr") - invalid_sdata.tables.data["valid_name"] = AnnData( - np.array([[0]]), layers={"invalid name": np.array([[0]])} - ) + invalid_sdata.tables.data["valid_name"] = AnnData(np.array([[0]]), layers={"invalid name": np.array([[0]])}) with pytest.raises(ValueError, match="Name (must|cannot)"): invalid_sdata.write_element("valid_name") @@ -1059,19 +990,13 @@ def test_reading_invalid_name(tmp_path: Path): ) valid_sdata.write(tmp_path / "data.zarr") # Circumvent validation at construction time and check validation happens again at writing time. - (tmp_path / "data.zarr/points" / points_name).rename( - tmp_path / "data.zarr/points" / "has whitespace" - ) + (tmp_path / "data.zarr/points" / points_name).rename(tmp_path / "data.zarr/points" / "has whitespace") # This one is not allowed on windows - (tmp_path / "data.zarr/shapes" / shapes_name).rename( - tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()+,@" - ) + (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()+,@") # We do this as the key of the element is otherwise not in the consolidated metadata, leading to an error. valid_sdata.write_consolidated_metadata() - with pytest.raises( - ValidationError, match="Cannot construct SpatialData" - ) as exc_info: + with pytest.raises(ValidationError, match="Cannot construct SpatialData") as exc_info: read_zarr(tmp_path / "data.zarr") actual_message = str(exc_info.value) @@ -1084,14 +1009,10 @@ def test_reading_invalid_name(tmp_path: Path): @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_write_store_unconsolidated_and_read( - full_sdata, sdata_container_format: SpatialDataContainerFormatType -): +def test_write_store_unconsolidated_and_read(full_sdata, sdata_container_format: SpatialDataContainerFormatType): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "data.zarr" - full_sdata.write( - path, consolidate_metadata=False, sdata_formats=sdata_container_format - ) + full_sdata.write(path, consolidate_metadata=False, sdata_formats=sdata_container_format) group = zarr.open_group(path, mode="r") assert group.metadata.consolidated_metadata is None @@ -1100,9 +1021,7 @@ def test_write_store_unconsolidated_and_read( @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) -def test_can_read_sdata_with_reconsolidation( - full_sdata, sdata_container_format: SpatialDataContainerFormatType -): +def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format: SpatialDataContainerFormatType): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "data.zarr" full_sdata.write(path, sdata_formats=sdata_container_format) From 120d11dc7386875f8d36f87f8eec1963b5e097b6 Mon Sep 17 00:00:00 2001 From: Tomatokeftes <129113023+Tomatokeftes@users.noreply.github.com> Date: Tue, 27 Jan 2026 12:05:03 +0100 Subject: [PATCH 3/8] fix: address pre-commit linting issues - Simplify if/return pattern in _is_lazy_anndata (SIM103) - Add missing TableModel import in test fixture (F821) - Use modern np.random.Generator instead of np.random.rand (NPY002) --- src/spatialdata/models/models.py | 4 +--- tests/io/test_readwrite.py | 5 ++++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 42be9cddd..36a3a973d 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -67,9 +67,7 @@ def _is_lazy_anndata(adata: AnnData) -> bool: True if the AnnData is lazily loaded, False otherwise. """ # Check if obs is not a pandas DataFrame (lazy AnnData uses xarray Dataset2D) - if not isinstance(adata.obs, pd.DataFrame): - return True - return False + return not isinstance(adata.obs, pd.DataFrame) def _parse_transformations(element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None) -> None: diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 80e9b5da2..4973a4d7d 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -1124,9 +1124,12 @@ class TestLazyTableLoading: @pytest.fixture def sdata_with_table(self) -> SpatialData: """Create a SpatialData object with a simple table for testing.""" + from spatialdata.models import TableModel + + rng = default_rng(42) table = TableModel.parse( AnnData( - X=np.random.rand(100, 50), + X=rng.random((100, 50)), obs=pd.DataFrame( { "region": pd.Categorical(["region1"] * 100), From 209097d1491582bec7a28b0b1863d96c797eeda8 Mon Sep 17 00:00:00 2001 From: Tomatokeftes <129113023+Tomatokeftes@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:23:23 +0100 Subject: [PATCH 4/8] fix: handle lazy AnnData obs conversion in query operations When lazy AnnData objects (from anndata.experimental.read_lazy) are subset, their obs attribute is a Dataset2D object, not a pandas DataFrame. Using pd.DataFrame(table.obs) produces a malformed DataFrame. This fix uses table.obs.to_memory() for lazy tables to properly convert Dataset2D to DataFrame while preserving all column data. Files modified: - relational_query.py: _filter_table_by_element_names, _filter_table_by_elements, get_values - _utils.py: _inplace_fix_subset_categorical_obs --- .../_core/query/relational_query.py | 23 ++++++++++++++++--- src/spatialdata/_utils.py | 16 +++++++++++-- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 01cfb745b..6f4f03cab 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -75,8 +75,14 @@ def _filter_table_by_element_names(table: AnnData | None, element_names: str | l return None table_mapping_metadata = table.uns[TableModel.ATTRS_KEY] region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY] - table.obs = pd.DataFrame(table.obs) + # Filter first, then materialize obs to avoid shape mismatch with lazy tables table = table[table.obs[region_key].isin(element_names)].copy() + # Handle lazy tables (Dataset2D) vs eager tables (DataFrame) + if isinstance(table.obs, pd.DataFrame): + table.obs = pd.DataFrame(table.obs) + else: + # Lazy AnnData uses Dataset2D which needs to_memory() to convert properly + table.obs = table.obs.to_memory() table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = table.obs[region_key].unique().tolist() return table @@ -196,8 +202,14 @@ def _filter_table_by_elements( indices = ((table.obs[region_key] == name) & (table.obs[instance_key].isin(instances))).to_numpy() to_keep = to_keep | indices original_table = table - table.obs = pd.DataFrame(table.obs) + # Subset first, then materialize obs to avoid shape mismatch with lazy tables table = table[to_keep, :] + # Handle lazy tables (Dataset2D) vs eager tables (DataFrame) + if isinstance(table.obs, pd.DataFrame): + table.obs = pd.DataFrame(table.obs) + else: + # Lazy AnnData uses Dataset2D which needs to_memory() to convert properly + table.obs = table.obs.to_memory() if match_rows: assert instances is not None assert isinstance(instances, np.ndarray) @@ -1066,7 +1078,12 @@ def get_values( if origin == "obs": df = obs[value_key_values].copy() if origin == "var": - matched_table.obs = pd.DataFrame(obs) + # Handle lazy tables (Dataset2D) vs eager tables (DataFrame) + if isinstance(obs, pd.DataFrame): + matched_table.obs = pd.DataFrame(obs) + else: + # Lazy AnnData uses Dataset2D which needs to_memory() to convert properly + matched_table.obs = obs.to_memory() if table_layer is None: x = matched_table[:, value_key_values].X else: diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index f5bc12579..58a11f856 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -210,11 +210,23 @@ def _inplace_fix_subset_categorical_obs(subset_adata: AnnData, original_adata: A """ if not hasattr(subset_adata, "obs") or not hasattr(original_adata, "obs"): return - obs = pd.DataFrame(subset_adata.obs) + # Handle lazy tables (Dataset2D) vs eager tables (DataFrame) + if isinstance(subset_adata.obs, pd.DataFrame): + obs = pd.DataFrame(subset_adata.obs) + else: + # Lazy AnnData uses Dataset2D which needs to_memory() to convert properly + obs = subset_adata.obs.to_memory() + + # Also handle lazy original_adata.obs + if isinstance(original_adata.obs, pd.DataFrame): + original_obs = original_adata.obs + else: + original_obs = original_adata.obs.to_memory() + for column in obs.columns: is_categorical = isinstance(obs[column].dtype, pd.CategoricalDtype) if is_categorical: - c = obs[column].cat.set_categories(original_adata.obs[column].cat.categories) + c = obs[column].cat.set_categories(original_obs[column].cat.categories) obs[column] = c subset_adata.obs = obs From e182dde72026524497eae54e662bb2525ffc116a Mon Sep 17 00:00:00 2001 From: Tomatokeftes <129113023+Tomatokeftes@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:42:33 +0100 Subject: [PATCH 5/8] style: fix line length and use ternary operators per ruff --- src/spatialdata/_utils.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 58a11f856..a0e855502 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -210,18 +210,11 @@ def _inplace_fix_subset_categorical_obs(subset_adata: AnnData, original_adata: A """ if not hasattr(subset_adata, "obs") or not hasattr(original_adata, "obs"): return - # Handle lazy tables (Dataset2D) vs eager tables (DataFrame) - if isinstance(subset_adata.obs, pd.DataFrame): - obs = pd.DataFrame(subset_adata.obs) - else: - # Lazy AnnData uses Dataset2D which needs to_memory() to convert properly - obs = subset_adata.obs.to_memory() - - # Also handle lazy original_adata.obs - if isinstance(original_adata.obs, pd.DataFrame): - original_obs = original_adata.obs - else: - original_obs = original_adata.obs.to_memory() + # Handle lazy tables (Dataset2D vs DataFrame). Lazy AnnData uses Dataset2D which needs to_memory() + obs = pd.DataFrame(subset_adata.obs) if isinstance(subset_adata.obs, pd.DataFrame) else subset_adata.obs.to_memory() + original_obs = ( + original_adata.obs if isinstance(original_adata.obs, pd.DataFrame) else original_adata.obs.to_memory() + ) for column in obs.columns: is_categorical = isinstance(obs[column].dtype, pd.CategoricalDtype) From 41fa1b561e5b3f3c49dc6e0e772a7c1199d895ae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 09:34:09 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .claude/settings.local.json | 24 +++++++------- spatialdata_pr_context.md | 65 ++++++++++++++++++++----------------- 2 files changed, 48 insertions(+), 41 deletions(-) diff --git a/.claude/settings.local.json b/.claude/settings.local.json index fd009d5d8..71aa0256f 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -1,14 +1,14 @@ { - "permissions": { - "allow": [ - "Bash(gh pr view 1055 --repo scverse/spatialdata --json title,state,url,headRefName,body)", - "Bash(python -c \":*)", - "Bash(python -m pytest tests/io/test_readwrite.py -x -v --tb=short)", - "Bash(python -m pytest tests/io/test_readwrite.py -x -v --tb=short -k \"lazy\")", - "Bash(python -m pytest tests/io/test_readwrite.py::TestReadWrite::test_io_and_lazy_loading_raster -v --tb=long -k \"sdata_container_format1\")", - "Bash(python -m pytest tests/io/test_readwrite.py -v --tb=short -k \"TestLazyTableLoading\")", - "Bash(python -m pytest tests/core/query/ -v --tb=short)", - "Bash(python -W all -c \":*)" - ] - } + "permissions": { + "allow": [ + "Bash(gh pr view 1055 --repo scverse/spatialdata --json title,state,url,headRefName,body)", + "Bash(python -c \":*)", + "Bash(python -m pytest tests/io/test_readwrite.py -x -v --tb=short)", + "Bash(python -m pytest tests/io/test_readwrite.py -x -v --tb=short -k \"lazy\")", + "Bash(python -m pytest tests/io/test_readwrite.py::TestReadWrite::test_io_and_lazy_loading_raster -v --tb=long -k \"sdata_container_format1\")", + "Bash(python -m pytest tests/io/test_readwrite.py -v --tb=short -k \"TestLazyTableLoading\")", + "Bash(python -m pytest tests/core/query/ -v --tb=short)", + "Bash(python -W all -c \":*)" + ] + } } diff --git a/spatialdata_pr_context.md b/spatialdata_pr_context.md index 2b6339f34..c0f5ebb19 100644 --- a/spatialdata_pr_context.md +++ b/spatialdata_pr_context.md @@ -1,6 +1,7 @@ # Context for SpatialData PR #1055: Lazy Table Loading ## PR Link + https://github.com/scverse/spatialdata/pull/1055 ## Background @@ -8,6 +9,7 @@ https://github.com/scverse/spatialdata/pull/1055 This PR adds `lazy: bool = False` to `SpatialData.read()` and `read_zarr()` so that AnnData tables are loaded lazily via dask, keeping large matrices out of memory. This matters for Mass Spectrometry Imaging (MSI) datasets where tables can have millions of pixels and hundreds of thousands of m/z bins (e.g. 179,389 x 460,517 = ~40GB dense, though stored as sparse CSC). ### What the PR already does (6 commits): + - `io_table.py`: Added lazy parameter, uses `anndata.experimental.read_lazy()` when `lazy=True` - `io_zarr.py`: Passes `lazy` through to `_read_table()` - `spatialdata.py`: Passes `lazy` through to `read_zarr()`; skips eager validation for lazy tables @@ -17,7 +19,9 @@ This PR adds `lazy: bool = False` to `SpatialData.read()` and `read_zarr()` so t - `tests/io/test_readwrite.py`: Added lazy loading tests ### Benchmark (from PR): + 100,000 pixels x 100,000 m/z bins, 3,000 peaks/pixel (~296M non-zeros): + - Memory: 15.4 MB vs 2,270.7 MB (99% savings) - Load time: 0.13s vs 1.57s (12x faster) @@ -32,6 +36,7 @@ We tested the PR against real-world MSI datasets and found three important thing The original concern was that `read_lazy()` wraps the entire sparse matrix as a single dask chunk, defeating lazy loading. **This is NOT true** with anndata 0.13. `read_lazy()` internally calls `read_elem_lazy(elem)` without passing `chunks`, which triggers these defaults in `read_elem_lazy`: + - CSC sparse: `(n_rows, 1000)` -- 1000-column chunks - CSR sparse: `(1000, n_cols)` -- 1000-row chunks - Dense: uses on-disk zarr chunk layout @@ -63,10 +68,10 @@ Error raised while reading key 'y' of from /obs **Root cause**: anndata has two separate IO registries: -| Registry | Readers for zarr.Array | Purpose | -|---|---|---| -| **Eager** (`_REGISTRY`) | 8 readers | Used by `read_elem()`, `read_zarr()` | -| **Lazy** (`_LAZY_REGISTRY`) | 2 readers | Used by `read_elem_lazy()`, `read_lazy()` | +| Registry | Readers for zarr.Array | Purpose | +| --------------------------- | ---------------------- | ----------------------------------------- | +| **Eager** (`_REGISTRY`) | 8 readers | Used by `read_elem()`, `read_zarr()` | +| **Lazy** (`_LAZY_REGISTRY`) | 2 readers | Used by `read_elem_lazy()`, `read_lazy()` | The eager registry has a catch-all reader for `IOSpec('', '')` (plain arrays with no encoding metadata). The lazy registry does NOT. So when obs columns are stored without `encoding-type` attributes, eager reading works but lazy reading crashes. @@ -87,13 +92,13 @@ We tested what anndata itself writes: Then we checked all real datasets: -| Dataset | obs encoding metadata | Created by | -|---|---|---| -| Hippocampus.zarr | **MISSING** on all non-categorical columns | Thyra streaming COO converter | -| mouse_brain.zarr | **MISSING** on all non-categorical columns | Thyra streaming COO converter | -| sample_A.zarr | `'array'` / `'string-array'` on all columns | Standard anndata `write_zarr()` | -| sample_B.zarr | `'array'` / `'string-array'` on all columns | Standard anndata `write_zarr()` | -| xenium.zarr | `'array'` / `'string-array'` on all columns | Standard anndata `write_zarr()` | +| Dataset | obs encoding metadata | Created by | +| ---------------- | ------------------------------------------- | ------------------------------- | +| Hippocampus.zarr | **MISSING** on all non-categorical columns | Thyra streaming COO converter | +| mouse_brain.zarr | **MISSING** on all non-categorical columns | Thyra streaming COO converter | +| sample_A.zarr | `'array'` / `'string-array'` on all columns | Standard anndata `write_zarr()` | +| sample_B.zarr | `'array'` / `'string-array'` on all columns | Standard anndata `write_zarr()` | +| xenium.zarr | `'array'` / `'string-array'` on all columns | Standard anndata `write_zarr()` | **Conclusion**: The datasets created through anndata's standard `write_zarr()` have proper encoding metadata and `read_lazy()` would work on them. The datasets created by Thyra's streaming COO converter write raw zarr arrays without the encoding attributes. The fix belongs in the writer. @@ -120,18 +125,19 @@ This is not a bug -- it's how `read_lazy()` works. But it's a gotcha for downstr Fix Thyra's streaming COO converter to write proper anndata encoding attributes on obs columns. When writing a zarr array for an obs column, add: - For numeric arrays (int32, float64, etc.): - ```python - arr = zarr.open_array(path, ...) - arr.attrs['encoding-type'] = 'array' - arr.attrs['encoding-version'] = '0.2.0' - ``` + + ```python + arr = zarr.open_array(path, ...) + arr.attrs['encoding-type'] = 'array' + arr.attrs['encoding-version'] = '0.2.0' + ``` - For string arrays: - ```python - arr = zarr.open_array(path, ...) - arr.attrs['encoding-type'] = 'string-array' - arr.attrs['encoding-version'] = '0.2.0' - ``` + ```python + arr = zarr.open_array(path, ...) + arr.attrs['encoding-type'] = 'string-array' + arr.attrs['encoding-version'] = '0.2.0' + ``` With this fix, `read_lazy()` works on all datasets, the PR's current implementation using `read_lazy()` is correct, and the chunking is already optimal for sparse data. @@ -151,6 +157,7 @@ This sidesteps the crash AND the uns dask-wrapping issue, but it's working aroun ### The PR as-is If Option A is done, the PR's current approach (`read_lazy()` on the whole table) is already correct: + - Chunking works for sparse data - Query API fixes are in place - Validation skipping is correct @@ -179,11 +186,11 @@ data = table.X[:, feature_indices].compute() ## Key Files -| File | Purpose | -|------|---------| -| `src/spatialdata/_io/io_table.py` | Core lazy loading logic | -| `src/spatialdata/_io/io_zarr.py` | Passes lazy parameter through | -| `src/spatialdata/_core/spatialdata.py` | Entry point for `SpatialData.read()` | -| `src/spatialdata/_core/query/relational_query.py` | Query API fixes for Dataset2D obs | -| `src/spatialdata/_utils.py` | Helper for lazy AnnData detection | -| `src/spatialdata/models/models.py` | Validation skip for lazy tables | +| File | Purpose | +| ------------------------------------------------- | ------------------------------------ | +| `src/spatialdata/_io/io_table.py` | Core lazy loading logic | +| `src/spatialdata/_io/io_zarr.py` | Passes lazy parameter through | +| `src/spatialdata/_core/spatialdata.py` | Entry point for `SpatialData.read()` | +| `src/spatialdata/_core/query/relational_query.py` | Query API fixes for Dataset2D obs | +| `src/spatialdata/_utils.py` | Helper for lazy AnnData detection | +| `src/spatialdata/models/models.py` | Validation skip for lazy tables | From 18991af8745d4281af46b8d676504f990c95733e Mon Sep 17 00:00:00 2001 From: Tomatokeftes <129113023+Tomatokeftes@users.noreply.github.com> Date: Wed, 4 Mar 2026 10:50:18 +0100 Subject: [PATCH 7/8] fix: remove silent fallback for lazy reading, raise on failure instead Removes the try/except fallback that caught generic Exception during lazy table reading. If lazy=True and read_lazy() fails, the error now propagates directly -- silently falling back to eager loading could crash devices on large (40GB+) datasets. --- src/spatialdata/_io/io_table.py | 33 +++++++-------------------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 6a9936bd6..19fb273a2 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from pathlib import Path import numpy as np @@ -26,38 +25,20 @@ def _read_table(store: str | Path, lazy: bool = False) -> AnnData: If True, read the table lazily using ``anndata.experimental.read_lazy``. This keeps large matrices (X, layers) as dask arrays backed by zarr, so they are only loaded into memory on demand. Requires anndata >= 0.12. - If the installed version does not support lazy reading, or if lazy - reading fails (e.g. due to missing encoding metadata on disk), a - warning is raised and the table is read eagerly. Returns ------- The AnnData table, either lazily loaded or in-memory. + + Raises + ------ + ImportError + If ``lazy=True`` but anndata >= 0.12 is not installed. """ if lazy: - try: - from anndata.experimental import read_lazy + from anndata.experimental import read_lazy - table = read_lazy(str(store)) - except ImportError: - warnings.warn( - "Lazy reading of tables requires anndata >= 0.12. " - "Falling back to eager reading. To enable lazy reading, " - "upgrade anndata with: pip install 'anndata>=0.12'", - UserWarning, - stacklevel=2, - ) - table = read_anndata_zarr(str(store)) - except Exception as e: - warnings.warn( - f"Lazy reading failed: {e}. " - "This can happen when the zarr store was written without " - "anndata encoding metadata on obs/var columns. " - "Falling back to eager reading.", - UserWarning, - stacklevel=2, - ) - table = read_anndata_zarr(str(store)) + table = read_lazy(str(store)) else: table = read_anndata_zarr(str(store)) From 9c0764045389f20fa8e29da242b393c7bc0a0d54 Mon Sep 17 00:00:00 2001 From: Tomatokeftes <129113023+Tomatokeftes@users.noreply.github.com> Date: Wed, 4 Mar 2026 10:54:05 +0100 Subject: [PATCH 8/8] chore: remove local dev files from PR and add to .gitignore Remove .claude/settings.local.json and spatialdata_pr_context.md from tracking -- these are local development artifacts not meant for the PR. --- .claude/settings.local.json | 14 --- .gitignore | 4 + spatialdata_pr_context.md | 196 ------------------------------------ 3 files changed, 4 insertions(+), 210 deletions(-) delete mode 100644 .claude/settings.local.json delete mode 100644 spatialdata_pr_context.md diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index 71aa0256f..000000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(gh pr view 1055 --repo scverse/spatialdata --json title,state,url,headRefName,body)", - "Bash(python -c \":*)", - "Bash(python -m pytest tests/io/test_readwrite.py -x -v --tb=short)", - "Bash(python -m pytest tests/io/test_readwrite.py -x -v --tb=short -k \"lazy\")", - "Bash(python -m pytest tests/io/test_readwrite.py::TestReadWrite::test_io_and_lazy_loading_raster -v --tb=long -k \"sdata_container_format1\")", - "Bash(python -m pytest tests/io/test_readwrite.py -v --tb=short -k \"TestLazyTableLoading\")", - "Bash(python -m pytest tests/core/query/ -v --tb=short)", - "Bash(python -W all -c \":*)" - ] - } -} diff --git a/.gitignore b/.gitignore index 25836a3b1..ac8373fad 100644 --- a/.gitignore +++ b/.gitignore @@ -54,3 +54,7 @@ node_modules/ .mypy_cache .ruff_cache uv.lock + +# Claude Code local files +.claude/ +spatialdata_pr_context.md diff --git a/spatialdata_pr_context.md b/spatialdata_pr_context.md deleted file mode 100644 index c0f5ebb19..000000000 --- a/spatialdata_pr_context.md +++ /dev/null @@ -1,196 +0,0 @@ -# Context for SpatialData PR #1055: Lazy Table Loading - -## PR Link - -https://github.com/scverse/spatialdata/pull/1055 - -## Background - -This PR adds `lazy: bool = False` to `SpatialData.read()` and `read_zarr()` so that AnnData tables are loaded lazily via dask, keeping large matrices out of memory. This matters for Mass Spectrometry Imaging (MSI) datasets where tables can have millions of pixels and hundreds of thousands of m/z bins (e.g. 179,389 x 460,517 = ~40GB dense, though stored as sparse CSC). - -### What the PR already does (6 commits): - -- `io_table.py`: Added lazy parameter, uses `anndata.experimental.read_lazy()` when `lazy=True` -- `io_zarr.py`: Passes `lazy` through to `_read_table()` -- `spatialdata.py`: Passes `lazy` through to `read_zarr()`; skips eager validation for lazy tables -- `relational_query.py`: Fixed query operations to handle `Dataset2D` obs from lazy tables -- `_utils.py`: Fixed `_inplace_fix_subset_categorical_obs()` for lazy tables -- `models.py`: Modified validation to skip eager checks for lazy tables -- `tests/io/test_readwrite.py`: Added lazy loading tests - -### Benchmark (from PR): - -100,000 pixels x 100,000 m/z bins, 3,000 peaks/pixel (~296M non-zeros): - -- Memory: 15.4 MB vs 2,270.7 MB (99% savings) -- Load time: 0.13s vs 1.57s (12x faster) - ---- - -## Investigation Results - -We tested the PR against real-world MSI datasets and found three important things. - -### Finding 1: Chunking is already correct -- no fix needed - -The original concern was that `read_lazy()` wraps the entire sparse matrix as a single dask chunk, defeating lazy loading. **This is NOT true** with anndata 0.13. - -`read_lazy()` internally calls `read_elem_lazy(elem)` without passing `chunks`, which triggers these defaults in `read_elem_lazy`: - -- CSC sparse: `(n_rows, 1000)` -- 1000-column chunks -- CSR sparse: `(1000, n_cols)` -- 1000-row chunks -- Dense: uses on-disk zarr chunk layout - -Verified on real data (179,389 x 460,517 CSC sparse): - -```python -from anndata.experimental import read_elem_lazy -import zarr - -x_group = zarr.open_group("dataset.zarr/tables/my_table/X", mode="r") -lazy_X = read_elem_lazy(x_group) -print(lazy_X.chunks) -# ((179389,), (1000, 1000, 1000, ... 1000, 517)) -# 461 column chunks -- correct! -``` - -Column slicing only reads the relevant zarr chunks from disk. No rechunking needed. - -### Finding 2: `read_lazy()` crashes on some real datasets -- but it's a writer problem - -`read_lazy()` crashes on some datasets with: - -``` -IORegistryError: No read method registered for IOSpec(encoding_type='', encoding_version='') -from . -Error raised while reading key 'y' of from /obs -``` - -**Root cause**: anndata has two separate IO registries: - -| Registry | Readers for zarr.Array | Purpose | -| --------------------------- | ---------------------- | ----------------------------------------- | -| **Eager** (`_REGISTRY`) | 8 readers | Used by `read_elem()`, `read_zarr()` | -| **Lazy** (`_LAZY_REGISTRY`) | 2 readers | Used by `read_elem_lazy()`, `read_lazy()` | - -The eager registry has a catch-all reader for `IOSpec('', '')` (plain arrays with no encoding metadata). The lazy registry does NOT. So when obs columns are stored without `encoding-type` attributes, eager reading works but lazy reading crashes. - -**The key question was: is this a reader bug or a writer bug?** - -We tested what anndata itself writes: - -```python -# When anndata writes an int32 obs column via write_zarr(): -# encoding-type='array', encoding-version='0.2.0' <-- STAMPED -# -# When anndata writes a string obs column via write_zarr(): -# encoding-type='string-array', encoding-version='0.2.0' <-- STAMPED -# -# When anndata writes a categorical: -# encoding-type='categorical', encoding-version='0.2.0' <-- STAMPED -``` - -Then we checked all real datasets: - -| Dataset | obs encoding metadata | Created by | -| ---------------- | ------------------------------------------- | ------------------------------- | -| Hippocampus.zarr | **MISSING** on all non-categorical columns | Thyra streaming COO converter | -| mouse_brain.zarr | **MISSING** on all non-categorical columns | Thyra streaming COO converter | -| sample_A.zarr | `'array'` / `'string-array'` on all columns | Standard anndata `write_zarr()` | -| sample_B.zarr | `'array'` / `'string-array'` on all columns | Standard anndata `write_zarr()` | -| xenium.zarr | `'array'` / `'string-array'` on all columns | Standard anndata `write_zarr()` | - -**Conclusion**: The datasets created through anndata's standard `write_zarr()` have proper encoding metadata and `read_lazy()` would work on them. The datasets created by Thyra's streaming COO converter write raw zarr arrays without the encoding attributes. The fix belongs in the writer. - -### Finding 3: `table.uns` values become dask arrays after `read_lazy()` - -All values in `table.uns` (unstructured metadata like mean spectra, peak lists, parameter dicts) become dask arrays: - -```python -lazy = read_lazy(zarr_path) -type(lazy.uns['mean_spectra']['global_mean']['mz']) # dask.array.core.Array -type(lazy.uns['peak_lists']['auto_peaks']['indices']) # dask.array.core.Array -``` - -`np.array()` on a dask array silently materializes it (no crash), but direct JSON/Pydantic serialization fails. Downstream code needs `.compute()` or `np.asarray()` before serialization. - -This is not a bug -- it's how `read_lazy()` works. But it's a gotcha for downstream consumers. If the writer is fixed and `read_lazy()` is used, this needs to be documented or handled. - ---- - -## What Needs to Be Done - -### Option A: Fix the writer (recommended) - -Fix Thyra's streaming COO converter to write proper anndata encoding attributes on obs columns. When writing a zarr array for an obs column, add: - -- For numeric arrays (int32, float64, etc.): - - ```python - arr = zarr.open_array(path, ...) - arr.attrs['encoding-type'] = 'array' - arr.attrs['encoding-version'] = '0.2.0' - ``` - -- For string arrays: - ```python - arr = zarr.open_array(path, ...) - arr.attrs['encoding-type'] = 'string-array' - arr.attrs['encoding-version'] = '0.2.0' - ``` - -With this fix, `read_lazy()` works on all datasets, the PR's current implementation using `read_lazy()` is correct, and the chunking is already optimal for sparse data. - -You would also want a migration script to stamp the encoding metadata on existing datasets (Hippocampus.zarr, mouse_brain.zarr) so they work with lazy loading too. - -### Option B: Piecewise loading in SpatialData (workaround) - -If the writer can't be fixed (or for backwards compatibility with existing datasets), change `_read_table()` in SpatialData to build the AnnData piecewise: - -1. Read obs, var, uns eagerly via `read_elem()` (handles missing encoding metadata; obs/var/uns are small) -2. Read X lazily via `read_elem_lazy()` (proper chunking automatic for sparse) -3. Read layers lazily if they exist -4. Assemble into AnnData - -This sidesteps the crash AND the uns dask-wrapping issue, but it's working around data that wasn't written to anndata spec. - -### The PR as-is - -If Option A is done, the PR's current approach (`read_lazy()` on the whole table) is already correct: - -- Chunking works for sparse data -- Query API fixes are in place -- Validation skipping is correct -- The uns dask-wrapping is the only remaining gotcha to document - -The main remaining work would be improving test coverage (currently 84% patch, 7 uncovered lines). - ---- - -## Known Downstream Limitations (not SpatialData bugs) - -### dask + scipy.sparse aggregate bug - -`.mean(axis=0)` and `.sum(axis=0)` fail with a `keepdims` error when dask wraps scipy.sparse chunks. Workaround: iterate over column chunks manually. - -### sklearn/UMAP require materialized arrays - -Pattern is "deferred materialization" -- subset features BEFORE `.compute()`: - -```python -# Materializes only ~200 columns instead of 460,517 -data = table.X[:, feature_indices].compute() -``` - ---- - -## Key Files - -| File | Purpose | -| ------------------------------------------------- | ------------------------------------ | -| `src/spatialdata/_io/io_table.py` | Core lazy loading logic | -| `src/spatialdata/_io/io_zarr.py` | Passes lazy parameter through | -| `src/spatialdata/_core/spatialdata.py` | Entry point for `SpatialData.read()` | -| `src/spatialdata/_core/query/relational_query.py` | Query API fixes for Dataset2D obs | -| `src/spatialdata/_utils.py` | Helper for lazy AnnData detection | -| `src/spatialdata/models/models.py` | Validation skip for lazy tables |