@@ -10,7 +10,8 @@ def get_correlation_modules(
1010 layer = 'X' ,
1111 n_neighbors = 10 ,
1212 leiden_kwargs = {},
13- output_key = 'gene_module'
13+ output_key = 'gene_module' ,
14+ obs_mask = None
1415):
1516 """
1617 Get correlation modules from an anndata object.
@@ -31,6 +32,8 @@ def get_correlation_modules(
3132 :param output_key: Column to add to adata.var with module IDs,
3233 defaults to 'gene_module'
3334 :type output_key: str, optional
35+ :param obs_mask: Boolean mask or slice for observations to consider
36+ :type obs_mask: np.ndarray or slice, optional
3437
3538 :return: The original adata object with:
3639 Gene correlations in 'X_corrcoef' in .varp
@@ -39,9 +42,13 @@ def get_correlation_modules(
3942 :rtype: ad.AnnData
4043 """
4144
45+ if obs_mask is None :
46+ obs_mask = slice (None )
47+
4248 if f'{ layer } _corrcoef' not in adata .varp .keys ():
4349 adata .varp [f'{ layer } _corrcoef' ] = corrcoef (
44- adata .X if layer == 'X' else adata .layers [layer ]
50+ adata .X [obs_mask , :] if layer == 'X' else
51+ adata .layers [layer ][obs_mask , :]
4552 )
4653
4754 corr_dist_adata = correlation_clustering_and_umap (
@@ -68,7 +75,8 @@ def get_correlation_submodules(
6875 n_neighbors = 10 ,
6976 leiden_kwargs = {},
7077 input_key = 'gene_module' ,
71- output_key = 'gene_submodule'
78+ output_key = 'gene_submodule' ,
79+ obs_mask = None
7280):
7381 """
7482 Get correlation submodules iteratively from an anndata object
@@ -91,6 +99,8 @@ def get_correlation_submodules(
9199 :param output_key: Column to add to adata.var with module IDs,
92100 defaults to 'gene_submodule'
93101 :type output_key: str, optional
102+ :param obs_mask: Boolean mask or slice for observations to consider
103+ :type obs_mask: np.ndarray or slice, optional
94104
95105 :return: The original adata object with:
96106 Gene-gene submodule correlation UMAP in 'X_submodule_umap' in .varm
@@ -100,6 +110,9 @@ def get_correlation_submodules(
100110
101111 if input_key not in adata .var .columns :
102112 raise RuntimeError (f"Column { input_key } not present in .var" )
113+
114+ if obs_mask is None :
115+ obs_mask = slice (None )
103116
104117 lref = adata .X if layer == 'X' else adata .layers [layer ]
105118
@@ -117,7 +130,7 @@ def get_correlation_submodules(
117130 _slice_idx = adata .var [input_key ] == cat
118131
119132 _slice_corr_dist_adata = correlation_clustering_and_umap (
120- corrcoef (lref [:, _slice_idx ]),
133+ corrcoef (lref [:, _slice_idx ][ obs_mask , :] ),
121134 n_neighbors = n_neighbors ,
122135 var_names = adata .var_names [_slice_idx ],
123136 ** leiden_kwargs
0 commit comments