From f3e77bfc990ba8fbfced379c0722df6964533b80 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Thu, 7 May 2026 11:11:10 +0200 Subject: [PATCH 1/3] basic support for AnnData accessors this should keep scanpy plotting working with MuData objects --- pyproject.toml | 2 +- src/mudata/_core/mudata.py | 19 +++++++++++++++++++ tests/test_obs_var.py | 13 +++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 309e5d3..55b04b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dev = [ "pre-commit", "twine>=4.0.2", ] -test = [ "coverage>=7.10", "mudata[io]", "pytest" ] +test = [ "coverage>=7.10", "mudata[io]", "packaging", "pytest" ] doc = [ "docutils>=0.8,!=0.18.*,!=0.19.*", "ipykernel", diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 9e9e54a..5ec88d6 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -533,6 +533,25 @@ def strings_to_categoricals(self, df: pd.DataFrame | None = None) -> pd.DataFram def __getitem__(self, index) -> AnnData | MuData: if isinstance(index, str): return self._mod[index] + elif type(index).__module__.startswith("anndata.acc") and type(index).__name__ == "AdRef": + try: + return index.acc.get(self, index.idx) + except KeyError as e: + if index.acc.dim in ("obs", "var"): + for modname, mod in self._mod.items(): + try: + index.acc.get(mod, index.idx) + except KeyError: + pass + else: + raise KeyError( + f"There is no key {index.idx} in MuData .{index.acc.dim} but there is one in {modname} .{index.acc.dim}. Consider running `pull_{index.acc.dim}()` to update global .{index.acc.dim}." + ) from e + raise KeyError( + f"There is no key {index.idx} in MuData .{index.acc.dim} or in .{index.acc.dim} of any modalities." + ) from e + else: + raise else: return MuData(self, as_view=True, index=index) diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 3057c50..40a97e4 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -1,8 +1,10 @@ from pathlib import Path +import anndata as ad import numpy as np import pandas as pd import pytest +from packaging.version import Version import mudata as md @@ -145,3 +147,14 @@ def test_names_make_unique(mdata: md.MuData): with pytest.raises(TypeError, match="axis="): getattr(mdata, f"{attr}_names_make_unique")() + + +@pytest.mark.skipif( + Version(ad.__version__) < Version("0.13dev0"), reason="anndata version too old, no accessor support" +) +def test_accessors(mdata: md.MuData): + assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() + with pytest.raises(KeyError, match="any modalities"): + mdata[ad.acc.A.var["test"]] + with pytest.raises(KeyError, match="there is one in"): + mdata[ad.acc.A.var["assert-bool"]] From 55ffb17eec47356770aafc3143efb640751724fb Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 8 May 2026 09:02:18 +0200 Subject: [PATCH 2/3] simplify --- src/mudata/_core/mudata.py | 12 ++---------- tests/test_obs_var.py | 2 +- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 5ec88d6..8c01e03 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -539,19 +539,11 @@ def __getitem__(self, index) -> AnnData | MuData: except KeyError as e: if index.acc.dim in ("obs", "var"): for modname, mod in self._mod.items(): - try: - index.acc.get(mod, index.idx) - except KeyError: - pass - else: + if index in mod: raise KeyError( f"There is no key {index.idx} in MuData .{index.acc.dim} but there is one in {modname} .{index.acc.dim}. Consider running `pull_{index.acc.dim}()` to update global .{index.acc.dim}." ) from e - raise KeyError( - f"There is no key {index.idx} in MuData .{index.acc.dim} or in .{index.acc.dim} of any modalities." - ) from e - else: - raise + raise else: return MuData(self, as_view=True, index=index) diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 40a97e4..451148a 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -154,7 +154,7 @@ def test_names_make_unique(mdata: md.MuData): ) def test_accessors(mdata: md.MuData): assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() - with pytest.raises(KeyError, match="any modalities"): + with pytest.raises(KeyError, match="test"): mdata[ad.acc.A.var["test"]] with pytest.raises(KeyError, match="there is one in"): mdata[ad.acc.A.var["assert-bool"]] From a5e25fd0addec0a39b71405de0e0d30d7f500cbb Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 8 May 2026 09:45:11 +0200 Subject: [PATCH 3/3] implement __contains__ --- src/mudata/_core/mudata.py | 8 ++++++++ tests/test_obs_var.py | 1 + tests/test_update.py | 3 +++ 3 files changed, 12 insertions(+) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 8c01e03..123ab95 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -547,6 +547,14 @@ def __getitem__(self, index) -> AnnData | MuData: else: return MuData(self, as_view=True, index=index) + def __contains__(self, key) -> bool: + if isinstance(key, str): + return key in self._mod + elif type(key).__module__.startswith("anndata.acc") and type(key).__name__ == "AdRef": + return AnnData.__contains__.__get__(self)(key) + else: + raise TypeError(f"Unexpected key {key!r}.") + @property def mod(self) -> Mapping[str, AnnData | MuData]: """Dictionary of modalities.""" diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 451148a..4636b4f 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -153,6 +153,7 @@ def test_names_make_unique(mdata: md.MuData): Version(ad.__version__) < Version("0.13dev0"), reason="anndata version too old, no accessor support" ) def test_accessors(mdata: md.MuData): + assert ad.acc.A.obs["arange"] in mdata assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() with pytest.raises(KeyError, match="test"): mdata[ad.acc.A.var["test"]] diff --git a/tests/test_update.py b/tests/test_update.py index 0a32eb0..318ff49 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -146,6 +146,9 @@ def test_update_simple(mdata: MuData, axis: Axis): for mod in mdata.mod.keys(): assert mdata.obsmap[mod].dtype.kind == "u" assert mdata.varmap[mod].dtype.kind == "u" + assert mod in mdata + with pytest.raises(TypeError): + 1 in mdata # noqa: B015 # names along non-axis are concatenated assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values())