diff --git a/.gitignore b/.gitignore index 25836a3b1..ac8373fad 100644 --- a/.gitignore +++ b/.gitignore @@ -54,3 +54,7 @@ node_modules/ .mypy_cache .ruff_cache uv.lock + +# Claude Code local files +.claude/ +spatialdata_pr_context.md diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 8232f74c6..22fa21de8 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -77,8 +77,14 @@ def _filter_table_by_element_names(table: AnnData | None, element_names: str | l return None table_mapping_metadata = table.uns[TableModel.ATTRS_KEY] region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY] - table.obs = pd.DataFrame(table.obs) + # Filter first, then materialize obs to avoid shape mismatch with lazy tables table = table[table.obs[region_key].isin(element_names)].copy() + # Handle lazy tables (Dataset2D) vs eager tables (DataFrame) + if isinstance(table.obs, pd.DataFrame): + table.obs = pd.DataFrame(table.obs) + else: + # Lazy AnnData uses Dataset2D which needs to_memory() to convert properly + table.obs = table.obs.to_memory() table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = table.obs[region_key].unique().tolist() return table @@ -198,8 +204,14 @@ def _filter_table_by_elements( indices = ((table.obs[region_key] == name) & (table.obs[instance_key].isin(instances))).to_numpy() to_keep = to_keep | indices original_table = table - table.obs = pd.DataFrame(table.obs) + # Subset first, then materialize obs to avoid shape mismatch with lazy tables table = table[to_keep, :] + # Handle lazy tables (Dataset2D) vs eager tables (DataFrame) + if isinstance(table.obs, pd.DataFrame): + table.obs = pd.DataFrame(table.obs) + else: + # Lazy AnnData uses Dataset2D which needs to_memory() to convert properly + table.obs = table.obs.to_memory() if match_rows: assert instances is not None assert isinstance(instances, np.ndarray) @@ -1068,7 +1080,12 @@ def get_values( if origin == "obs": df = obs[value_key_values].copy() if origin == "var": - matched_table.obs = pd.DataFrame(obs) + # Handle lazy tables (Dataset2D) vs eager tables (DataFrame) + if isinstance(obs, pd.DataFrame): + matched_table.obs = pd.DataFrame(obs) + else: + # Lazy AnnData uses Dataset2D which needs to_memory() to convert properly + matched_table.obs = obs.to_memory() if table_layer is None: x = matched_table[:, value_key_values].X else: diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 739b225fe..4db3fdf30 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -52,10 +52,7 @@ if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest - from spatialdata._io.format import ( - SpatialDataContainerFormatType, - SpatialDataFormatType, - ) + from spatialdata._io.format import SpatialDataContainerFormatType, SpatialDataFormatType class SpatialData: @@ -232,9 +229,7 @@ def get_annotated_regions(table: AnnData) -> list[str]: ------- The annotated regions. """ - from spatialdata.models.models import ( - _get_region_metadata_from_region_key_column, - ) + from spatialdata.models.models import _get_region_metadata_from_region_key_column return _get_region_metadata_from_region_key_column(table) @@ -691,18 +686,14 @@ def _filter_tables( continue # each mode here requires paths or elements, using assert here to avoid mypy errors. if by == "cs": - from spatialdata._core.query.relational_query import ( - _filter_table_by_element_names, - ) + from spatialdata._core.query.relational_query import _filter_table_by_element_names assert element_names is not None table = _filter_table_by_element_names(table, element_names) if table is not None and len(table) != 0: tables[table_name] = table elif by == "elements": - from spatialdata._core.query.relational_query import ( - _filter_table_by_elements, - ) + from spatialdata._core.query.relational_query import _filter_table_by_elements assert elements_dict is not None table = _filter_table_by_elements(table, elements_dict=elements_dict) @@ -727,10 +718,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: The method does not allow to rename a coordinate system into an existing one, unless the existing one is also renamed in the same call. """ - from spatialdata.transformations.operations import ( - get_transformation, - set_transformation, - ) + from spatialdata.transformations.operations import get_transformation, set_transformation # check that the rename_dict is valid old_names = self.coordinate_systems @@ -1106,7 +1094,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, update_sdata_path: bool = True, - sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1211,15 +1199,12 @@ def _write_element( ) root_group, element_type_group, element_group = _get_groups_for_element( - zarr_path=zarr_container_path, element_type=element_type, element_name=element_name, use_consolidated=False - ) - from spatialdata._io import ( - write_image, - write_labels, - write_points, - write_shapes, - write_table, + zarr_path=zarr_container_path, + element_type=element_type, + element_name=element_name, + use_consolidated=False, ) + from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table from spatialdata._io.format import _parse_formats if parsed_formats is None: @@ -1266,7 +1251,7 @@ def write_element( self, element_name: str | list[str], overwrite: bool = False, - sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ @@ -1544,7 +1529,10 @@ def write_channel_names(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have the check in the conditional if element_type == "images" and self.path is not None: _, _, element_group = _get_groups_for_element( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name, use_consolidated=False + zarr_path=Path(self.path), + element_type=element_type, + element_name=element_name, + use_consolidated=False, ) from spatialdata._io._utils import overwrite_channel_names @@ -1595,19 +1583,18 @@ def write_transformations(self, element_name: str | None = None) -> None: ) axes = get_axes_names(element) if isinstance(element, DataArray | DataTree): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_raster, - ) + from spatialdata._io._utils import overwrite_coordinate_transformations_raster from spatialdata._io.format import RasterFormats raster_format = RasterFormats[element_group.metadata.attributes["spatialdata_attrs"]["version"]] overwrite_coordinate_transformations_raster( - group=element_group, axes=axes, transformations=transformations, raster_format=raster_format + group=element_group, + axes=axes, + transformations=transformations, + raster_format=raster_format, ) elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_non_raster, - ) + from spatialdata._io._utils import overwrite_coordinate_transformations_non_raster overwrite_coordinate_transformations_non_raster( group=element_group, @@ -1826,6 +1813,7 @@ def read( file_path: str | Path | UPath | zarr.Group, selection: tuple[str] | None = None, reconsolidate_metadata: bool = False, + lazy: bool = False, ) -> SpatialData: """ Read a SpatialData object from a Zarr storage (on-disk or remote). @@ -1838,6 +1826,11 @@ def read( The elements to read (images, labels, points, shapes, table). If None, all elements are read. reconsolidate_metadata If the consolidated metadata store got corrupted this can lead to errors when trying to read the data. + lazy + If True, read tables lazily using anndata.experimental.read_lazy. + This keeps large tables out of memory until needed. Requires anndata >= 0.12. + Note: Images, labels, and points are always read lazily (using Dask). + This parameter only affects tables, which are normally loaded into memory. Returns ------- @@ -1850,7 +1843,7 @@ def read( _write_consolidated_metadata(file_path) - return read_zarr(file_path, selection=selection) + return read_zarr(file_path, selection=selection, lazy=lazy) @property def images(self) -> Images: diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 8cd7b8385..19fb273a2 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -9,18 +9,38 @@ from anndata._io.specs import write_elem as write_adata from ome_zarr.format import Format -from spatialdata._io.format import ( - CurrentTablesFormat, - TablesFormats, - TablesFormatV01, - TablesFormatV02, - _parse_version, -) +from spatialdata._io.format import CurrentTablesFormat, TablesFormats, TablesFormatV01, TablesFormatV02, _parse_version from spatialdata.models import TableModel, get_table_keys -def _read_table(store: str | Path) -> AnnData: - table = read_anndata_zarr(str(store)) +def _read_table(store: str | Path, lazy: bool = False) -> AnnData: + """ + Read a table from a zarr store. + + Parameters + ---------- + store + Path to the zarr store containing the table. + lazy + If True, read the table lazily using ``anndata.experimental.read_lazy``. + This keeps large matrices (X, layers) as dask arrays backed by zarr, + so they are only loaded into memory on demand. Requires anndata >= 0.12. + + Returns + ------- + The AnnData table, either lazily loaded or in-memory. + + Raises + ------ + ImportError + If ``lazy=True`` but anndata >= 0.12 is not installed. + """ + if lazy: + from anndata.experimental import read_lazy + + table = read_lazy(str(store)) + else: + table = read_anndata_zarr(str(store)) f = zarr.open(store, mode="r") version = _parse_version(f, expect_attrs_key=False) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 4c410fab0..0e36b2a04 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -3,6 +3,7 @@ import os import warnings from collections.abc import Callable +from functools import partial from json import JSONDecodeError from pathlib import Path from typing import Any, Literal, cast @@ -17,11 +18,7 @@ from zarr.errors import ArrayNotFoundError from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import ( - BadFileHandleMethod, - _resolve_zarr_store, - handle_read_errors, -) +from spatialdata._io._utils import BadFileHandleMethod, _resolve_zarr_store, handle_read_errors from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes @@ -106,10 +103,7 @@ def get_raster_format_for_read( ------- The ome-zarr format to use for reading the raster element. """ - from spatialdata._io.format import ( - sdata_zarr_version_to_ome_zarr_format, - sdata_zarr_version_to_raster_format, - ) + from spatialdata._io.format import sdata_zarr_version_to_ome_zarr_format, sdata_zarr_version_to_raster_format if sdata_version == "0.1": group_version = group.metadata.attributes["multiscales"][0]["version"] @@ -126,6 +120,7 @@ def read_zarr( store: str | Path | UPath | zarr.Group, selection: None | tuple[str] = None, on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, + lazy: bool = False, ) -> SpatialData: """ Read a SpatialData dataset from a zarr store (on-disk or remote). @@ -149,6 +144,12 @@ def read_zarr( object is returned containing only elements that could be read. Failures can only be determined from the warnings. + lazy + If True, read tables lazily using anndata.experimental.read_lazy. + This keeps large tables out of memory until needed. Requires anndata >= 0.12. + Note: Images, labels, and points are always read lazily (using Dask). + This parameter only affects tables, which are normally loaded into memory. + Returns ------- A SpatialData object. @@ -195,7 +196,7 @@ def read_zarr( "labels": (_read_multiscale, "labels", labels), "points": (_read_points, "points", points), "shapes": (_read_shapes, "shapes", shapes), - "tables": (_read_table, "tables", tables), + "tables": (partial(_read_table, lazy=lazy), "tables", tables), } for group_name, ( read_func, diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 609cd0403..876a2c92a 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -225,11 +225,16 @@ def _inplace_fix_subset_categorical_obs(subset_adata: AnnData, original_adata: A """ if not hasattr(subset_adata, "obs") or not hasattr(original_adata, "obs"): return - obs = pd.DataFrame(subset_adata.obs) + # Handle lazy tables (Dataset2D vs DataFrame). Lazy AnnData uses Dataset2D which needs to_memory() + obs = pd.DataFrame(subset_adata.obs) if isinstance(subset_adata.obs, pd.DataFrame) else subset_adata.obs.to_memory() + original_obs = ( + original_adata.obs if isinstance(original_adata.obs, pd.DataFrame) else original_adata.obs.to_memory() + ) + for column in obs.columns: is_categorical = isinstance(obs[column].dtype, pd.CategoricalDtype) if is_categorical: - c = obs[column].cat.set_categories(original_adata.obs[column].cat.categories) + c = obs[column].cat.set_categories(original_obs[column].cat.categories) obs[column] = c subset_adata.obs = obs diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 2bfcf88ce..4c38c9c8f 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -55,6 +55,25 @@ ATTRS_KEY = "spatialdata_attrs" +def _is_lazy_anndata(adata: AnnData) -> bool: + """Check if an AnnData object is lazily loaded. + + Lazy AnnData objects (from anndata.experimental.read_lazy) have obs/var + stored as xarray Dataset2D instead of pandas DataFrame. + + Parameters + ---------- + adata + The AnnData object to check. + + Returns + ------- + True if the AnnData is lazily loaded, False otherwise. + """ + # Check if obs is not a pandas DataFrame (lazy AnnData uses xarray Dataset2D) + return not isinstance(adata.obs, pd.DataFrame) + + def _parse_transformations(element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None) -> None: _validate_mapping_to_coordinate_system_type(transformations) transformations_in_element = _get_transformations(element) @@ -1053,6 +1072,13 @@ def _validate_table_annotation_metadata(cls, data: AnnData) -> None: raise ValueError(f"`{attr[cls.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.") if attr[cls.INSTANCE_KEY] not in data.obs: raise ValueError(f"`{attr[cls.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.") + + # Skip detailed dtype/value validation for lazy-loaded AnnData + # These checks would trigger data loading, defeating the purpose of lazy loading + # Validation will occur when data is actually computed/accessed + if _is_lazy_anndata(data): + return + instance_col = data.obs[attr[cls.INSTANCE_KEY]] dtype = instance_col.dtype @@ -1122,6 +1148,10 @@ def validate( if ATTRS_KEY not in data.uns: return data + # Check if this is a lazy-loaded AnnData (from anndata.experimental.read_lazy) + # Lazy AnnData has xarray-based obs/var, which requires different validation + is_lazy = _is_lazy_anndata(data) + _, region_key, instance_key = get_table_keys(data) if region_key is not None: if region_key not in data.obs: @@ -1129,7 +1159,8 @@ def validate( f"Region key `{region_key}` not in `adata.obs`. Please create the column and parse " f"using TableModel.parse(adata)." ) - if not isinstance(data.obs[region_key].dtype, CategoricalDtype): + # Skip dtype validation for lazy tables (would require loading data) + if not is_lazy and not isinstance(data.obs[region_key].dtype, CategoricalDtype): raise ValueError( f"`table.obs[{region_key}]` must be of type `categorical`, not `{type(data.obs[region_key])}`." ) @@ -1139,7 +1170,8 @@ def validate( f"Instance key `{instance_key}` not in `adata.obs`. Please create the column and parse" f" using TableModel.parse(adata)." ) - if data.obs[instance_key].isnull().values.any(): + # Skip null check for lazy tables (would require loading data) + if not is_lazy and data.obs[instance_key].isnull().values.any(): raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") cls._validate_table_annotation_metadata(data) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index be07d8be8..def2c1477 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -97,7 +97,11 @@ def test_shapes( # add a mixed Polygon + MultiPolygon element shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]]) - shapes.write(tmpdir, sdata_formats=sdata_container_format, shapes_geometry_encoding=geometry_encoding) + shapes.write( + tmpdir, + sdata_formats=sdata_container_format, + shapes_geometry_encoding=geometry_encoding, + ) sdata = SpatialData.read(tmpdir) if geometry_encoding == "WKB": @@ -1114,3 +1118,87 @@ def test_sdata_with_nan_in_obs(tmp_path: Path) -> None: else: # After round-trip, NaN in object-dtype column becomes string "nan" on pandas 2 assert r1.iloc[1] == "nan" assert np.isnan(r2.iloc[0]) + + +class TestLazyTableLoading: + """Tests for lazy table loading functionality. + + Lazy loading uses anndata.experimental.read_lazy() to keep large tables + out of memory until needed. This is particularly useful for MSI data + where tables can contain millions of pixels. + """ + + @pytest.fixture + def sdata_with_table(self) -> SpatialData: + """Create a SpatialData object with a simple table for testing.""" + from spatialdata.models import TableModel + + rng = default_rng(42) + table = TableModel.parse( + AnnData( + X=rng.random((100, 50)), + obs=pd.DataFrame( + { + "region": pd.Categorical(["region1"] * 100), + "instance": np.arange(100), + } + ), + ), + region_key="region", + instance_key="instance", + region="region1", + ) + return SpatialData(tables={"test_table": table}) + + def test_lazy_read_basic(self, sdata_with_table: SpatialData) -> None: + """Test that lazy=True reads tables without loading into memory.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Read with lazy=True + try: + sdata_lazy = SpatialData.read(path, lazy=True) + + # Table should be present + assert "test_table" in sdata_lazy.tables + + # Check that X is a lazy array (dask or similar) + # Lazy AnnData from read_lazy uses dask arrays + table = sdata_lazy.tables["test_table"] + assert hasattr(table, "X") + + except ImportError: + # If anndata.experimental.read_lazy is not available, skip + pytest.skip("anndata.experimental.read_lazy not available") + + def test_lazy_false_loads_normally(self, sdata_with_table: SpatialData) -> None: + """Test that lazy=False (default) loads tables into memory normally.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Read with lazy=False (default) + sdata_normal = SpatialData.read(path, lazy=False) + + # Table should be present and loaded normally + assert "test_table" in sdata_normal.tables + table = sdata_normal.tables["test_table"] + + # X should be a numpy array or scipy sparse matrix (in-memory) + import scipy.sparse as sp + + assert isinstance(table.X, np.ndarray | sp.spmatrix) + + def test_read_zarr_lazy_parameter(self, sdata_with_table: SpatialData) -> None: + """Test that read_zarr function accepts lazy parameter.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.zarr") + sdata_with_table.write(path) + + # Test read_zarr directly with lazy parameter + try: + sdata = read_zarr(path, lazy=True) + assert "test_table" in sdata.tables + except ImportError: + pytest.skip("anndata.experimental.read_lazy not available")