Skip to content

Commit 2681749

Browse files
committed
add row masking for module finding
1 parent d8b91e3 commit 2681749

1 file changed

Lines changed: 17 additions & 4 deletions

File tree

scself/_modules/find_modules.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ def get_correlation_modules(
1010
layer='X',
1111
n_neighbors=10,
1212
leiden_kwargs={},
13-
output_key='gene_module'
13+
output_key='gene_module',
14+
obs_mask=None
1415
):
1516
"""
1617
Get correlation modules from an anndata object.
@@ -31,6 +32,8 @@ def get_correlation_modules(
3132
:param output_key: Column to add to adata.var with module IDs,
3233
defaults to 'gene_module'
3334
:type output_key: str, optional
35+
:param obs_mask: Boolean mask or slice for observations to consider
36+
:type obs_mask: np.ndarray or slice, optional
3437
3538
:return: The original adata object with:
3639
Gene correlations in 'X_corrcoef' in .varp
@@ -39,9 +42,13 @@ def get_correlation_modules(
3942
:rtype: ad.AnnData
4043
"""
4144

45+
if obs_mask is None:
46+
obs_mask = slice(None)
47+
4248
if f'{layer}_corrcoef' not in adata.varp.keys():
4349
adata.varp[f'{layer}_corrcoef'] = corrcoef(
44-
adata.X if layer == 'X' else adata.layers[layer]
50+
adata.X[obs_mask, :] if layer == 'X' else
51+
adata.layers[layer][obs_mask, :]
4552
)
4653

4754
corr_dist_adata = correlation_clustering_and_umap(
@@ -68,7 +75,8 @@ def get_correlation_submodules(
6875
n_neighbors=10,
6976
leiden_kwargs={},
7077
input_key='gene_module',
71-
output_key='gene_submodule'
78+
output_key='gene_submodule',
79+
obs_mask=None
7280
):
7381
"""
7482
Get correlation submodules iteratively from an anndata object
@@ -91,6 +99,8 @@ def get_correlation_submodules(
9199
:param output_key: Column to add to adata.var with module IDs,
92100
defaults to 'gene_submodule'
93101
:type output_key: str, optional
102+
:param obs_mask: Boolean mask or slice for observations to consider
103+
:type obs_mask: np.ndarray or slice, optional
94104
95105
:return: The original adata object with:
96106
Gene-gene submodule correlation UMAP in 'X_submodule_umap' in .varm
@@ -100,6 +110,9 @@ def get_correlation_submodules(
100110

101111
if input_key not in adata.var.columns:
102112
raise RuntimeError(f"Column {input_key} not present in .var")
113+
114+
if obs_mask is None:
115+
obs_mask = slice(None)
103116

104117
lref = adata.X if layer == 'X' else adata.layers[layer]
105118

@@ -117,7 +130,7 @@ def get_correlation_submodules(
117130
_slice_idx = adata.var[input_key] == cat
118131

119132
_slice_corr_dist_adata = correlation_clustering_and_umap(
120-
corrcoef(lref[:, _slice_idx]),
133+
corrcoef(lref[:, _slice_idx][obs_mask, :]),
121134
n_neighbors=n_neighbors,
122135
var_names=adata.var_names[_slice_idx],
123136
**leiden_kwargs

0 commit comments

Comments
 (0)