Skip to content

Commit 82e9c69

Browse files
committed
Improve documentation
1 parent 30ae150 commit 82e9c69

3 files changed

Lines changed: 103 additions & 45 deletions

File tree

scself/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@
2424
get_correlation_submodules,
2525
module_score,
2626
score_all_modules,
27-
score_all_submodules
27+
score_all_submodules,
28+
get_combined_correlation_modules
2829
)

scself/_modules/find_weighted_modules.py

Lines changed: 85 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21,47 +21,68 @@ def get_combined_correlation_modules(
2121
obs_mask=None
2222
):
2323
"""
24-
Get correlation modules from a list of anndata objects
25-
by calculating the correlation separately for each object
26-
and averaging the correlation between genes where the
27-
objects overlap.
24+
Find gene modules by combining correlation patterns across multiple datasets.
2825
29-
Adds .varp['X_corrcoef'] with gene-gene correlation
30-
and .var[output_key] with gene module ID
26+
This function performs a weighted averaging of gene-gene correlations across
27+
multiple AnnData objects, enabling module discovery that accounts for genes
28+
measured in different subsets of datasets. The workflow consists of:
3129
32-
:param adata_list: List of data objects containing
33-
expression data
30+
1. Calculate gene-gene correlation within each dataset independently
31+
2. Determine the union of genes across all datasets
32+
3. Average correlations where genes overlap (weighted by dataset availability)
33+
4. Cluster averaged correlations using Leiden community detection
34+
5. Assign module IDs back to individual datasets
35+
36+
The key advantage is handling datasets with non-overlapping gene sets while
37+
properly weighting correlations based on the number of datasets in which both
38+
genes appear.
39+
40+
**Modifies input datasets in-place:**
41+
- Adds .varp['{layer}_corrcoef'] with gene-gene correlation matrix
42+
- Adds .var[output_key] with integer gene module assignments
43+
- Adds .varm['{layer}_umap'] with UMAP coordinates in shared space
44+
45+
:param adata_list: List of AnnData objects containing expression data.
46+
Datasets may have different genes but should have comparable expression scales.
3447
:type adata_list: list(ad.AnnData)
35-
:param layer: Layer to calculate correlation and find
36-
modules from, can provide a list with the same length
37-
as adata_list, defaults to 'X'
38-
:type layer: str, optional
39-
:param n_neighbors: Number of neighbors in kNN, defaults
40-
to 10
48+
:param layer: Layer to use for correlation calculation. Can be 'X' or any layer name.
49+
Accepts either a single string (applied to all datasets) or a list with one
50+
layer name per dataset, defaults to 'X'
51+
:type layer: str or list(str), optional
52+
:param n_neighbors: Number of neighbors for kNN graph construction used in
53+
clustering and UMAP, defaults to 10
4154
:type n_neighbors: int, optional
42-
:param leiden_kwargs: Keyword arguments to sc.tl.leiden
55+
:param leiden_kwargs: Additional keyword arguments passed to leiden clustering
56+
(e.g., resolution parameter), defaults to {}
4357
:type leiden_kwargs: dict, optional
44-
:param output_key: Column to add to adata.var with module IDs,
58+
:param output_key: Column name to add to .var for storing module IDs,
4559
defaults to 'gene_module'
4660
:type output_key: str, optional
47-
:param obs_mask: Boolean mask or slice for observations to consider
48-
:type obs_mask: np.ndarray or slice, optional
49-
50-
:return: A correlation (gene x gene) anndata object with:
51-
Gene-gene correlation UMAP in 'X_umap' in .varm
52-
Module membership IDs in .obs[output_key]
61+
:param obs_mask: Boolean mask or slice to subset observations (cells) before
62+
computing correlations. Can be a single mask (applied to all) or list of masks
63+
(one per dataset). Useful for computing modules from specific cell populations.
64+
:type obs_mask: np.ndarray, slice, list, or None, optional
65+
66+
:return: An AnnData object representing the full gene-gene correlation matrix with:
67+
- .layers['{layer}_corrcoef']: Averaged correlation matrix (n_genes x n_genes)
68+
- .obs['leiden'] and .var['leiden']: Module assignments (symmetric)
69+
- .obsm['{layer}_umap']: UMAP embedding of gene correlations
5370
:rtype: ad.AnnData
5471
"""
5572

