@@ -21,47 +21,68 @@ def get_combined_correlation_modules(
2121 obs_mask = None
2222):
2323 """
24- Get correlation modules from a list of anndata objects
25- by calculating the correlation separately for each object
26- and averaging the correlation between genes where the
27- objects overlap.
24+ Find gene modules by combining correlation patterns across multiple datasets.
2825
29- Adds .varp['X_corrcoef'] with gene-gene correlation
30- and .var[output_key] with gene module ID
26+ This function performs a weighted averaging of gene-gene correlations across
27+ multiple AnnData objects, enabling module discovery that accounts for genes
28+ measured in different subsets of datasets. The workflow consists of:
3129
32- :param adata_list: List of data objects containing
33- expression data
30+ 1. Calculate gene-gene correlation within each dataset independently
31+ 2. Determine the union of genes across all datasets
32+ 3. Average correlations where genes overlap (weighted by dataset availability)
33+ 4. Cluster averaged correlations using Leiden community detection
34+ 5. Assign module IDs back to individual datasets
35+
36+ The key advantage is handling datasets with non-overlapping gene sets while
37+ properly weighting correlations based on the number of datasets in which both
38+ genes appear.
39+
40+ **Modifies input datasets in-place:**
41+ - Adds .varp['{layer}_corrcoef'] with gene-gene correlation matrix
42+ - Adds .var[output_key] with integer gene module assignments
43+ - Adds .varm['{layer}_umap'] with UMAP coordinates in shared space
44+
45+ :param adata_list: List of AnnData objects containing expression data.
46+ Datasets may have different genes but should have comparable expression scales.
3447 :type adata_list: list(ad.AnnData)
35- :param layer: Layer to calculate correlation and find
36- modules from, can provide a list with the same length
37- as adata_list , defaults to 'X'
38- :type layer: str, optional
39- :param n_neighbors: Number of neighbors in kNN, defaults
40- to 10
48+ :param layer: Layer to use for correlation calculation. Can be 'X' or any layer name.
49+ Accepts either a single string (applied to all datasets) or a list with one
50+ layer name per dataset , defaults to 'X'
51+ :type layer: str or list(str) , optional
52+ :param n_neighbors: Number of neighbors for kNN graph construction used in
53+ clustering and UMAP, defaults to 10
4154 :type n_neighbors: int, optional
42- :param leiden_kwargs: Keyword arguments to sc.tl.leiden
55+ :param leiden_kwargs: Additional keyword arguments passed to leiden clustering
56+ (e.g., resolution parameter), defaults to {}
4357 :type leiden_kwargs: dict, optional
44- :param output_key: Column to add to adata .var with module IDs,
58+ :param output_key: Column name to add to .var for storing module IDs,
4559 defaults to 'gene_module'
4660 :type output_key: str, optional
47- :param obs_mask: Boolean mask or slice for observations to consider
48- :type obs_mask: np.ndarray or slice, optional
49-
50- :return: A correlation (gene x gene) anndata object with:
51- Gene-gene correlation UMAP in 'X_umap' in .varm
52- Module membership IDs in .obs[output_key]
61+ :param obs_mask: Boolean mask or slice to subset observations (cells) before
62+ computing correlations. Can be a single mask (applied to all) or list of masks
63+ (one per dataset). Useful for computing modules from specific cell populations.
64+ :type obs_mask: np.ndarray, slice, list, or None, optional
65+
66+ :return: An AnnData object representing the full gene-gene correlation matrix with:
67+ - .layers['{layer}_corrcoef']: Averaged correlation matrix (n_genes x n_genes)
68+ - .obs['leiden'] and .var['leiden']: Module assignments (symmetric)
69+ - .obsm['{layer}_umap']: UMAP embedding of gene correlations
5370 :rtype: ad.AnnData
5471 """
5572
5673 _n_datasets = len (adata_list )
5774
75+ # Initialize obs_mask to None for all datasets if not provided
5876 if obs_mask is None :
5977 obs_mask = [None ] * _n_datasets
6078
61- # Make sure all of the arguments are iterable lists of
62- # correct length and raise an AttributeError if not
79+ # Helper function to broadcast scalar arguments to lists
80+ # matching the number of datasets. This allows users to pass
81+ # a single layer name that applies to all datasets, or a list
82+ # with one layer name per dataset.
6383 def _to_iterable (arg , argname ):
6484 if not isinstance (arg , _ITER_TYPES ):
85+ # Broadcast scalar to list
6586 return [arg ] * _n_datasets
6687 elif len (arg ) != _n_datasets :
6788 raise AttributeError (
@@ -72,34 +93,40 @@ def _to_iterable(arg, argname):
7293
7394 layer = _to_iterable (layer , 'layer' )
7495
75- # Calculate correlation for each dataset
96+ # Step 1: Calculate gene-gene correlation for each dataset independently
97+ # Store results in each dataset's .varp to enable caching across calls
7698 for adata , layer_i , mask_i in zip (
7799 adata_list ,
78100 layer ,
79101 obs_mask
80102 ):
81-
103+ # Skip if correlation already computed for this layer
82104 if f'{ layer_i } _corrcoef' not in adata .varp .keys ():
105+ # Get reference to expression data (X or specified layer)
83106 _lref = adata .X if layer_i == 'X' else adata .layers [layer_i ]
84107
108+ # Apply observation mask if provided (subset to specific cells)
109+ # Then compute gene-gene correlation matrix
85110 adata .varp [f'{ layer_i } _corrcoef' ] = corrcoef (
86111 _lref [mask_i , :] if mask_i is not None else _lref
87112 )
88113
89114 del _lref
90115
91- # Check to see if all the data is already aligned
92- # If not, find the union of all the var_names
116+ # Step 2: Determine the gene universe across all datasets
117+ # Check if all datasets share identical gene names (same genes, same order)
93118 if all (
94119 all (
95120 adata .var_names .equals (a .var_names )
96121 for a in adata_list
97122 )
98123 for adata in adata_list
99124 ):
125+ # All datasets aligned - no reindexing needed
100126 _genes = adata_list [0 ].var_names .copy ()
101127 _do_reindex = False
102128 else :
129+ # Datasets have different genes - compute union of all gene names
103130 _genes = reduce (
104131 lambda x , y : x .var_names .union (y .var_names ),
105132 adata_list
@@ -108,75 +135,89 @@ def _to_iterable(arg, argname):
108135
109136 _n_genes = len (_genes )
110137
111- # Get the number of times each gene appears
138+ # Count how many datasets contain each gene
139+ # This is used later for weighted averaging of correlations
112140 _gene_counts = reduce (
113141 lambda x , y : x + y ,
114142 [_genes .isin (c .var_names ).astype (int ) for c in adata_list ]
115143 )
116144
117- # Create a zeroed correlation matrix for the
118- # gene union
145+ # Step 3: Initialize combined correlation matrix
146+ # Create an AnnData object to hold the gene x gene correlation
147+ # Both obs and var represent genes (symmetric gene-gene matrix)
119148 full_correlation = ad .AnnData (
120149 csr_matrix ((_n_genes , _n_genes )),
121150 var = pd .DataFrame (index = _genes ),
122151 obs = pd .DataFrame (index = _genes )
123152 )
124153
154+ # Store the summed correlations in a layer (will average later)
125155 _corr_layer = f'{ layer [0 ]} _corrcoef'
126156 full_correlation .layers [_corr_layer ] = np .zeros (
127157 (_n_genes , _n_genes ),
128158 dtype = float
129159 )
130160
131- # Iterate through all the anndata object and add
132- # each correlation into the full_correlation
133- # indexed appropriately for the feaure name
161+ # Step 4: Accumulate correlation matrices from each dataset
162+ # Sum correlations into the appropriate positions based on gene names
134163 for adata , layer_i in zip (
135164 adata_list ,
136165 layer
137166 ):
138-
167+
139168 if _do_reindex :
169+ # Map each dataset's genes to their positions in the full gene universe
170+ # np.ix_ creates a mesh for 2D indexing (row indices x column indices)
140171 full_correlation .layers [_corr_layer ][
141172 np .ix_ (
142173 full_correlation .obs_names .get_indexer (adata .var_names ),
143174 full_correlation .var_names .get_indexer (adata .var_names )
144175 )
145176 ] += adata .varp [f'{ layer_i } _corrcoef' ]
146177 else :
178+ # All datasets aligned - direct addition
147179 full_correlation .layers [_corr_layer ] += adata .varp [f'{ layer_i } _corrcoef' ]
148-
149- # Correct for the number of times the gene-gene correlation
150- # was calculated
180+
181+ # Step 5: Compute weighted average of correlations
182+ # Divide by the minimum count between gene pairs to get proper average
183+ # Example: if gene A appears in 3 datasets and gene B in 2 datasets,
184+ # their correlation can only be calculated in min(3,2)=2 datasets
151185 full_correlation .layers [_corr_layer ] /= np .minimum (
152- _gene_counts [:, None ],
153- _gene_counts [None , :]
186+ _gene_counts [:, None ], # Row-wise gene counts
187+ _gene_counts [None , :] # Column-wise gene counts
154188 )
155189
190+ # Sanity check: correlations should be in [-1, 1]
156191 assert full_correlation .layers [_corr_layer ].max () <= 1.0
157-
192+
193+ # Step 6: Cluster genes based on correlation patterns
194+ # Build kNN graph, run Leiden clustering, and compute UMAP embedding
158195 _corr_results = correlation_clustering_and_umap (
159196 full_correlation .layers [_corr_layer ],
160197 n_neighbors = n_neighbors ,
161198 var_names = full_correlation .var_names ,
162199 ** leiden_kwargs
163200 )
164201
202+ # Store clustering results in both obs and var (symmetric since genes x genes)
165203 full_correlation .obs ['leiden' ] = _corr_results .obs ['leiden' ].astype (int ).values
166204 full_correlation .var ['leiden' ] = _corr_results .obs ['leiden' ].astype (int ).values
167205 full_correlation .obsm [f'{ layer_i } _umap' ] = _corr_results .obsm ['X_umap' ]
168206
207+ # Step 7: Propagate module assignments back to individual datasets
208+ # Each dataset gets the module IDs for its own genes
169209 for adata , layer_i in zip (
170210 adata_list ,
171211 layer
172212 ):
213+ # Find positions of this dataset's genes in the full correlation results
173214 _gene_idx = _corr_results .var_names .get_indexer (adata .var_names )
174215
175- # Put the gene module memberships in
216+ # Assign module membership to each gene
176217 adata .var [output_key ] = _corr_results .obs ['leiden' ]
177218
178- # Put the partial umap into the separate objects
179- # so they can be plotted in the same space
219+ # Store the UMAP coordinates for this dataset's genes
220+ # This allows plotting genes from different datasets in the same UMAP space
180221 adata .varm [f'{ layer_i } _umap' ] = _corr_results .obsm ['X_umap' ][_gene_idx , :]
181222
182223 return full_correlation
0 commit comments