diff --git a/pyproject.toml b/pyproject.toml index 309e5d3..55b04b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dev = [ "pre-commit", "twine>=4.0.2", ] -test = [ "coverage>=7.10", "mudata[io]", "pytest" ] +test = [ "coverage>=7.10", "mudata[io]", "packaging", "pytest" ] doc = [ "docutils>=0.8,!=0.18.*,!=0.19.*", "ipykernel", diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 9e9e54a..123ab95 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -533,9 +533,28 @@ def strings_to_categoricals(self, df: pd.DataFrame | None = None) -> pd.DataFram def __getitem__(self, index) -> AnnData | MuData: if isinstance(index, str): return self._mod[index] + elif type(index).__module__.startswith("anndata.acc") and type(index).__name__ == "AdRef": + try: + return index.acc.get(self, index.idx) + except KeyError as e: + if index.acc.dim in ("obs", "var"): + for modname, mod in self._mod.items(): + if index in mod: + raise KeyError( + f"There is no key {index.idx} in MuData .{index.acc.dim} but there is one in {modname} .{index.acc.dim}. Consider running `pull_{index.acc.dim}()` to update global .{index.acc.dim}." + ) from e + raise else: return MuData(self, as_view=True, index=index) + def __contains__(self, key) -> bool: + if isinstance(key, str): + return key in self._mod + elif type(key).__module__.startswith("anndata.acc") and type(key).__name__ == "AdRef": + return AnnData.__contains__.__get__(self)(key) + else: + raise TypeError(f"Unexpected key {key!r}.") + @property def mod(self) -> Mapping[str, AnnData | MuData]: """Dictionary of modalities.""" diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 3057c50..4636b4f 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -1,8 +1,10 @@ from pathlib import Path +import anndata as ad import numpy as np import pandas as pd import pytest +from packaging.version import Version import mudata as md @@ -145,3 +147,15 @@ def test_names_make_unique(mdata: md.MuData): with pytest.raises(TypeError, match="axis="): getattr(mdata, f"{attr}_names_make_unique")() + + +@pytest.mark.skipif( + Version(ad.__version__) < Version("0.13dev0"), reason="anndata version too old, no accessor support" +) +def test_accessors(mdata: md.MuData): + assert ad.acc.A.obs["arange"] in mdata + assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() + with pytest.raises(KeyError, match="test"): + mdata[ad.acc.A.var["test"]] + with pytest.raises(KeyError, match="there is one in"): + mdata[ad.acc.A.var["assert-bool"]] diff --git a/tests/test_update.py b/tests/test_update.py index 0a32eb0..318ff49 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -146,6 +146,9 @@ def test_update_simple(mdata: MuData, axis: Axis): for mod in mdata.mod.keys(): assert mdata.obsmap[mod].dtype.kind == "u" assert mdata.varmap[mod].dtype.kind == "u" + assert mod in mdata + with pytest.raises(TypeError): + 1 in mdata # noqa: B015 # names along non-axis are concatenated assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values())