5673
_n_datasets = len(adata_list)
5774

75+
# Initialize obs_mask to None for all datasets if not provided
5876
if obs_mask is None:
5977
obs_mask = [None] * _n_datasets
6078

61-
# Make sure all of the arguments are iterable lists of
62-
# correct length and raise an AttributeError if not
79+
# Helper function to broadcast scalar arguments to lists
80+
# matching the number of datasets. This allows users to pass
81+
# a single layer name that applies to all datasets, or a list
82+
# with one layer name per dataset.
6383
def _to_iterable(arg, argname):
6484
if not isinstance(arg, _ITER_TYPES):
85+
# Broadcast scalar to list
6586
return [arg] * _n_datasets
6687
elif len(arg) != _n_datasets:
6788
raise AttributeError(
@@ -72,34 +93,40 @@ def _to_iterable(arg, argname):
7293

7394
layer = _to_iterable(layer, 'layer')
7495

75-
# Calculate correlation for each dataset
96+
# Step 1: Calculate gene-gene correlation for each dataset independently
97+
# Store results in each dataset's .varp to enable caching across calls
7698
for adata, layer_i, mask_i in zip(
7799
adata_list,
78100
layer,
79101
obs_mask
80102
):
81-
103+
# Skip if correlation already computed for this layer
82104
if f'{layer_i}_corrcoef' not in adata.varp.keys():
105+
# Get reference to expression data (X or specified layer)
83106
_lref = adata.X if layer_i == 'X' else adata.layers[layer_i]
84107

108+
# Apply observation mask if provided (subset to specific cells)
109+
# Then compute gene-gene correlation matrix
85110
adata.varp[f'{layer_i}_corrcoef'] = corrcoef(
86111
_lref[mask_i, :] if mask_i is not None else _lref
87112
)
88113

89114
del _lref
90115

91-
# Check to see if all the data is already aligned
92-
# If not, find the union of all the var_names
116+
# Step 2: Determine the gene universe across all datasets
117+
# Check if all datasets share identical gene names (same genes, same order)
93118
if all(
94119
all(
95120
adata.var_names.equals(a.var_names)
96121
for a in adata_list
97122
)
98123
for adata in adata_list
99124
):
125+
# All datasets aligned - no reindexing needed
100126
_genes = adata_list[0].var_names.copy()
101127
_do_reindex=False
102128
else:
129+
# Datasets have different genes - compute union of all gene names
103130
_genes = reduce(
104131
lambda x, y: x.var_names.union(y.var_names),
105132
adata_list
@@ -108,75 +135,89 @@ def _to_iterable(arg, argname):
108135

109136
_n_genes = len(_genes)
110137

111-
# Get the number of times each gene appears
138+
# Count how many datasets contain each gene
139+
# This is used later for weighted averaging of correlations
112140
_gene_counts = reduce(
113141
lambda x, y: x + y,
114142
[_genes.isin(c.var_names).astype(int) for c in adata_list]
115143
)
116144

117-
# Create a zeroed correlation matrix for the
118-
# gene union
145+
# Step 3: Initialize combined correlation matrix
146+
# Create an AnnData object to hold the gene x gene correlation
147+
# Both obs and var represent genes (symmetric gene-gene matrix)
119148
full_correlation = ad.AnnData(
120149
csr_matrix((_n_genes, _n_genes)),
121150
var=pd.DataFrame(index=_genes),
122151
obs=pd.DataFrame(index=_genes)
123152
)
124153

154+
# Store the summed correlations in a layer (will average later)
125155
_corr_layer = f'{layer[0]}_corrcoef'
126156
full_correlation.layers[_corr_layer] = np.zeros(
127157
(_n_genes, _n_genes),
128158
dtype=float
129159
)
130160

131-
# Iterate through all the anndata object and add
132-
# each correlation into the full_correlation
133-
# indexed appropriately for the feaure name
161+
# Step 4: Accumulate correlation matrices from each dataset
162+
# Sum correlations into the appropriate positions based on gene names
134163
for adata, layer_i in zip(
135164
adata_list,
136165
layer
137166
):
138-
167+
139168
if _do_reindex:
169+
# Map each dataset's genes to their positions in the full gene universe
170+
# np.ix_ creates a mesh for 2D indexing (row indices x column indices)
140171
full_correlation.layers[_corr_layer][
141172
np.ix_(
142173
full_correlation.obs_names.get_indexer(adata.var_names),
143174
full_correlation.var_names.get_indexer(adata.var_names)
144175
)
145176
] += adata.varp[f'{layer_i}_corrcoef']
146177
else:
178+
# All datasets aligned - direct addition
147179
full_correlation.layers[_corr_layer] += adata.varp[f'{layer_i}_corrcoef']
148-
149-
# Correct for the number of times the gene-gene correlation
150-
# was calculated
180+
181+
# Step 5: Compute weighted average of correlations
182+
# Divide by the minimum count between gene pairs to get proper average
183+
# Example: if gene A appears in 3 datasets and gene B in 2 datasets,
184+
# their correlation can only be calculated in min(3,2)=2 datasets
151185
full_correlation.layers[_corr_layer] /= np.minimum(
152-
_gene_counts[:, None],
153-
_gene_counts[None, :]
186+
_gene_counts[:, None], # Row-wise gene counts
187+
_gene_counts[None, :] # Column-wise gene counts
154188
)
155189

190+
# Sanity check: correlations should be in [-1, 1]
156191
assert full_correlation.layers[_corr_layer].max() <= 1.0
157-
192+
193+
# Step 6: Cluster genes based on correlation patterns
194+
# Build kNN graph, run Leiden clustering, and compute UMAP embedding
158195
_corr_results = correlation_clustering_and_umap(
159196
full_correlation.layers[_corr_layer],
160197
n_neighbors=n_neighbors,
161198
var_names=full_correlation.var_names,
162199
**leiden_kwargs
163200
)
164201

202+
# Store clustering results in both obs and var (symmetric since genes x genes)
165203
full_correlation.obs['leiden'] = _corr_results.obs['leiden'].astype(int).values
166204
full_correlation.var['leiden'] = _corr_results.obs['leiden'].astype(int).values
167205
full_correlation.obsm[f'{layer_i}_umap'] = _corr_results.obsm['X_umap']
168206

207+
# Step 7: Propagate module assignments back to individual datasets
208+
# Each dataset gets the module IDs for its own genes
169209
for adata, layer_i in zip(
170210
adata_list,
171211
layer
172212
):
213+
# Find positions of this dataset's genes in the full correlation results
173214
_gene_idx = _corr_results.var_names.get_indexer(adata.var_names)
174215

175-
# Put the gene module memberships in
216+
# Assign module membership to each gene
176217
adata.var[output_key] = _corr_results.obs['leiden']
177218

178-
# Put the partial umap into the separate objects
179-
# so they can be plotted in the same space
219+
# Store the UMAP coordinates for this dataset's genes
220+
# This allows plotting genes from different datasets in the same UMAP space
180221
adata.varm[f'{layer_i}_umap'] = _corr_results.obsm['X_umap'][_gene_idx, :]
181222

182223
return full_correlation

scself/tests/test_weighted_modules.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from scipy import sparse
1111

1212
from scself._modules.find_weighted_modules import get_combined_correlation_modules
13+
from scself._modules.find_modules import get_correlation_modules
1314

1415

1516
@pytest.fixture
@@ -90,6 +91,21 @@ def misaligned_adata_list():
9091
return [adata1, adata2]
9192

9293

94+
def test_fixtures(aligned_adata_list):
95+
96+
result = get_correlation_modules(
97+
aligned_adata_list[0],
98+
n_neighbors=3
99+
)
100+
101+
assert np.array_equal(
102+
result.var['gene_module'].values[0:8],
103+
[0] * 4 + [1] * 4
104+
) or np.array_equal(
105+
result.var['gene_module'].values[0:8],
106+
[1] * 4 + [0] * 4
107+
)
108+
93109
def test_basic_functionality_aligned(aligned_adata_list):
94110
"""Test basic functionality with aligned datasets."""
95111
result = get_combined_correlation_modules(

0 commit comments

Comments
 (0)