|
| 1 | +import pandas as pd |
| 2 | +import numpy as np |
| 3 | +import anndata as ad |
| 4 | + |
| 5 | +from functools import reduce |
| 6 | + |
| 7 | +from scself.utils.correlation import ( |
| 8 | + corrcoef, |
| 9 | + correlation_clustering_and_umap |
| 10 | +) |
| 11 | + |
| 12 | +_ITER_TYPES = (tuple, list, pd.Series, np.ndarray) |
| 13 | + |
| 14 | +def get_combined_correlation_modules( |
| 15 | + adata_list, |
| 16 | + layer='X', |
| 17 | + n_neighbors=10, |
| 18 | + leiden_kwargs={}, |
| 19 | + output_key='gene_module', |
| 20 | + obs_mask=None |
| 21 | +): |
| 22 | + """ |
| 23 | + Get correlation modules from a list of anndata objects |
| 24 | + by calculating the correlation separately for each object |
| 25 | + and averaging the correlation between genes where the |
| 26 | + objects overlap. |
| 27 | +
|
| 28 | + Adds .varp['X_corrcoef'] with gene-gene correlation |
| 29 | + and .var[output_key] with gene module ID |
| 30 | +
|
| 31 | + :param adata_list: List of data objects containing |
| 32 | + expression data |
| 33 | + :type adata_list: list(ad.AnnData) |
| 34 | + :param layer: Layer to calculate correlation and find |
| 35 | + modules from, can provide a list with the same length |
| 36 | + as adata_list, defaults to 'X' |
| 37 | + :type layer: str, optional |
| 38 | + :param n_neighbors: Number of neighbors in kNN, defaults |
| 39 | + to 10 |
| 40 | + :type n_neighbors: int, optional |
| 41 | + :param leiden_kwargs: Keyword arguments to sc.tl.leiden |
| 42 | + :type leiden_kwargs: dict, optional |
| 43 | + :param output_key: Column to add to adata.var with module IDs, |
| 44 | + defaults to 'gene_module' |
| 45 | + :type output_key: str, optional |
| 46 | + :param obs_mask: Boolean mask or slice for observations to consider |
| 47 | + :type obs_mask: np.ndarray or slice, optional |
| 48 | +
|
| 49 | + :return: A correlation (gene x gene) anndata object with: |
| 50 | + Gene-gene correlation UMAP in 'X_umap' in .varm |
| 51 | + Module membership IDs in .obs[output_key] |
| 52 | + :rtype: ad.AnnData |
| 53 | + """ |
| 54 | + |
| 55 | + _n_datasets = len(adata_list) |
| 56 | + |
| 57 | + if obs_mask is None: |
| 58 | + obs_mask = [None] * _n_datasets |
| 59 | + |
| 60 | + # Make sure all of the arguments are iterable lists of |
| 61 | + # correct length and raise an AttributeError if not |
| 62 | + def _to_iterable(arg, argname): |
| 63 | + if not isinstance(arg, _ITER_TYPES): |
| 64 | + return [arg] * _n_datasets |
| 65 | + elif len(arg) != _n_datasets: |
| 66 | + raise AttributeError( |
| 67 | + f"len({argname}) = {len(arg)}; {_n_datasets} is required" |
| 68 | + ) |
| 69 | + else: |
| 70 | + return arg |
| 71 | + |
| 72 | + layer = _to_iterable(layer, 'layer') |
| 73 | + |
| 74 | + # Calculate correlation for each dataset |
| 75 | + for adata, layer_i, mask_i in zip( |
| 76 | + adata_list, |
| 77 | + layer, |
| 78 | + obs_mask |
| 79 | + ): |
| 80 | + |
| 81 | + if f'{layer_i}_corrcoef' not in adata.varp.keys(): |
| 82 | + _lref = adata.X if layer_i == 'X' else adata.layers[layer_i] |
| 83 | + |
| 84 | + adata.varp[f'{layer_i}_corrcoef'] = corrcoef( |
| 85 | + _lref[mask_i, :] if mask_i is not None else _lref |
| 86 | + ) |
| 87 | + |
| 88 | + del _lref |
| 89 | + |
| 90 | + # Check to see if all the data is already aligned |
| 91 | + # If not, find the union of all the var_names |
| 92 | + if all( |
| 93 | + all( |
| 94 | + adata.var_names.equals(a.var_names) |
| 95 | + for a in adata_list |
| 96 | + ) |
| 97 | + for adata in adata_list |
| 98 | + ): |
| 99 | + _genes = adata_list[0].var_names.copy() |
| 100 | + _do_reindex=False |
| 101 | + else: |
| 102 | + _genes = reduce( |
| 103 | + lambda x, y: x.var_names.union(y.var_names), |
| 104 | + adata_list |
| 105 | + ) |
| 106 | + _do_reindex=True |
| 107 | + |
| 108 | + _n_genes = len(_genes) |
| 109 | + |
| 110 | + # Get the number of times each gene appears |
| 111 | + _gene_counts = reduce( |
| 112 | + lambda x, y: x + y, |
| 113 | + [_genes.isin(c.var_names).astype(int) for c in adata_list] |
| 114 | + ) |
| 115 | + |
| 116 | + # Create a zeroed correlation matrix for the |
| 117 | + # gene union |
| 118 | + full_correlation = ad.AnnData( |
| 119 | + np.zeros( |
| 120 | + (_n_genes, _n_genes), |
| 121 | + dtype=float |
| 122 | + ), |
| 123 | + var=pd.DataFrame(index=_genes), |
| 124 | + obs=pd.DataFrame(index=_genes) |
| 125 | + ) |
| 126 | + |
| 127 | + # Iterate through all the anndata object and add |
| 128 | + # each correlation into the full_correlation |
| 129 | + # indexed appropriately for the feaure name |
| 130 | + for adata, layer_i in zip( |
| 131 | + adata_list, |
| 132 | + layer |
| 133 | + ): |
| 134 | + |
| 135 | + if _do_reindex: |
| 136 | + full_correlation.X[ |
| 137 | + np.ix_( |
| 138 | + full_correlation.obs_names.get_indexer(adata.var_names), |
| 139 | + full_correlation.var_names.get_indexer(adata.var_names) |
| 140 | + ) |
| 141 | + ] += adata.varp[f'{layer_i}_corrcoef'] |
| 142 | + else: |
| 143 | + full_correlation.X += adata.varp[f'{layer_i}_corrcoef'] |
| 144 | + |
| 145 | + # Correct for the number of times the gene-gene correlation |
| 146 | + # was calculated |
| 147 | + full_correlation.X /= np.minimum( |
| 148 | + _gene_counts[:, None], |
| 149 | + _gene_counts[None, :] |
| 150 | + ) |
| 151 | + |
| 152 | + full_correlation = correlation_clustering_and_umap( |
| 153 | + full_correlation.X, |
| 154 | + n_neighbors=n_neighbors, |
| 155 | + var_names=full_correlation.var_names, |
| 156 | + **leiden_kwargs |
| 157 | + ) |
| 158 | + |
| 159 | + full_correlation.obs['leiden'] = full_correlation.obs['leiden'].astype(int) |
| 160 | + |
| 161 | + for adata, layer_i in zip( |
| 162 | + adata_list, |
| 163 | + layer |
| 164 | + ): |
| 165 | + _gene_idx = full_correlation.var_names.get_indexer(adata.var_names) |
| 166 | + |
| 167 | + # Put the gene module memberships in |
| 168 | + adata.var[output_key] = full_correlation.obs['leiden'] |
| 169 | + |
| 170 | + # Put the partial umap into the separate objects |
| 171 | + # so they can be plotted in the same space |
| 172 | + adata.varm[f'{layer_i}_umap'] = full_correlation.obsm['X_umap'][_gene_idx, :] |
| 173 | + |
| 174 | + return full_correlation |
0 commit comments