Skip to content

Commit 7d7a5c6

Browse files
committed
Correlation-weighted module finding
1 parent fd55680 commit 7d7a5c6

4 files changed

Lines changed: 194 additions & 5 deletions

File tree

scself/_modules/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
get_correlation_modules,
33
get_correlation_submodules
44
)
5+
6+
from .find_weighted_modules import (
7+
get_combined_correlation_modules
8+
)
9+
510
from .score_modules import (
611
score_all_modules,
712
score_all_submodules,
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import pandas as pd
2+
import numpy as np
3+
import anndata as ad
4+
5+
from functools import reduce
6+
7+
from scself.utils.correlation import (
8+
corrcoef,
9+
correlation_clustering_and_umap
10+
)
11+
12+
_ITER_TYPES = (tuple, list, pd.Series, np.ndarray)
13+
14+
def get_combined_correlation_modules(
15+
adata_list,
16+
layer='X',
17+
n_neighbors=10,
18+
leiden_kwargs={},
19+
output_key='gene_module',
20+
obs_mask=None
21+
):
22+
"""
23+
Get correlation modules from a list of anndata objects
24+
by calculating the correlation separately for each object
25+
and averaging the correlation between genes where the
26+
objects overlap.
27+
28+
Adds .varp['X_corrcoef'] with gene-gene correlation
29+
and .var[output_key] with gene module ID
30+
31+
:param adata_list: List of data objects containing
32+
expression data
33+
:type adata_list: list(ad.AnnData)
34+
:param layer: Layer to calculate correlation and find
35+
modules from, can provide a list with the same length
36+
as adata_list, defaults to 'X'
37+
:type layer: str, optional
38+
:param n_neighbors: Number of neighbors in kNN, defaults
39+
to 10
40+
:type n_neighbors: int, optional
41+
:param leiden_kwargs: Keyword arguments to sc.tl.leiden
42+
:type leiden_kwargs: dict, optional
43+
:param output_key: Column to add to adata.var with module IDs,
44+
defaults to 'gene_module'
45+
:type output_key: str, optional
46+
:param obs_mask: Boolean mask or slice for observations to consider
47+
:type obs_mask: np.ndarray or slice, optional
48+
49+
:return: A correlation (gene x gene) anndata object with:
50+
Gene-gene correlation UMAP in 'X_umap' in .varm
51+
Module membership IDs in .obs[output_key]
52+
:rtype: ad.AnnData
53+
"""
54+
55+
_n_datasets = len(adata_list)
56+
57+
if obs_mask is None:
58+
obs_mask = [None] * _n_datasets
59+
60+
# Make sure all of the arguments are iterable lists of
61+
# correct length and raise an AttributeError if not
62+
def _to_iterable(arg, argname):
63+
if not isinstance(arg, _ITER_TYPES):
64+
return [arg] * _n_datasets
65+
elif len(arg) != _n_datasets:
66+
raise AttributeError(
67+
f"len({argname}) = {len(arg)}; {_n_datasets} is required"
68+
)
69+
else:
70+
return arg
71+
72+
layer = _to_iterable(layer, 'layer')
73+
74+
# Calculate correlation for each dataset
75+
for adata, layer_i, mask_i in zip(
76+
adata_list,
77+
layer,
78+
obs_mask
79+
):
80+
81+
if f'{layer_i}_corrcoef' not in adata.varp.keys():
82+
_lref = adata.X if layer_i == 'X' else adata.layers[layer_i]
83+
84+
adata.varp[f'{layer_i}_corrcoef'] = corrcoef(
85+
_lref[mask_i, :] if mask_i is not None else _lref
86+
)
87+
88+
del _lref
89+
90+
# Check to see if all the data is already aligned
91+
# If not, find the union of all the var_names
92+
if all(
93+
all(
94+
adata.var_names.equals(a.var_names)
95+
for a in adata_list
96+
)
97+
for adata in adata_list
98+
):
99+
_genes = adata_list[0].var_names.copy()
100+
_do_reindex=False
101+
else:
102+
_genes = reduce(
103+
lambda x, y: x.var_names.union(y.var_names),
104+
adata_list
105+
)
106+
_do_reindex=True
107+
108+
_n_genes = len(_genes)
109+
110+
# Get the number of times each gene appears
111+
_gene_counts = reduce(
112+
lambda x, y: x + y,
113+
[_genes.isin(c.var_names).astype(int) for c in adata_list]
114+
)
115+
116+
# Create a zeroed correlation matrix for the
117+
# gene union
118+
full_correlation = ad.AnnData(
119+
np.zeros(
120+
(_n_genes, _n_genes),
121+
dtype=float
122+
),
123+
var=pd.DataFrame(index=_genes),
124+
obs=pd.DataFrame(index=_genes)
125+
)
126+
127+
# Iterate through all the anndata object and add
128+
# each correlation into the full_correlation
129+
# indexed appropriately for the feaure name
130+
for adata, layer_i in zip(
131+
adata_list,
132+
layer
133+
):
134+
135+
if _do_reindex:
136+
full_correlation.X[
137+
np.ix_(
138+
full_correlation.obs_names.get_indexer(adata.var_names),
139+
full_correlation.var_names.get_indexer(adata.var_names)
140+
)
141+
] += adata.varp[f'{layer_i}_corrcoef']
142+
else:
143+
full_correlation.X += adata.varp[f'{layer_i}_corrcoef']
144+
145+
# Correct for the number of times the gene-gene correlation
146+
# was calculated
147+
full_correlation.X /= np.minimum(
148+
_gene_counts[:, None],
149+
_gene_counts[None, :]
150+
)
151+
152+
full_correlation = correlation_clustering_and_umap(
153+
full_correlation.X,
154+
n_neighbors=n_neighbors,
155+
var_names=full_correlation.var_names,
156+
**leiden_kwargs
157+
)
158+
159+
full_correlation.obs['leiden'] = full_correlation.obs['leiden'].astype(int)
160+
161+
for adata, layer_i in zip(
162+
adata_list,
163+
layer
164+
):
165+
_gene_idx = full_correlation.var_names.get_indexer(adata.var_names)
166+
167+
# Put the gene module memberships in
168+
adata.var[output_key] = full_correlation.obs['leiden']
169+
170+
# Put the partial umap into the separate objects
171+
# so they can be plotted in the same space
172+
adata.varm[f'{layer_i}_umap'] = full_correlation.obsm['X_umap'][_gene_idx, :]
173+
174+
return full_correlation

scself/utils/correlation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def correlation_clustering_and_umap(
8888
correlations,
8989
n_neighbors=10,
9090
var_names=None,
91+
skip_leiden=None,
9192
**leiden_kwargs
9293
):
9394

@@ -97,9 +98,12 @@ def correlation_clustering_and_umap(
9798
obs=pd.DataFrame(index=var_names) if var_names is not None else None
9899
)
99100

100-
# Special case handling to silently handle when there are too many neighbors for
101-
# the provided data; comes up with submodules a lot
102-
if corr_dist_adata.shape[0] <= n_neighbors:
101+
# Special case handling to silently handle when there are too many neighbors
102+
# for the provided data; comes up with submodules a lot
103+
if skip_leiden is not None:
104+
pass
105+
106+
elif corr_dist_adata.shape[0] <= n_neighbors:
103107

104108
n_neighbors = corr_dist_adata.shape[0] - 2
105109

scself/utils/hierarchical_clustering.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from scipy.spatial.distance import pdist
99

1010

11-
def hclust(data, metric='euclidean', method='ward', return_fcluster=False, **kwargs):
11+
def hclust(
12+
data,
13+
metric='euclidean',
14+
method='ward',
15+
return_fcluster=False,
16+
**fcluster_kwargs
17+
):
1218

1319
# Increase recursion limit so dendrogram doesn't throw a fit
1420
# if needed
@@ -40,7 +46,7 @@ def hclust(data, metric='euclidean', method='ward', return_fcluster=False, **kwa
4046
if return_fcluster:
4147
return _order, fcluster(
4248
_links,
43-
**kwargs
49+
**fcluster_kwargs
4450
)
4551
else:
4652
return _order

0 commit comments

Comments
 (0)