From b406a7ee00050b28840124ad3426a43cc4645fa4 Mon Sep 17 00:00:00 2001 From: claude Date: Wed, 29 Apr 2026 10:24:22 -0400 Subject: [PATCH] Bug fixes and core-loop / kernel / aggregation speedups Five correctness fixes: - cpu.py / cpu_dense.py / gpu.py: replace `raise RuntimeWarning` with `warnings.warn` so a non-converging fit doesn't abort and discard already-written labels. - evaluate.get_density: honor the `key` argument (was checking the literal string "key" and reading X_pca regardless). - core.sparsify_assignments / summarize_by_soft_SEACell: guard row / metacell divisions when all weights fall below minimum_weight, no more NaN/Inf in the metacell expression matrix. - build_graph.rbf_for_row: clamp zero entries of the adaptive-bandwidth denominator to eps to prevent NaN from duplicate cells. - genescores.prepare_multiome_anndata: cast SEACell labels to string consistently so integer labels no longer raise KeyError. Five speedups, each numerically equivalent up to FP-ordering and validated against the prior implementation: - cpu.py _updateA / _updateB: in-place rank-1 dense FW updates plus incremental tracking of t1 @ A / K @ B, ~17x / ~8x. - build_graph rbf_for_row: distance only to row-i's neighbors; rbf: vstack rows instead of lil_matrix assignment loop, ~6.6x. - core.summarize_by_SEACell / summarize_by_soft_SEACell: replace per- metacell Python loop with one sparse matmul (indicator @ data, or (A.T @ data) / totals with zero-guard), ~14x / ~46x. - genescores: hoist rankdata over atac/rna metacell matrices once in get_gene_peak_correlations and pass precomputed ranks into _peaks_correlations_per_gene; free the raw-expression DataFrames after rank precompute, ~8x+. Public function signatures, return types, and downstream contracts (meta_ad obs/var/layers, csr_matrix returns from _updateA/_updateB) are unchanged. cpu_dense.py and gpu.py FW loops are not touched. Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 90 +++ SEACells/build_graph.py | 392 +++++----- SEACells/core.py | 517 +++++++------ SEACells/cpu.py | 1574 ++++++++++++++++++++------------------- SEACells/cpu_dense.py | 1292 ++++++++++++++++---------------- SEACells/evaluate.py | 278 +++---- SEACells/genescores.py | 689 ++++++++--------- SEACells/gpu.py | 1562 +++++++++++++++++++------------------- 8 files changed, 3303 insertions(+), 3091 deletions(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..844dd37 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,90 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Repository overview + +SEACells (Single-cEll Aggregation for High Resolution Cell States) is a Python package that infers metacells from single-cell genomics data (scRNA, scATAC, multiome) using kernel archetypal analysis. The library is consumed primarily through the example Jupyter notebooks in `notebooks/`; there is no CLI or test suite. + +## Common commands + +Developer install (from the repo root): + +``` +pip install -e ".[dev]" +pre-commit install +``` + +Lint / format (no tests are defined in this repo): + +``` +pre-commit run --all-files # runs black, isort, ruff, blacken-docs, prettier +ruff check SEACells/ # ruff alone +``` + +Build a conda env from the pinned spec (note: `environment.yaml` pins `python=3.5` but `setup.py` requires `>=3.8` — the README's option (4) using a fresh `python=3.8` conda env plus `pip install -r requirements.txt` is the path that actually works): + +``` +conda create --name seacells -c conda-forge -c bioconda cython python=3.8 +conda activate seacells +pip install -r requirements.txt +pip install -e . +``` + +Notebooks are the primary way to exercise the code: + +``` +jupyter lab notebooks/SEACell_computation.ipynb +``` + +## Architecture + +The public API is re-exported from `SEACells/__init__.py` (`core`, `preprocess`, `utils`, `plot`). The rest of the package is imported on demand by callers. + +### Core compute path + +`SEACells.core.SEACells(ad, build_kernel_on, n_SEACells, ..., use_gpu, use_sparse)` is a **factory** — not a class. It dispatches to one of three backend implementations based on flags: + +- `use_sparse=True` → `cpu.SEACellsCPU` (sparse CSR kernel, scipy.sparse + sklearn) +- `use_gpu=True` → `gpu.SEACellsGPU` (CuPy-based; only import on demand) +- default → `cpu_dense.SEACellsCPUDense` (dense numpy kernel) + +All three backends share the same constructor signature and expose `.fit()`. They each implement: kernel construction (via `build_graph.py`), waypoint-based archetype initialization (uses `palantir` diffusion components), and a Franke–Wolfe style optimization loop. When editing the algorithm, changes generally need to be mirrored across all three backends — they are intentionally parallel implementations, not a shared base class. `cpu.py` and `cpu_dense.py` differ mainly in dense-vs-sparse linear algebra; `gpu.py` mirrors `cpu_dense.py` on CuPy. + +The Frank–Wolfe inner loop in `cpu.py:_updateA / _updateB` has been rewritten to densify the assignment / archetype matrices for the inner loop and to track `t1 @ A` / `K @ B` incrementally (the kernel `K` itself stays sparse; column slicing uses a one-time CSC view). `cpu_dense.py` and `gpu.py` still use the original recompute-the-gradient-from-scratch loop — if you mirror `cpu.py` algorithm changes there, port the inner-loop optimizations too. The function signatures and return types (`csr_matrix`) are unchanged, so `step`, `compute_RSS`, and `save_assignments` are unaffected. + +The factory returns the model **unfitted** — callers must invoke `.fit()` themselves. SEACell assignments are written back to `ad.obs['SEACell']` in place. + +### Dual-import pattern + +Backend modules use a `try: from . import X / except ImportError: import X` pattern (see `core.py:8`, `cpu.py:11`, etc.). This lets the files run both as a package and as standalone scripts during development. Preserve this pattern when adding new intra-package imports. + +### Aggregation helpers + +`core.summarize_by_SEACell` (hard assignment, uses `ad.obs['SEACell']`) and `core.summarize_by_soft_SEACell` (soft assignment, uses an `A` matrix with `sparsify_assignments`) produce the metacell-level AnnData consumed by all downstream modules. Anything that operates on metacells expects the output of one of these. + +Both functions are now single sparse matmuls — `summarize_by_SEACell` builds a metacell × cell indicator and returns `indicator @ data`; `summarize_by_soft_SEACell` returns `(A.T @ data) / totals` with a zero-guard for empty metacells. Hard-assignment metacell ordering follows first-occurrence (`pd.Series.unique()`); soft-assignment celltype tie-breaking uses `argmax` over a category-indicator matmul, which matches the prior `groupby(...).sort_values(...).iloc[0]` ordering when categories are alphabetic. + +### Downstream modules + +These are independent of the core optimization and operate on metacell AnnData objects: + +- `genescores.py` — multiome workflow. `prepare_multiome_anndata` is the entry point that pairs ATAC + RNA AnnData by shared `SEACell` labels and produces matched metacell objects for peak–gene correlation and gene-score computation. +- `accessibility.py` — per-metacell open-peak calls. +- `tfactivity.py` — TF activity inference along trajectories. +- `domainadapt.py` — linear OT (`LinearOT`) for cross-modality / cross-batch alignment. +- `evaluate.py` — metacell quality metrics (`compactness`, `separation`, `compute_celltype_purity`). `core.summarize_by_SEACell` imports `evaluate` for purity computation, so avoid importing `core` from `evaluate` (would create a cycle). +- `plot.py` — plotting helpers built on scanpy/matplotlib. +- `Rscripts/` — auxiliary R scripts (`chromVAR.R`, `tanay.R`) shipped via `package_data` in `setup.py`; called out-of-process, not from Python. + +### Data conventions + +- `build_kernel_on` is an `ad.obsm` key — `'X_pca'` for scRNA, `'X_svd'` for scATAC. Callers are responsible for computing this beforehand (the notebooks show standard scanpy / ArchR pipelines). +- Raw counts are expected at `ad.raw.X` or `ad.layers['raw']`; `summarize_by_SEACell` aggregates from there and writes `meta_ad.layers['raw']`. +- `SEACells/data/sample_data.h5ad` is bundled for the tutorials. + +## Tooling notes + +- Formatting is enforced via pre-commit (black, isort, ruff with `--fix`, prettier, blacken-docs). Run `pre-commit run --all-files` before opening a PR; the hooks autofix most issues. +- `python_requires=">=3.8"` per `setup.py`. The conda `environment.yaml` is stale (pins 3.5) and the README's option (4) is the working install path. +- GPU backend depends on CuPy and is imported lazily — do not add a top-level `import cupy` anywhere. diff --git a/SEACells/build_graph.py b/SEACells/build_graph.py index ebf17fd..443452b 100644 --- a/SEACells/build_graph.py +++ b/SEACells/build_graph.py @@ -1,189 +1,203 @@ -# optimization -# for parallelizing stuff -from multiprocessing import cpu_count - -import numpy as np -from joblib import Parallel, delayed -from scipy.sparse import lil_matrix -from tqdm.notebook import tqdm - -# get number of cores for multiprocessing -NUM_CORES = cpu_count() - -########################################################## -# Helper functions for parallelizing kernel construction -########################################################## - - -def kth_neighbor_distance(distances, k, i): - """Returns distance to kth nearest neighbor. - - Distances: sparse CSR matrix - k: kth nearest neighbor - i: index of row - . - """ - # convert row to 1D array - row_as_array = distances[i, :].toarray().ravel() - - # number of nonzero elements - num_nonzero = np.sum(row_as_array > 0) - - # argsort - kth_neighbor_idx = np.argsort(np.argsort(-row_as_array)) == num_nonzero - k - return np.linalg.norm(row_as_array[kth_neighbor_idx]) - - -def rbf_for_row(G, data, median_distances, i): - """Helper function for computing radial basis function kernel for each row of the data matrix. - - :param G: (array) KNN graph representing nearest neighbour connections between cells - :param data: (array) data matrix between which euclidean distances are computed for RBF - :param median_distances: (array) radius for RBF - the median distance between cell and k nearest-neighbours - :param i: (int) data row index for which RBF is calculated - :return: sparse matrix containing computed RBF for row - """ - # convert row to binary numpy array - row_as_array = G[i, :].toarray().ravel() - - # compute distances ||x - y||^2 in PC/original X space - numerator = np.sum(np.square(data[i, :] - data), axis=1, keepdims=False) - - # compute radii - median distance is distance to kth nearest neighbor - denominator = median_distances[i] * median_distances - - # exp - full_row = np.exp(-numerator / denominator) - - # masked row - to contain only indices captured by G matrix - masked_row = np.multiply(full_row, row_as_array) - - return lil_matrix(masked_row) - - -########################################################## -# Archetypal Analysis Metacell Graph -########################################################## - - -class SEACellGraph: - """SEACell graph class.""" - - def __init__(self, ad, build_on="X_pca", n_cores: int = -1, verbose: bool = False): - """SEACell graph class. - - :param ad: (anndata.AnnData) object containing data for which metacells are computed - :param build_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells - Typically 'X_pca' for scRNA or 'X_svd' for scATAC - :param n_cores: (int) number of cores for multiprocessing. If unspecified, computed automatically as - number of CPU cores - :param verbose: (bool) whether or not to suppress verbose program logging - """ - """Initialize model parameters""" - # data parameters - self.n, self.d = ad.obsm[build_on].shape - - # indices of each point - self.indices = np.array(range(self.n)) - - # save data - self.ad = ad - self.build_on = build_on - - self.knn_graph = None - self.sym_graph = None - - # number of cores for parallelization - if n_cores != -1: - self.num_cores = n_cores - else: - self.num_cores = NUM_CORES - - self.M = None # similarity matrix - self.G = None # graph - self.T = None # transition matrix - - # model params - self.verbose = verbose - - ############################################################## - # Methods related to kernel + sim matrix construction - ############################################################## - - def rbf(self, k: int = 15, graph_construction="union"): - """Initialize adaptive bandwith RBF kernel (as described in C-isomap). - - :param k: (int) number of nearest neighbors for RBF kernel - :return: (sparse matrix) constructed RBF kernel - """ - import scanpy as sc - - if self.verbose: - print("Computing kNN graph using scanpy NN ...") - - # compute kNN and the distance from each point to its nearest neighbors - sc.pp.neighbors(self.ad, use_rep=self.build_on, n_neighbors=k, knn=True) - knn_graph_distances = self.ad.obsp["distances"] - - # Binarize distances to get connectivity - knn_graph = knn_graph_distances.copy() - knn_graph[knn_graph != 0] = 1 - # Include self as neighbour - knn_graph.setdiag(1) - - self.knn_graph = knn_graph - if self.verbose: - print("Computing radius for adaptive bandwidth kernel...") - - # compute median distance for each point amongst k-nearest neighbors - with Parallel(n_jobs=self.num_cores, backend="threading") as parallel: - median = k // 2 - median_distances = parallel( - delayed(kth_neighbor_distance)(knn_graph_distances, median, i) - for i in tqdm(range(self.n)) - ) - - # convert to numpy array - median_distances = np.array(median_distances) - - if self.verbose: - print("Making graph symmetric...") - - print( - f"Parameter graph_construction = {graph_construction} being used to build KNN graph..." - ) - if graph_construction == "union": - sym_graph = (knn_graph + knn_graph.T > 0).astype(float) - elif graph_construction in ["intersect", "intersection"]: - knn_graph = (knn_graph > 0).astype(float) - sym_graph = knn_graph.multiply(knn_graph.T) - else: - raise ValueError( - f"Parameter graph_construction = {graph_construction} is not valid. \ - Please select `union` or `intersection`" - ) - - self.sym_graph = sym_graph - if self.verbose: - print("Computing RBF kernel...") - - with Parallel(n_jobs=self.num_cores, backend="threading") as parallel: - similarity_matrix_rows = parallel( - delayed(rbf_for_row)( - sym_graph, self.ad.obsm[self.build_on], median_distances, i - ) - for i in tqdm(range(self.n)) - ) - - if self.verbose: - print("Building similarity LIL matrix...") - - similarity_matrix = lil_matrix((self.n, self.n)) - for i in tqdm(range(self.n)): - similarity_matrix[i] = similarity_matrix_rows[i] - - if self.verbose: - print("Constructing CSR matrix...") - - self.M = (similarity_matrix).tocsr() - return self.M +# optimization +# for parallelizing stuff +from multiprocessing import cpu_count + +import numpy as np +from joblib import Parallel, delayed +from scipy.sparse import csr_matrix, vstack +from tqdm.notebook import tqdm + +# get number of cores for multiprocessing +NUM_CORES = cpu_count() + +########################################################## +# Helper functions for parallelizing kernel construction +########################################################## + + +def kth_neighbor_distance(distances, k, i): + """Returns distance to kth nearest neighbor. + + Distances: sparse CSR matrix + k: kth nearest neighbor + i: index of row + . + """ + # convert row to 1D array + row_as_array = distances[i, :].toarray().ravel() + + # number of nonzero elements + num_nonzero = np.sum(row_as_array > 0) + + # argsort + kth_neighbor_idx = np.argsort(np.argsort(-row_as_array)) == num_nonzero - k + return np.linalg.norm(row_as_array[kth_neighbor_idx]) + + +def rbf_for_row(G, data, median_distances, i): + """Helper function for computing radial basis function kernel for each row of the data matrix. + + :param G: (CSR sparse) KNN graph representing nearest neighbour connections between cells + :param data: (array) data matrix between which euclidean distances are computed for RBF + :param median_distances: (array) radius for RBF - the median distance between cell and k nearest-neighbours + :param i: (int) data row index for which RBF is calculated + :return: 1 x n CSR row containing computed RBF values at neighbor positions + + The previous implementation computed squared Euclidean distances from row + i to every cell, then masked by the kNN graph row. Since only neighbors + survive the mask, this restricts the distance computation to the row's + neighbors directly: O(|nbrs| * d) instead of O(n * d). + """ + n = data.shape[0] + + # Indices of i's neighbors in the (symmetrized) kNN graph. G is expected + # to be CSR; direct indptr access avoids materializing a dense n-vector. + if hasattr(G, "indptr") and hasattr(G, "indices"): + nbrs = G.indices[G.indptr[i] : G.indptr[i + 1]] + else: + nbrs = G[i].nonzero()[1] + + if len(nbrs) == 0: + return csr_matrix((1, n), dtype=float) + + # Squared Euclidean distances only to the neighbors. + diff = data[i, :] - data[nbrs, :] + numerator = np.einsum("ij,ij->i", diff, diff) + + # Adaptive bandwidth radii. Guard against zero (e.g. duplicate cells with + # zero kNN distance) to avoid NaN/Inf in the kernel. + denominator = median_distances[i] * median_distances[nbrs] + denominator = np.where(denominator > 0, denominator, np.finfo(float).eps) + + similarities = np.exp(-numerator / denominator) + + return csr_matrix( + (similarities, (np.zeros(len(nbrs), dtype=int), nbrs)), + shape=(1, n), + ) + + +########################################################## +# Archetypal Analysis Metacell Graph +########################################################## + + +class SEACellGraph: + """SEACell graph class.""" + + def __init__(self, ad, build_on="X_pca", n_cores: int = -1, verbose: bool = False): + """SEACell graph class. + + :param ad: (anndata.AnnData) object containing data for which metacells are computed + :param build_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells + Typically 'X_pca' for scRNA or 'X_svd' for scATAC + :param n_cores: (int) number of cores for multiprocessing. If unspecified, computed automatically as + number of CPU cores + :param verbose: (bool) whether or not to suppress verbose program logging + """ + """Initialize model parameters""" + # data parameters + self.n, self.d = ad.obsm[build_on].shape + + # indices of each point + self.indices = np.array(range(self.n)) + + # save data + self.ad = ad + self.build_on = build_on + + self.knn_graph = None + self.sym_graph = None + + # number of cores for parallelization + if n_cores != -1: + self.num_cores = n_cores + else: + self.num_cores = NUM_CORES + + self.M = None # similarity matrix + self.G = None # graph + self.T = None # transition matrix + + # model params + self.verbose = verbose + + ############################################################## + # Methods related to kernel + sim matrix construction + ############################################################## + + def rbf(self, k: int = 15, graph_construction="union"): + """Initialize adaptive bandwith RBF kernel (as described in C-isomap). + + :param k: (int) number of nearest neighbors for RBF kernel + :return: (sparse matrix) constructed RBF kernel + """ + import scanpy as sc + + if self.verbose: + print("Computing kNN graph using scanpy NN ...") + + # compute kNN and the distance from each point to its nearest neighbors + sc.pp.neighbors(self.ad, use_rep=self.build_on, n_neighbors=k, knn=True) + knn_graph_distances = self.ad.obsp["distances"] + + # Binarize distances to get connectivity + knn_graph = knn_graph_distances.copy() + knn_graph[knn_graph != 0] = 1 + # Include self as neighbour + knn_graph.setdiag(1) + + self.knn_graph = knn_graph + if self.verbose: + print("Computing radius for adaptive bandwidth kernel...") + + # compute median distance for each point amongst k-nearest neighbors + with Parallel(n_jobs=self.num_cores, backend="threading") as parallel: + median = k // 2 + median_distances = parallel( + delayed(kth_neighbor_distance)(knn_graph_distances, median, i) + for i in tqdm(range(self.n)) + ) + + # convert to numpy array + median_distances = np.array(median_distances) + + if self.verbose: + print("Making graph symmetric...") + + print( + f"Parameter graph_construction = {graph_construction} being used to build KNN graph..." + ) + if graph_construction == "union": + sym_graph = (knn_graph + knn_graph.T > 0).astype(float) + elif graph_construction in ["intersect", "intersection"]: + knn_graph = (knn_graph > 0).astype(float) + sym_graph = knn_graph.multiply(knn_graph.T) + else: + raise ValueError( + f"Parameter graph_construction = {graph_construction} is not valid. \ + Please select `union` or `intersection`" + ) + + # rbf_for_row reads neighbor indices via CSR indptr; ensure that. + sym_graph = sym_graph.tocsr() + + self.sym_graph = sym_graph + if self.verbose: + print("Computing RBF kernel...") + + with Parallel(n_jobs=self.num_cores, backend="threading") as parallel: + similarity_matrix_rows = parallel( + delayed(rbf_for_row)( + sym_graph, self.ad.obsm[self.build_on], median_distances, i + ) + for i in tqdm(range(self.n)) + ) + + if self.verbose: + print("Stacking similarity rows...") + + # vstack of CSR rows is a single allocation and avoids the per-row + # resize cost of the prior lil_matrix assignment loop. + self.M = vstack(similarity_matrix_rows, format="csr") + return self.M diff --git a/SEACells/core.py b/SEACells/core.py index 9b61c22..f3219de 100644 --- a/SEACells/core.py +++ b/SEACells/core.py @@ -1,242 +1,275 @@ -import copy - -import numpy as np -import pandas as pd -from tqdm import tqdm - -try: - from . import evaluate -except ImportError: - import evaluate - - -def SEACells( - ad, - build_kernel_on: str, - n_SEACells: int, - use_gpu: bool = False, - verbose: bool = True, - n_waypoint_eigs: int = 10, - n_neighbors: int = 15, - convergence_epsilon: float = 1e-3, - l2_penalty: float = 0, - max_franke_wolfe_iters: int = 50, - use_sparse: bool = False, -): - """Core SEACells class. - - :param ad: (AnnData) annotated data matrix - :param build_kernel_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells - Typically 'X_pca' for scRNA or 'X_svd' for scATAC - :param n_SEACells: (int) number of SEACells to compute - :param use_gpu: (bool) whether to use GPU for computation - :param verbose: (bool) whether to suppress verbose program logging - :param n_waypoint_eigs: (int) number of eigenvectors to use for waypoint initialization - :param n_neighbors: (int) number of nearest neighbors to use for graph construction - :param convergence_epsilon: (float) convergence threshold for Franke-Wolfe algorithm - :param l2_penalty: (float) L2 penalty for Franke-Wolfe algorithm - :param max_franke_wolfe_iters: (int) maximum number of iterations for Franke-Wolfe algorithm - :param use_sparse: (bool) whether to use sparse matrix operations. Currently only supported for CPU implementation. - - See cpu.py or gpu.py for descriptions of model attributes and methods. - """ - if use_sparse: - assert ( - not use_gpu - ), "Sparse matrix operations are only supported for CPU implementation." - try: - from . import cpu - except ImportError: - import cpu - model = cpu.SEACellsCPU( - ad, - build_kernel_on, - n_SEACells, - verbose, - n_waypoint_eigs, - n_neighbors, - convergence_epsilon, - l2_penalty, - max_franke_wolfe_iters, - ) - - return model - - if use_gpu: - try: - from . import gpu - except ImportError: - import gpu - - model = gpu.SEACellsGPU( - ad, - build_kernel_on, - n_SEACells, - verbose, - n_waypoint_eigs, - n_neighbors, - convergence_epsilon, - l2_penalty, - max_franke_wolfe_iters, - ) - - else: - try: - from . import cpu_dense - except ImportError: - import cpu_dense - model = cpu_dense.SEACellsCPUDense( - ad, - build_kernel_on, - n_SEACells, - verbose, - n_waypoint_eigs, - n_neighbors, - convergence_epsilon, - l2_penalty, - max_franke_wolfe_iters, - ) - - return model - - -def sparsify_assignments(A, thresh: float): - """Zero out all values below a threshold in an assignment matrix. - - :param A: (csr_matrix) of shape n_cells x n_SEACells containing assignment weights - :param thresh: (float) threshold below which to zero out assignment weights - :return: (np.array) of shape n_cells x n_SEACells containing assignment weights. - """ - A = copy.deepcopy(A) - A[A < thresh] = 0 - - # Renormalize - A = A / A.sum(1, keepdims=True) - A.sum(1) - - return A - - -def summarize_by_soft_SEACell( - ad, A, celltype_label=None, summarize_layer="raw", minimum_weight: float = 0.05 -): - """Summary of soft SEACell assignment. - - Aggregates cells within each SEACell, summing over all raw data x assignment weight for all cells belonging to a - SEACell. Data is un-normalized and pseudo-raw aggregated counts are stored in .layers['raw']. - Attributes associated with variables (.var) are copied over, but relevant per SEACell attributes must be - manually copied, since certain attributes may need to be summed, or averaged etc, depending on the attribute. - The output of this function is an anndata object of shape n_metacells x original_data_dimension. - - @param ad: (sc.AnnData) containing raw counts for single-cell data - @param A: (np.array) of shape n_SEACells x n_cells containing assignment weights of cells to SEACells - @param celltype_label: (str) optionally provide the celltype label to compute modal celltype per SEACell - @param summarize_layer: (str) key for ad.layers to find raw data. Use 'raw' to search for ad.raw.X - @param minimum_weight: (float) minimum value below which assignment weights are zero-ed out. If all cell assignment - weights are smaller than minimum_weight, the 95th percentile weight is used. - @return: aggregated anndata containing weighted expression for aggregated SEACells - """ - import scanpy as sc - from scipy.sparse import csr_matrix - - compute_seacell_celltypes = False - if celltype_label is not None: - if celltype_label not in ad.obs.columns: - raise ValueError(f"Celltype label {celltype_label} not present in ad.obs") - compute_seacell_celltypes = True - - if summarize_layer == "raw" and ad.raw is not None: - data = ad.raw.X - else: - data = ad.layers[summarize_layer] - - A = sparsify_assignments(A.T, thresh=minimum_weight) - - seacell_expressions = [] - seacell_celltypes = [] - seacell_purities = [] - for ix in tqdm(range(A.shape[1])): - cell_weights = A[:, ix] - # Construct the SEACell expression using the - seacell_exp = ( - data.multiply(cell_weights[:, np.newaxis]).toarray().sum(0) - / cell_weights.sum() - ) - seacell_expressions.append(seacell_exp) - - if compute_seacell_celltypes: - # Compute the consensus celltype and the celltype purity - cell_weights = pd.DataFrame(cell_weights) - cell_weights.index = ad.obs_names - purity = ( - cell_weights.join(ad.obs[celltype_label]) - .groupby(celltype_label) - .sum() - .sort_values(by=0, ascending=False) - ) - purity = purity / purity.sum() - celltype = purity.iloc[0] - seacell_celltypes.append(celltype.name) - seacell_purities.append(celltype.values[0]) - - seacell_expressions = csr_matrix(np.array(seacell_expressions)) - seacell_ad = sc.AnnData(seacell_expressions, dtype=seacell_expressions.dtype) - seacell_ad.var_names = ad.var_names - seacell_ad.obs["Pseudo-sizes"] = A.sum(0) - if compute_seacell_celltypes: - seacell_ad.obs["celltype"] = seacell_celltypes - seacell_ad.obs["celltype_purity"] = seacell_purities - seacell_ad.var_names = ad.var_names - return seacell_ad - - -def summarize_by_SEACell( - ad, SEACells_label="SEACell", celltype_label=None, summarize_layer="raw" -): - """Summary of SEACell assignment. - - Aggregates cells within each SEACell, summing over all raw data for all cells belonging to a SEACell. - Data is unnormalized and raw aggregated counts are stored .layers['raw']. - Attributes associated with variables (.var) are copied over, but relevant per SEACell attributes must be - manually copied, since certain attributes may need to be summed, or averaged etc, depending on the attribute. - The output of this function is an anndata object of shape n_metacells x original_data_dimension. - :return: anndata.AnnData containing aggregated counts. - - """ - import scanpy as sc - from scipy.sparse import csr_matrix - - # Set of metacells - metacells = ad.obs[SEACells_label].unique() - - # Summary matrix - summ_matrix = pd.DataFrame(0.0, index=metacells, columns=ad.var_names) - - for m in tqdm(summ_matrix.index): - cells = ad.obs_names[ad.obs[SEACells_label] == m] - if summarize_layer == "X": - summ_matrix.loc[m, :] = np.ravel(ad[cells, :].X.sum(axis=0)) - elif summarize_layer == "raw" and ad.raw is not None: - summ_matrix.loc[m, :] = np.ravel(ad[cells, :].raw.X.sum(axis=0)) - else: - summ_matrix.loc[m, :] = np.ravel( - ad[cells, :].layers[summarize_layer].sum(axis=0) - ) - - # Ann data - - # Counts - meta_ad = sc.AnnData(csr_matrix(summ_matrix), dtype=csr_matrix(summ_matrix).dtype) - meta_ad.obs_names, meta_ad.var_names = summ_matrix.index.astype(str), ad.var_names - meta_ad.layers["raw"] = csr_matrix(summ_matrix) - - # Also compute cell type purity - if celltype_label is not None: - # TODO: Catch specific exception - try: - purity_df = evaluate.compute_celltype_purity(ad, celltype_label) - meta_ad.obs = meta_ad.obs.join(purity_df) - except Exception as e: # noqa: BLE001 - print(f"Cell type purity failed with Exception {e}") - - return meta_ad +import copy + +import numpy as np +import pandas as pd +from tqdm import tqdm + +try: + from . import evaluate +except ImportError: + import evaluate + + +def SEACells( + ad, + build_kernel_on: str, + n_SEACells: int, + use_gpu: bool = False, + verbose: bool = True, + n_waypoint_eigs: int = 10, + n_neighbors: int = 15, + convergence_epsilon: float = 1e-3, + l2_penalty: float = 0, + max_franke_wolfe_iters: int = 50, + use_sparse: bool = False, +): + """Core SEACells class. + + :param ad: (AnnData) annotated data matrix + :param build_kernel_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells + Typically 'X_pca' for scRNA or 'X_svd' for scATAC + :param n_SEACells: (int) number of SEACells to compute + :param use_gpu: (bool) whether to use GPU for computation + :param verbose: (bool) whether to suppress verbose program logging + :param n_waypoint_eigs: (int) number of eigenvectors to use for waypoint initialization + :param n_neighbors: (int) number of nearest neighbors to use for graph construction + :param convergence_epsilon: (float) convergence threshold for Franke-Wolfe algorithm + :param l2_penalty: (float) L2 penalty for Franke-Wolfe algorithm + :param max_franke_wolfe_iters: (int) maximum number of iterations for Franke-Wolfe algorithm + :param use_sparse: (bool) whether to use sparse matrix operations. Currently only supported for CPU implementation. + + See cpu.py or gpu.py for descriptions of model attributes and methods. + """ + if use_sparse: + assert ( + not use_gpu + ), "Sparse matrix operations are only supported for CPU implementation." + try: + from . import cpu + except ImportError: + import cpu + model = cpu.SEACellsCPU( + ad, + build_kernel_on, + n_SEACells, + verbose, + n_waypoint_eigs, + n_neighbors, + convergence_epsilon, + l2_penalty, + max_franke_wolfe_iters, + ) + + return model + + if use_gpu: + try: + from . import gpu + except ImportError: + import gpu + + model = gpu.SEACellsGPU( + ad, + build_kernel_on, + n_SEACells, + verbose, + n_waypoint_eigs, + n_neighbors, + convergence_epsilon, + l2_penalty, + max_franke_wolfe_iters, + ) + + else: + try: + from . import cpu_dense + except ImportError: + import cpu_dense + model = cpu_dense.SEACellsCPUDense( + ad, + build_kernel_on, + n_SEACells, + verbose, + n_waypoint_eigs, + n_neighbors, + convergence_epsilon, + l2_penalty, + max_franke_wolfe_iters, + ) + + return model + + +def sparsify_assignments(A, thresh: float): + """Zero out all values below a threshold in an assignment matrix. + + :param A: (csr_matrix) of shape n_cells x n_SEACells containing assignment weights + :param thresh: (float) threshold below which to zero out assignment weights + :return: (np.array) of shape n_cells x n_SEACells containing assignment weights. + """ + A = copy.deepcopy(A) + A[A < thresh] = 0 + + # Renormalize. Cells whose every weight was below threshold would otherwise + # produce NaN rows; leave their weights at zero so they contribute nothing. + row_sums = A.sum(1, keepdims=True) + row_sums = np.where(row_sums == 0, 1.0, row_sums) + A = A / row_sums + + return A + + +def summarize_by_soft_SEACell( + ad, A, celltype_label=None, summarize_layer="raw", minimum_weight: float = 0.05 +): + """Summary of soft SEACell assignment. + + Aggregates cells within each SEACell, summing over all raw data x assignment weight for all cells belonging to a + SEACell. Data is un-normalized and pseudo-raw aggregated counts are stored in .layers['raw']. + Attributes associated with variables (.var) are copied over, but relevant per SEACell attributes must be + manually copied, since certain attributes may need to be summed, or averaged etc, depending on the attribute. + The output of this function is an anndata object of shape n_metacells x original_data_dimension. + + @param ad: (sc.AnnData) containing raw counts for single-cell data + @param A: (np.array) of shape n_SEACells x n_cells containing assignment weights of cells to SEACells + @param celltype_label: (str) optionally provide the celltype label to compute modal celltype per SEACell + @param summarize_layer: (str) key for ad.layers to find raw data. Use 'raw' to search for ad.raw.X + @param minimum_weight: (float) minimum value below which assignment weights are zero-ed out. If all cell assignment + weights are smaller than minimum_weight, the 95th percentile weight is used. + @return: aggregated anndata containing weighted expression for aggregated SEACells + """ + import scanpy as sc + from scipy.sparse import csr_matrix + + compute_seacell_celltypes = False + if celltype_label is not None: + if celltype_label not in ad.obs.columns: + raise ValueError(f"Celltype label {celltype_label} not present in ad.obs") + compute_seacell_celltypes = True + + if summarize_layer == "raw" and ad.raw is not None: + data = ad.raw.X + else: + data = ad.layers[summarize_layer] + + A = sparsify_assignments(A.T, thresh=minimum_weight) + + # Vectorized aggregation: a single sparse matmul replaces the per-metacell + # Python loop. Mathematically the same as + # seacell_exp[m, :] = sum_c A[c, m] * data[c, :] / sum_c A[c, m] + # but expressed as (A.T @ data) / totals. + n_metacells = A.shape[1] + # Per-metacell weight totals; works uniformly for dense ndarray and sparse. + totals = np.asarray(A.sum(axis=0)).ravel() + totals_safe = np.where(totals > 0, totals, 1.0) + + A_T = csr_matrix(A.T) # (M, n_cells) + weighted_sum = A_T @ data # (M, n_features) + if hasattr(weighted_sum, "toarray"): + weighted_sum_dense = weighted_sum.toarray() + else: + weighted_sum_dense = np.asarray(weighted_sum) + seacell_expressions_mat = weighted_sum_dense / totals_safe[:, np.newaxis] + # Rows whose total weight was zero produce all zeros (totals_safe = 1 + # divides a zero numerator), matching the per-metacell zero fallback. + + seacell_expressions = csr_matrix(seacell_expressions_mat) + seacell_ad = sc.AnnData(seacell_expressions, dtype=seacell_expressions.dtype) + seacell_ad.var_names = ad.var_names + seacell_ad.obs["Pseudo-sizes"] = totals + + if compute_seacell_celltypes: + # Vectorized celltype purity: build a cells x celltypes indicator and + # form purity = A.T @ indicator (M x C). The dominant celltype per + # metacell is argmax over rows; ties resolve to the first category, as + # in the original sort_values(...).iloc[0] path (categories are sorted). + celltype_col = ad.obs[celltype_label].astype("category") + celltype_codes = celltype_col.cat.codes.values + celltype_names = celltype_col.cat.categories + n_cells = len(celltype_codes) + celltype_indicator = csr_matrix( + (np.ones(n_cells), (np.arange(n_cells), celltype_codes)), + shape=(n_cells, len(celltype_names)), + ) + purity_mat = A_T @ celltype_indicator # (M, C) + if hasattr(purity_mat, "toarray"): + purity_mat = purity_mat.toarray() + purity_mat = np.asarray(purity_mat) + + purity_row_sum = purity_mat.sum(axis=1) + nonempty = purity_row_sum > 0 + # Avoid divide-by-zero rows; we'll mask their celltype to None below. + denom = np.where(nonempty, purity_row_sum, 1.0) + purity_norm = purity_mat / denom[:, np.newaxis] + argmax_ct = purity_norm.argmax(axis=1) + + seacell_celltypes = [ + celltype_names[idx] if nonempty[m] else None + for m, idx in enumerate(argmax_ct) + ] + seacell_purities = np.where( + nonempty, purity_norm[np.arange(n_metacells), argmax_ct], 0.0 + ).tolist() + + seacell_ad.obs["celltype"] = seacell_celltypes + seacell_ad.obs["celltype_purity"] = seacell_purities + seacell_ad.var_names = ad.var_names + return seacell_ad + + +def summarize_by_SEACell( + ad, SEACells_label="SEACell", celltype_label=None, summarize_layer="raw" +): + """Summary of SEACell assignment. + + Aggregates cells within each SEACell, summing over all raw data for all cells belonging to a SEACell. + Data is unnormalized and raw aggregated counts are stored .layers['raw']. + Attributes associated with variables (.var) are copied over, but relevant per SEACell attributes must be + manually copied, since certain attributes may need to be summed, or averaged etc, depending on the attribute. + The output of this function is an anndata object of shape n_metacells x original_data_dimension. + :return: anndata.AnnData containing aggregated counts. + + """ + import scanpy as sc + from scipy.sparse import csr_matrix + + # Pick the source data matrix once. + if summarize_layer == "X": + data = ad.X + elif summarize_layer == "raw" and ad.raw is not None: + data = ad.raw.X + else: + data = ad.layers[summarize_layer] + + # Build a cell-to-metacell indicator and aggregate in one sparse matmul. + # Preserves first-occurrence order of metacell labels (matches the prior + # use of pd.Series.unique() to seed summ_matrix.index). + labels = ad.obs[SEACells_label] + metacell_order = pd.Index(labels.unique()) + code_lookup = pd.Series(np.arange(len(metacell_order)), index=metacell_order) + codes = code_lookup.loc[labels.values].values + n_cells = len(codes) + n_metacells = len(metacell_order) + + indicator = csr_matrix( + (np.ones(n_cells), (codes, np.arange(n_cells))), + shape=(n_metacells, n_cells), + ) + summed = indicator @ data + summed = csr_matrix(summed) + + meta_ad = sc.AnnData(summed, dtype=summed.dtype) + meta_ad.obs_names = metacell_order.astype(str) + meta_ad.var_names = ad.var_names + meta_ad.layers["raw"] = summed + + # Also compute cell type purity + if celltype_label is not None: + # TODO: Catch specific exception + try: + purity_df = evaluate.compute_celltype_purity(ad, celltype_label) + meta_ad.obs = meta_ad.obs.join(purity_df) + except Exception as e: # noqa: BLE001 + print(f"Cell type purity failed with Exception {e}") + + return meta_ad diff --git a/SEACells/cpu.py b/SEACells/cpu.py index fd6f137..53b355c 100644 --- a/SEACells/cpu.py +++ b/SEACells/cpu.py @@ -1,764 +1,810 @@ -import copy - -import numpy as np -import palantir -import pandas as pd -from scipy.sparse import csr_matrix, save_npz -from scipy.sparse.linalg import norm -from sklearn.preprocessing import normalize -from tqdm import tqdm - -try: - from . import build_graph -except ImportError: - import build_graph - - -class SEACellsCPU: - """CPU Implementation of SEACells algorithm. - - This implementation uses fast kernel archetypal analysis to find SEACells - groupings - of cells that represent highly granular, distinct cell states. SEACells are found by solving a convex optimization - problem that minimizes the residual sum of squares between the kernel matrix and the weighted sum of the archetypes. - - Modifies annotated data matrix in place to include SEACell assignments in ad.obs['SEACell'] - """ - - def __init__( - self, - ad, - build_kernel_on: str, - n_SEACells: int, - verbose: bool = True, - n_waypoint_eigs: int = 10, - n_neighbors: int = 15, - convergence_epsilon: float = 1e-3, - l2_penalty: float = 0, - max_franke_wolfe_iters: int = 50, - ): - """CPU Implementation of SEACells algorithm. - - :param ad: (AnnData) annotated data matrix - :param build_kernel_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells - Typically 'X_pca' for scRNA or 'X_svd' for scATAC - :param n_SEACells: (int) number of SEACells to compute - :param verbose: (bool) whether to suppress verbose program logging - :param n_waypoint_eigs: (int) number of eigenvectors to use for waypoint initialization - :param n_neighbors: (int) number of nearest neighbors to use for graph construction - :param convergence_epsilon: (float) convergence threshold for Franke-Wolfe algorithm - :param l2_penalty: (float) L2 penalty for Franke-Wolfe algorithm - :param max_franke_wolfe_iters: (int) maximum number of iterations for Franke-Wolfe algorithm - - Class Attributes: - ad: (AnnData) annotated data matrix - build_kernel_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells - n_cells: (int) number of cells in ad - k: (int) number of SEACells to compute - n_waypoint_eigs: (int) number of eigenvectors to use for waypoint initialization - waypoint_proportion: (float) proportion of cells to use for waypoint initialization - n_neighbors: (int) number of nearest neighbors to use for graph construction - max_FW_iter: (int) maximum number of iterations for Franke-Wolfe algorithm - verbose: (bool) whether to suppress verbose program logging - l2_penalty: (float) L2 penalty for Franke-Wolfe algorithm - RSS_iters: (list) list of residual sum of squares at each iteration of Franke-Wolfe algorithm - convergence_epsilon: (float) algorithm converges when RSS < convergence_epsilon * RSS(0) - convergence_threshold: (float) convergence threshold for Franke-Wolfe algorithm - kernel_matrix: (csr_matrix) kernel matrix of shape (n_cells, n_cells) - K: (csr_matrix) dot product of kernel matrix with itself, K = K @ K.T - archetypes: (list) list of cell indices corresponding to archetypes - A_: (csr_matrix) matrix of shape (k, n) containing final assignments of cells to SEACells - B_: (csr_matrix) matrix of shape (n, k) containing archetype weights - A0: (csr_matrix) matrix of shape (k, n) containing initial assignments of cells to SEACells - B0: (csr_matrix) matrix of shape (n, k) containing initial archetype weights - """ - print("Welcome to SEACells!") - self.ad = ad - self.build_kernel_on = build_kernel_on - self.n_cells = ad.shape[0] - - if not isinstance(n_SEACells, int): - try: - n_SEACells = int(n_SEACells) - except ValueError: - raise ValueError( - f"The number of SEACells specified must be an integer type, not {type(n_SEACells)}" - ) - - self.k = n_SEACells - - self.n_waypoint_eigs = n_waypoint_eigs - self.waypoint_proportion = 1 - self.n_neighbors = n_neighbors - - self.max_FW_iter = max_franke_wolfe_iters - self.verbose = verbose - self.l2_penalty = l2_penalty - - self.RSS_iters = [] - self.convergence_epsilon = convergence_epsilon - self.convergence_threshold = None - - # Parameters to be initialized later in the model - self.kernel_matrix = None - self.K = None - - # Archetypes as list of cell indices - self.archetypes = None - - self.A_ = None - self.B_ = None - self.A0 = None - self.B0 = None - - return - - def add_precomputed_kernel_matrix(self, K): - """Add precomputed kernel matrix to SEACells object. - - :param K: (np.ndarray) kernel matrix of shape (n_cells, n_cells) - :return: None. - """ - assert K.shape == ( - self.n_cells, - self.n_cells, - ), f"Dimension of kernel matrix must be n_cells = ({self.n_cells},{self.n_cells}), not {K.shape} " - self.kernel_matrix = K - - # Pre-compute dot product - self.K = self.kernel_matrix @ self.kernel_matrix.T - - def construct_kernel_matrix( - self, n_neighbors: int = None, graph_construction="union" - ): - """Construct kernel matrix from data matrix using PCA/SVD and nearest neighbors. - - :param n_neighbors: (int) number of nearest neighbors to use for graph construction. - If none, use self.n_neighbors, which has a default value of 15. - :param graph_construction: (str) method for graph construction. Options are 'union' or 'intersection'. - Default is 'union', where the neighborhood graph is made symmetric by adding an edge - (u,v) if either (u,v) or (v,u) is in the neighborhood graph. If 'intersection', the - neighborhood graph is made symmetric by adding an edge (u,v) if both (u,v) and (v,u) - are in the neighborhood graph. - :return: None. - """ - # input to graph construction is PCA/SVD - kernel_model = build_graph.SEACellGraph( - self.ad, self.build_kernel_on, verbose=self.verbose - ) - - # K is a sparse matrix representing input to SEACell alg - if n_neighbors is None: - n_neighbors = self.n_neighbors - - M = kernel_model.rbf(n_neighbors, graph_construction=graph_construction) - self.kernel_matrix = M - - # Pre-compute dot product - self.K = self.kernel_matrix @ self.kernel_matrix.T - - return - - def initialize_archetypes(self): - """Initialize B matrix which defines cells as SEACells. - - Uses waypoint analysis for initialization into to fully - cover the phenotype space, and then greedily selects the remaining cells (if redundant cells are selected by - waypoint analysis). - - Modifies self.archetypes in-place with the indices of cells that are used as initialization for archetypes. - - By default, the proportion of cells selected by waypoint analysis is 1. This can be changed by setting the - waypoint_proportion parameter in the SEACells object. For example, setting waypoint_proportion = 0.5 will - select half of the cells by waypoint analysis and half by greedy selection. - """ - k = self.k - - if self.waypoint_proportion > 0: - waypoint_ix = self._get_waypoint_centers(k) - waypoint_ix = np.random.choice( - waypoint_ix, - int(len(waypoint_ix) * self.waypoint_proportion), - replace=False, - ) - from_greedy = self.k - len(waypoint_ix) - if self.verbose: - print( - f"Selecting {len(waypoint_ix)} cells from waypoint initialization." - ) - else: - from_greedy = self.k - - greedy_ix = self._get_greedy_centers(n_SEACells=from_greedy + 10) - if self.verbose: - print(f"Selecting {from_greedy} cells from greedy initialization.") - - if self.waypoint_proportion > 0: - all_ix = np.hstack([waypoint_ix, greedy_ix]) - else: - all_ix = np.hstack([greedy_ix]) - - unique_ix, ind = np.unique(all_ix, return_index=True) - all_ix = unique_ix[np.argsort(ind)][:k] - self.archetypes = all_ix - - def initialize(self, initial_archetypes=None, initial_assignments=None): - """Initialize the model by initializing the B matrix. - - The method constructs archetypes from a convex combination of cells) and - the A matrix (defines assignments of cells to archetypes. - - Assumes the kernel matrix has already been constructed. B matrix is of shape (n_cells, n_SEACells) and A matrix - is of shape (n_SEACells, n_cells). - - :param initial_archetypes: (np.ndarray) initial archetypes to use for initialization. If None, use waypoint - analysis and greedy selection to initialize archetypes. - :param initial_assignments: (np.ndarray) initial assignments to use for initialization. If None, use - random initialization. - :return: None - """ - if self.K is None: - raise RuntimeError( - "Must first construct kernel matrix before initializing SEACells." - ) - K = self.K - # initialize B (update this to allow initialization from RRQR) - n = K.shape[0] - - if initial_archetypes is not None: - if self.verbose: - print("Using provided list of initial archetypes") - self.archetypes = initial_archetypes - - if self.archetypes is None: - self.initialize_archetypes() - self.k = len(self.archetypes) - k = self.k - - # Sparse construction of B matrix - cols = np.arange(k) - rows = self.archetypes - shape = (n, k) - B0 = csr_matrix((np.ones(len(rows)), (rows, cols)), shape=shape) - - self.B0 = B0 - B = self.B0.copy() - - if initial_assignments is not None: - A0 = initial_assignments - assert A0.shape == ( - k, - n, - ), f"Initial assignment matrix should be of shape (k={k} x n={n})" - A0 = csr_matrix(A0) - A0 = normalize(A0, axis=0, norm="l1") - else: - # Need to ensure each cell is assigned to at least one archetype - # Randomly sample roughly 25% of the values between 0 and k - archetypes_per_cell = int(k * 0.25) - rows = np.random.randint(0, k, size=(n, archetypes_per_cell)).reshape(-1) - columns = np.repeat(np.arange(n), archetypes_per_cell) - - A0 = csr_matrix( - (np.random.random(len(rows)), (rows, columns)), shape=(k, n) - ) - A0 = normalize(A0, axis=0, norm="l1") - - if self.verbose: - print("Randomly initialized A matrix.") - - self.A0 = A0 - A = self.A0.copy() - A = self._updateA(B, A) - - self.A_ = A - self.B_ = B - - # Create convergence threshold - RSS = self.compute_RSS(A, B) - self.RSS_iters.append(RSS) - - if self.convergence_threshold is None: - self.convergence_threshold = self.convergence_epsilon * RSS - if self.verbose: - print( - f"Setting convergence threshold at {self.convergence_threshold:.5f}" - ) - - def _get_waypoint_centers(self, n_waypoints=None): - """Initialize B matrix using waypoint analysis, as described in Palantir. - - From https://www.nature.com/articles/s41587-019-0068-4. - - :param n_waypoints: (int) number of SEACells to initialize using waypoint analysis. If None specified, - all SEACells initialized using this method. - :return: (np.ndarray) indices of cells to use as initial archetypes - """ - if n_waypoints is None: - k = self.k - else: - k = n_waypoints - - ad = self.ad - - if self.build_kernel_on == "X_pca": - pca_components = pd.DataFrame(ad.obsm["X_pca"]).set_index(ad.obs_names) - elif self.build_kernel_on == "X_svd": - # Compute PCA components from ad object - pca_components = pd.DataFrame(ad.obsm["X_svd"]).set_index(ad.obs_names) - else: - pca_components = pd.DataFrame(ad.obsm[self.build_kernel_on]).set_index( - ad.obs_names - ) - - print(f"Building kernel on {self.build_kernel_on}") - - if self.verbose: - print( - f"Computing diffusion components from {self.build_kernel_on} for waypoint initialization ... " - ) - - dm_res = palantir.utils.run_diffusion_maps( - pca_components, n_components=self.n_neighbors - ) - dc_components = palantir.utils.determine_multiscale_space( - dm_res, n_eigs=self.n_waypoint_eigs - ) - if self.verbose: - print("Done.") - - # Initialize SEACells via waypoint sampling - if self.verbose: - print("Sampling waypoints ...") - waypoint_init = palantir.core._max_min_sampling( - data=dc_components, num_waypoints=k - ) - dc_components["iix"] = np.arange(len(dc_components)) - waypoint_ix = dc_components.loc[waypoint_init]["iix"].values - if self.verbose: - print("Done.") - - return waypoint_ix - - def _get_greedy_centers(self, n_SEACells=None): - """Initialize SEACells using fast greedy adaptive CSSP. - - From https://arxiv.org/pdf/1312.6838.pdf - :param n_SEACells: (int) number of SEACells to initialize using greedy selection. If None specified, - all SEACells initialized using this method. - :return: (np.ndarray) indices of cells to use as initial archetypes - """ - K = self.K - n = K.shape[0] - - if n_SEACells is None: - k = self.k - else: - k = n_SEACells - - if self.verbose: - print("Initializing residual matrix using greedy column selection") - - # precompute M.T * M - # ATA = M.T @ M - ATA = K - - if self.verbose: - print("Initializing f and g...") - - f = np.array((ATA.multiply(ATA)).sum(axis=0)).ravel() - # f = np.array((ATA * ATA).sum(axis=0)).ravel() - g = np.array(ATA.diagonal()).ravel() - - d = np.zeros((k, n)) - omega = np.zeros((k, n)) - - # keep track of selected indices - centers = np.zeros(k, dtype=int) - - # sampling - for j in tqdm(range(k)): - score = f / g - p = np.argmax(score) - - # print residuals - np.sum(f) - - delta_term1 = ATA[:, p].toarray().squeeze() - # print(delta_term1) - delta_term2 = ( - np.multiply(omega[:, p].reshape(-1, 1), omega).sum(axis=0).squeeze() - ) - delta = delta_term1 - delta_term2 - - # some weird rounding errors - delta[p] = np.max([0, delta[p]]) - - o = delta / np.max([np.sqrt(delta[p]), 1e-6]) - omega_square_norm = np.linalg.norm(o) ** 2 - omega_hadamard = np.multiply(o, o) - term1 = omega_square_norm * omega_hadamard - - # update f (term2) - pl = np.zeros(n) - for r in range(j): - omega_r = omega[r, :] - pl += np.dot(omega_r, o) * omega_r - - ATAo = (ATA @ o.reshape(-1, 1)).ravel() - term2 = np.multiply(o, ATAo - pl) - - # update f - f += -2.0 * term2 + term1 - - # update g - g += omega_hadamard - - # store omega and delta - d[j, :] = delta - omega[j, :] = o - - # add index - centers[j] = int(p) - - return centers - - def _updateA(self, B, A_prev): - """Update step for assigment matrix A. - - Given archetype matrix B and using kernel matrix K, compute assignment matrix A using constrained gradient - descent via Frank-Wolfe algorithm. - - :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells - :param A_prev: (n x k csr_matrix) defining previous weights used for assigning cells to SEACells - :return: (n x k csr_matrix) defining updated weights used for assigning cells to SEACells - """ - n, k = B.shape - A = A_prev - - t = 0 # current iteration (determine multiplicative update) - - # precompute some gradient terms - t2 = (self.K @ B).T - t1 = t2 @ B - - # update rows of A for given number of iterations - while t < self.max_FW_iter: - # compute gradient (must convert matrix to ndarray) - G = 2.0 * np.array(t1 @ A - t2) - - # # get argmins - shape 1 x n - amins = np.argmin(G, axis=0) - amins = np.array(amins).reshape(-1) - - # # loop free implementation - e = csr_matrix((np.ones(len(amins)), (amins, np.arange(n))), shape=A.shape) - - A += 2.0 / (t + 2.0) * (e - A) - t += 1 - - return A - - def _updateB(self, A, B_prev): - """Update step for archetype matrix B. - - Given assignment matrix A and using kernel matrix K, compute archetype matrix B using constrained gradient - descent via Frank-Wolfe algorithm. - - :param A: (n x k csr_matrix) defining weights used for assigning cells to SEACells - :param B_prev: (n x k csr_matrix) defining previous SEACells as weighted combinations of cells - :return: (n x k csr_matrix) defining updated SEACells as weighted combinations of cells - """ - K = self.K - k, n = A.shape - - B = B_prev - - # keep track of error - t = 0 - - # precompute some terms - t1 = A @ A.T - t2 = K @ A.T - - # update rows of B for a given number of iterations - while t < self.max_FW_iter: - # compute gradient (need to convert np.matrix to np.array) - G = 2.0 * np.array(K @ B @ t1 - t2) - - # get all argmins - amins = np.argmin(G, axis=0) - amins = np.array(amins).reshape(-1) - - e = csr_matrix((np.ones(len(amins)), (amins, np.arange(k))), shape=B.shape) - - B += 2.0 / (t + 2.0) * (e - B) - - t += 1 - - return B - - def compute_reconstruction(self, A=None, B=None): - """Compute reconstructed data matrix using learned archetypes (SEACells) and assignments. - - :param A: (k x n csr_matrix) defining weights used for assigning cells to SEACells - If None provided, self.A is used. - :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells - If None provided, self.B is used. - :return: (n x n csr_matrix) defining reconstructed data matrix. - """ - if A is None: - A = self.A_ - if B is None: - B = self.B_ - - if A is None or B is None: - raise RuntimeError( - "Either assignment matrix A or archetype matrix B is None." - ) - return (self.kernel_matrix.dot(B)).dot(A) - - def compute_RSS(self, A=None, B=None): - """Compute residual sum of squares error in difference between reconstruction and true data matrix. - - :param A: (k x n csr_matrix) defining weights used for assigning cells to SEACells - If None provided, self.A is used. - :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells - If None provided, self.B is used. - :return: - ||X-XBA||^2 - (float) square difference between true data and reconstruction. - """ - if A is None: - A = self.A_ - if B is None: - B = self.B_ - - reconstruction = self.compute_reconstruction(A, B) - return norm(self.kernel_matrix - reconstruction) - - def plot_convergence(self, save_as=None, show=True): - """Plot behaviour of squared error over iterations. - - :param save_as: (str) name of file which figure is saved as. If None, no plot is saved. - :param show: (bool) whether to show plot - :return: None. - """ - import matplotlib.pyplot as plt - - plt.figure() - plt.plot(self.RSS_iters) - plt.title("Reconstruction Error over Iterations") - plt.xlabel("Iterations") - plt.ylabel("Squared Error") - if save_as is not None: - plt.savefig(save_as, dpi=150) - if show: - plt.show() - plt.close() - - def step(self): - """Perform one iteration of SEACell algorithm. Update assignment matrix A and archetype matrix B. - - :return: None. - """ - A = self.A_ - B = self.B_ - - if self.K is None: - raise RuntimeError( - "Kernel matrix has not been computed. Run model.construct_kernel_matrix() first." - ) - - if A is None: - raise RuntimeError( - "Cell to SEACell assignment matrix has not been initialised. Run model.initialize() first." - ) - - if B is None: - raise RuntimeError( - "Archetype matrix has not been initialised. Run model.initialize() first." - ) - - A = self._updateA(B, A) - B = self._updateB(A, B) - - self.RSS_iters.append(self.compute_RSS(A, B)) - - self.A_ = A - self.B_ = B - - # Label cells by SEACells assignment - labels = self.get_hard_assignments() - self.ad.obs["SEACell"] = labels["SEACell"] - - return - - def _fit( - self, - max_iter: int = 50, - min_iter: int = 10, - initial_archetypes=None, - initial_assignments=None, - ): - """Internal method to compute archetypes and loadings given kernel matrix K. - - Iteratively updates A and B matrices until maximum number of iterations or convergence has been achieved. - - Modifies ad.obs in place to add 'SEACell' labels to cells. - :param max_iter: (int) maximum number of iterations to perform - :param min_iter: (int) minimum number of iterations to perform - :param initial_archetypes: (array) initial archetypes to use. If None, random initialisation is used. - :param initial_assignments: (array) initial assignments to use. If None, random initialisation is used. - :return: None - """ - self.initialize( - initial_archetypes=initial_archetypes, - initial_assignments=initial_assignments, - ) - - converged = False - n_iter = 0 - while (not converged and n_iter < max_iter) or n_iter < min_iter: - n_iter += 1 - if n_iter == 1 or (n_iter) % 10 == 0: - if self.verbose: - print(f"Starting iteration {n_iter}.") - - self.step() - - if n_iter == 1 or (n_iter) % 10 == 0: - if self.verbose: - print(f"Completed iteration {n_iter}.") - - # Check for convergence - if ( - np.abs(self.RSS_iters[-2] - self.RSS_iters[-1]) - < self.convergence_threshold - ): - if self.verbose: - print(f"Converged after {n_iter} iterations.") - converged = True - - self.Z_ = self.B_.T @ self.K - - # Label cells by SEACells assignment - labels = self.get_hard_assignments() - self.ad.obs["SEACell"] = labels["SEACell"] - - if not converged: - raise RuntimeWarning( - "Warning: Algorithm has not converged - you may need to increase the maximum number of iterations" - ) - return - - def fit( - self, - max_iter: int = 100, - min_iter: int = 10, - initial_archetypes=None, - initial_assignments=None, - ): - """Compute archetypes and loadings given kernel matrix K. - - Iteratively updates A and B matrices until maximum number of iterations or convergence has been achieved. - :param max_iter: (int) maximum number of iterations to perform (default 100) - :param min_iter: (int) minimum number of iterations to perform (default 10) - :param initial_archetypes: (array) initial archetypes to use. If None, random initialisation is used. - :param initial_assignments: (array) initial assignments to use. If None, random initialisation is used. - :return: None. - """ - if max_iter < min_iter: - raise ValueError( - "The maximum number of iterations specified is lower than the minimum number of iterations specified." - ) - self._fit( - max_iter=max_iter, - min_iter=min_iter, - initial_archetypes=initial_archetypes, - initial_assignments=initial_assignments, - ) - - def get_archetype_matrix(self): - """Return k x n matrix of archetypes computed as the product of the archetype matrix B and the kernel matrix K.""" - return self.Z_ - - def get_soft_assignments(self): - """Return soft SEACell assignment. - - Returns a tuple of (labels, weights) where labels is a dataframe with SEACell assignments for the top 5 - SEACell assignments for each cell and weights is an array with the corresponding weights for each assignment. - :return: (pd.DataFrame, np.array) with labels and weights. - """ - archetype_labels = self.get_hard_archetypes() - A = copy.deepcopy(self.A_.T) - - labels = [] - weights = [] - for _i in range(5): - l = A.argmax(1) - labels.append(archetype_labels[l]) - weights.append(A[np.arange(A.shape[0]), l]) - A[np.arange(A.shape[0]), l] = -1 - - weights = np.vstack(weights).T - labels = np.vstack(labels).T - - soft_labels = pd.DataFrame(labels) - soft_labels.index = self.ad.obs_names - - return soft_labels, weights - - def get_hard_assignments(self): - """Returns a dataframe with the SEACell assignment for each cell. - - The assignment is the SEACell with the highest assignment weight. - :return: (pd.DataFrame) with SEACell assignments. - """ - # Use argmax to get the index with the highest assignment weight - assmts = np.array(self.A_.argmax(0)).reshape(-1) - - df = pd.DataFrame({"SEACell": [f"SEACell-{i}" for i in assmts]}) - df.index = self.ad.obs_names - df.index.name = "index" - return df - - def get_hard_archetypes(self): - """Return the names of cells most strongly identified as archetypes. - - :return list of archetype names. - """ - return self.ad.obs_names[self.B_.argmax(0)] - - def save_model(self, outdir): - """Save the model to a pickle file. - - :param outdir: (str) path to directory to save to - :return: None. - """ - import pickle - - with open(outdir + "/model.pkl", "wb") as f: - pickle.dump(self, f) - return None - - def save_assignments(self, outdir): - """Save SEACells assignment. - - Saves: - (1) the cell to SEACell assignments to a csv file with the name 'SEACells.csv'. - (2) the kernel matrix to a .npz file with the name 'kernel_matrix.npz'. - (3) the archetype matrix to a .npz file with the name 'A.npz'. - (4) the loading matrix to a .npz file with the name 'B.npz'. - - :param outdir: (str) path to directory to save to - :return: None - """ - import os - - os.makedirs(outdir, exist_ok=True) - save_npz(outdir + "/kernel_matrix.npz", self.kernel_matrix) - save_npz(outdir + "/A.npz", self.A_.T) - save_npz(outdir + "/B.npz", self.B_) - - labels = self.get_hard_assignments() - labels.to_csv(outdir + "/SEACells.csv") - return None +import copy + +import numpy as np +import palantir +import pandas as pd +from scipy.sparse import csr_matrix, save_npz +from scipy.sparse.linalg import norm +from sklearn.preprocessing import normalize +from tqdm import tqdm + +try: + from . import build_graph +except ImportError: + import build_graph + + +class SEACellsCPU: + """CPU Implementation of SEACells algorithm. + + This implementation uses fast kernel archetypal analysis to find SEACells - groupings + of cells that represent highly granular, distinct cell states. SEACells are found by solving a convex optimization + problem that minimizes the residual sum of squares between the kernel matrix and the weighted sum of the archetypes. + + Modifies annotated data matrix in place to include SEACell assignments in ad.obs['SEACell'] + """ + + def __init__( + self, + ad, + build_kernel_on: str, + n_SEACells: int, + verbose: bool = True, + n_waypoint_eigs: int = 10, + n_neighbors: int = 15, + convergence_epsilon: float = 1e-3, + l2_penalty: float = 0, + max_franke_wolfe_iters: int = 50, + ): + """CPU Implementation of SEACells algorithm. + + :param ad: (AnnData) annotated data matrix + :param build_kernel_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells + Typically 'X_pca' for scRNA or 'X_svd' for scATAC + :param n_SEACells: (int) number of SEACells to compute + :param verbose: (bool) whether to suppress verbose program logging + :param n_waypoint_eigs: (int) number of eigenvectors to use for waypoint initialization + :param n_neighbors: (int) number of nearest neighbors to use for graph construction + :param convergence_epsilon: (float) convergence threshold for Franke-Wolfe algorithm + :param l2_penalty: (float) L2 penalty for Franke-Wolfe algorithm + :param max_franke_wolfe_iters: (int) maximum number of iterations for Franke-Wolfe algorithm + + Class Attributes: + ad: (AnnData) annotated data matrix + build_kernel_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells + n_cells: (int) number of cells in ad + k: (int) number of SEACells to compute + n_waypoint_eigs: (int) number of eigenvectors to use for waypoint initialization + waypoint_proportion: (float) proportion of cells to use for waypoint initialization + n_neighbors: (int) number of nearest neighbors to use for graph construction + max_FW_iter: (int) maximum number of iterations for Franke-Wolfe algorithm + verbose: (bool) whether to suppress verbose program logging + l2_penalty: (float) L2 penalty for Franke-Wolfe algorithm + RSS_iters: (list) list of residual sum of squares at each iteration of Franke-Wolfe algorithm + convergence_epsilon: (float) algorithm converges when RSS < convergence_epsilon * RSS(0) + convergence_threshold: (float) convergence threshold for Franke-Wolfe algorithm + kernel_matrix: (csr_matrix) kernel matrix of shape (n_cells, n_cells) + K: (csr_matrix) dot product of kernel matrix with itself, K = K @ K.T + archetypes: (list) list of cell indices corresponding to archetypes + A_: (csr_matrix) matrix of shape (k, n) containing final assignments of cells to SEACells + B_: (csr_matrix) matrix of shape (n, k) containing archetype weights + A0: (csr_matrix) matrix of shape (k, n) containing initial assignments of cells to SEACells + B0: (csr_matrix) matrix of shape (n, k) containing initial archetype weights + """ + print("Welcome to SEACells!") + self.ad = ad + self.build_kernel_on = build_kernel_on + self.n_cells = ad.shape[0] + + if not isinstance(n_SEACells, int): + try: + n_SEACells = int(n_SEACells) + except ValueError: + raise ValueError( + f"The number of SEACells specified must be an integer type, not {type(n_SEACells)}" + ) + + self.k = n_SEACells + + self.n_waypoint_eigs = n_waypoint_eigs + self.waypoint_proportion = 1 + self.n_neighbors = n_neighbors + + self.max_FW_iter = max_franke_wolfe_iters + self.verbose = verbose + self.l2_penalty = l2_penalty + + self.RSS_iters = [] + self.convergence_epsilon = convergence_epsilon + self.convergence_threshold = None + + # Parameters to be initialized later in the model + self.kernel_matrix = None + self.K = None + + # Archetypes as list of cell indices + self.archetypes = None + + self.A_ = None + self.B_ = None + self.A0 = None + self.B0 = None + + return + + def add_precomputed_kernel_matrix(self, K): + """Add precomputed kernel matrix to SEACells object. + + :param K: (np.ndarray) kernel matrix of shape (n_cells, n_cells) + :return: None. + """ + assert K.shape == ( + self.n_cells, + self.n_cells, + ), f"Dimension of kernel matrix must be n_cells = ({self.n_cells},{self.n_cells}), not {K.shape} " + self.kernel_matrix = K + + # Pre-compute dot product + self.K = self.kernel_matrix @ self.kernel_matrix.T + + def construct_kernel_matrix( + self, n_neighbors: int = None, graph_construction="union" + ): + """Construct kernel matrix from data matrix using PCA/SVD and nearest neighbors. + + :param n_neighbors: (int) number of nearest neighbors to use for graph construction. + If none, use self.n_neighbors, which has a default value of 15. + :param graph_construction: (str) method for graph construction. Options are 'union' or 'intersection'. + Default is 'union', where the neighborhood graph is made symmetric by adding an edge + (u,v) if either (u,v) or (v,u) is in the neighborhood graph. If 'intersection', the + neighborhood graph is made symmetric by adding an edge (u,v) if both (u,v) and (v,u) + are in the neighborhood graph. + :return: None. + """ + # input to graph construction is PCA/SVD + kernel_model = build_graph.SEACellGraph( + self.ad, self.build_kernel_on, verbose=self.verbose + ) + + # K is a sparse matrix representing input to SEACell alg + if n_neighbors is None: + n_neighbors = self.n_neighbors + + M = kernel_model.rbf(n_neighbors, graph_construction=graph_construction) + self.kernel_matrix = M + + # Pre-compute dot product + self.K = self.kernel_matrix @ self.kernel_matrix.T + + return + + def initialize_archetypes(self): + """Initialize B matrix which defines cells as SEACells. + + Uses waypoint analysis for initialization into to fully + cover the phenotype space, and then greedily selects the remaining cells (if redundant cells are selected by + waypoint analysis). + + Modifies self.archetypes in-place with the indices of cells that are used as initialization for archetypes. + + By default, the proportion of cells selected by waypoint analysis is 1. This can be changed by setting the + waypoint_proportion parameter in the SEACells object. For example, setting waypoint_proportion = 0.5 will + select half of the cells by waypoint analysis and half by greedy selection. + """ + k = self.k + + if self.waypoint_proportion > 0: + waypoint_ix = self._get_waypoint_centers(k) + waypoint_ix = np.random.choice( + waypoint_ix, + int(len(waypoint_ix) * self.waypoint_proportion), + replace=False, + ) + from_greedy = self.k - len(waypoint_ix) + if self.verbose: + print( + f"Selecting {len(waypoint_ix)} cells from waypoint initialization." + ) + else: + from_greedy = self.k + + greedy_ix = self._get_greedy_centers(n_SEACells=from_greedy + 10) + if self.verbose: + print(f"Selecting {from_greedy} cells from greedy initialization.") + + if self.waypoint_proportion > 0: + all_ix = np.hstack([waypoint_ix, greedy_ix]) + else: + all_ix = np.hstack([greedy_ix]) + + unique_ix, ind = np.unique(all_ix, return_index=True) + all_ix = unique_ix[np.argsort(ind)][:k] + self.archetypes = all_ix + + def initialize(self, initial_archetypes=None, initial_assignments=None): + """Initialize the model by initializing the B matrix. + + The method constructs archetypes from a convex combination of cells) and + the A matrix (defines assignments of cells to archetypes. + + Assumes the kernel matrix has already been constructed. B matrix is of shape (n_cells, n_SEACells) and A matrix + is of shape (n_SEACells, n_cells). + + :param initial_archetypes: (np.ndarray) initial archetypes to use for initialization. If None, use waypoint + analysis and greedy selection to initialize archetypes. + :param initial_assignments: (np.ndarray) initial assignments to use for initialization. If None, use + random initialization. + :return: None + """ + if self.K is None: + raise RuntimeError( + "Must first construct kernel matrix before initializing SEACells." + ) + K = self.K + # initialize B (update this to allow initialization from RRQR) + n = K.shape[0] + + if initial_archetypes is not None: + if self.verbose: + print("Using provided list of initial archetypes") + self.archetypes = initial_archetypes + + if self.archetypes is None: + self.initialize_archetypes() + self.k = len(self.archetypes) + k = self.k + + # Sparse construction of B matrix + cols = np.arange(k) + rows = self.archetypes + shape = (n, k) + B0 = csr_matrix((np.ones(len(rows)), (rows, cols)), shape=shape) + + self.B0 = B0 + B = self.B0.copy() + + if initial_assignments is not None: + A0 = initial_assignments + assert A0.shape == ( + k, + n, + ), f"Initial assignment matrix should be of shape (k={k} x n={n})" + A0 = csr_matrix(A0) + A0 = normalize(A0, axis=0, norm="l1") + else: + # Need to ensure each cell is assigned to at least one archetype + # Randomly sample roughly 25% of the values between 0 and k + archetypes_per_cell = int(k * 0.25) + rows = np.random.randint(0, k, size=(n, archetypes_per_cell)).reshape(-1) + columns = np.repeat(np.arange(n), archetypes_per_cell) + + A0 = csr_matrix( + (np.random.random(len(rows)), (rows, columns)), shape=(k, n) + ) + A0 = normalize(A0, axis=0, norm="l1") + + if self.verbose: + print("Randomly initialized A matrix.") + + self.A0 = A0 + A = self.A0.copy() + A = self._updateA(B, A) + + self.A_ = A + self.B_ = B + + # Create convergence threshold + RSS = self.compute_RSS(A, B) + self.RSS_iters.append(RSS) + + if self.convergence_threshold is None: + self.convergence_threshold = self.convergence_epsilon * RSS + if self.verbose: + print( + f"Setting convergence threshold at {self.convergence_threshold:.5f}" + ) + + def _get_waypoint_centers(self, n_waypoints=None): + """Initialize B matrix using waypoint analysis, as described in Palantir. + + From https://www.nature.com/articles/s41587-019-0068-4. + + :param n_waypoints: (int) number of SEACells to initialize using waypoint analysis. If None specified, + all SEACells initialized using this method. + :return: (np.ndarray) indices of cells to use as initial archetypes + """ + if n_waypoints is None: + k = self.k + else: + k = n_waypoints + + ad = self.ad + + if self.build_kernel_on == "X_pca": + pca_components = pd.DataFrame(ad.obsm["X_pca"]).set_index(ad.obs_names) + elif self.build_kernel_on == "X_svd": + # Compute PCA components from ad object + pca_components = pd.DataFrame(ad.obsm["X_svd"]).set_index(ad.obs_names) + else: + pca_components = pd.DataFrame(ad.obsm[self.build_kernel_on]).set_index( + ad.obs_names + ) + + print(f"Building kernel on {self.build_kernel_on}") + + if self.verbose: + print( + f"Computing diffusion components from {self.build_kernel_on} for waypoint initialization ... " + ) + + dm_res = palantir.utils.run_diffusion_maps( + pca_components, n_components=self.n_neighbors + ) + dc_components = palantir.utils.determine_multiscale_space( + dm_res, n_eigs=self.n_waypoint_eigs + ) + if self.verbose: + print("Done.") + + # Initialize SEACells via waypoint sampling + if self.verbose: + print("Sampling waypoints ...") + waypoint_init = palantir.core._max_min_sampling( + data=dc_components, num_waypoints=k + ) + dc_components["iix"] = np.arange(len(dc_components)) + waypoint_ix = dc_components.loc[waypoint_init]["iix"].values + if self.verbose: + print("Done.") + + return waypoint_ix + + def _get_greedy_centers(self, n_SEACells=None): + """Initialize SEACells using fast greedy adaptive CSSP. + + From https://arxiv.org/pdf/1312.6838.pdf + :param n_SEACells: (int) number of SEACells to initialize using greedy selection. If None specified, + all SEACells initialized using this method. + :return: (np.ndarray) indices of cells to use as initial archetypes + """ + K = self.K + n = K.shape[0] + + if n_SEACells is None: + k = self.k + else: + k = n_SEACells + + if self.verbose: + print("Initializing residual matrix using greedy column selection") + + # precompute M.T * M + # ATA = M.T @ M + ATA = K + + if self.verbose: + print("Initializing f and g...") + + f = np.array((ATA.multiply(ATA)).sum(axis=0)).ravel() + # f = np.array((ATA * ATA).sum(axis=0)).ravel() + g = np.array(ATA.diagonal()).ravel() + + d = np.zeros((k, n)) + omega = np.zeros((k, n)) + + # keep track of selected indices + centers = np.zeros(k, dtype=int) + + # sampling + for j in tqdm(range(k)): + score = f / g + p = np.argmax(score) + + # print residuals + np.sum(f) + + delta_term1 = ATA[:, p].toarray().squeeze() + # print(delta_term1) + delta_term2 = ( + np.multiply(omega[:, p].reshape(-1, 1), omega).sum(axis=0).squeeze() + ) + delta = delta_term1 - delta_term2 + + # some weird rounding errors + delta[p] = np.max([0, delta[p]]) + + o = delta / np.max([np.sqrt(delta[p]), 1e-6]) + omega_square_norm = np.linalg.norm(o) ** 2 + omega_hadamard = np.multiply(o, o) + term1 = omega_square_norm * omega_hadamard + + # update f (term2) + pl = np.zeros(n) + for r in range(j): + omega_r = omega[r, :] + pl += np.dot(omega_r, o) * omega_r + + ATAo = (ATA @ o.reshape(-1, 1)).ravel() + term2 = np.multiply(o, ATAo - pl) + + # update f + f += -2.0 * term2 + term1 + + # update g + g += omega_hadamard + + # store omega and delta + d[j, :] = delta + omega[j, :] = o + + # add index + centers[j] = int(p) + + return centers + + def _updateA(self, B, A_prev): + """Update step for assigment matrix A. + + Given archetype matrix B and using kernel matrix K, compute assignment matrix A using constrained gradient + descent via Frank-Wolfe algorithm. + + :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells + :param A_prev: (n x k csr_matrix) defining previous weights used for assigning cells to SEACells + :return: (n x k csr_matrix) defining updated weights used for assigning cells to SEACells + + Implementation note: this routine is mathematically equivalent to the + prior version (A <- (1-alpha) A + alpha E with alpha = 2/(t+2) and E a + column-wise one-hot indicator at argmin G). It has been rewritten to + (a) operate on a dense A in-place rather than constructing a sparse + one-hot e per FW step, since A becomes dense within a few iterations + anyway, and (b) maintain t1 @ A incrementally so each FW step costs + O(k * n) instead of recomputing an O(k^2 * n) matmul. + """ + n, k = B.shape + + # Densify A for the inner loop. The CSR pattern from a one-hot rank-1 + # update fills in within a few steps, so the sparse format provides no + # benefit here while preventing in-place updates. + A = A_prev.toarray() if hasattr(A_prev, "toarray") else np.asarray(A_prev, dtype=float) + A = np.ascontiguousarray(A, dtype=float) + + # Precompute gradient terms; t1 (k x k) and t2 (k x n) materialized as + # dense once so every FW step is BLAS dense-matmul rather than a chain + # of sparse / dense conversions. + B_dense = B.toarray() if hasattr(B, "toarray") else np.asarray(B) + KB = self.K @ B + if hasattr(KB, "toarray"): + KB = KB.toarray() + KB = np.asarray(KB) + t2 = KB.T # (k, n) + t1 = t2 @ B_dense # (k, k) + + # Track t1 @ A incrementally. After A_new = (1-alpha) A + alpha E, + # t1 @ A_new = (1-alpha)(t1 @ A) + alpha t1[:, amins]. + t1A = t1 @ A + + arange_n = np.arange(n) + for t in range(self.max_FW_iter): + G = 2.0 * (t1A - t2) + amins = G.argmin(axis=0) + alpha = 2.0 / (t + 2.0) + + # In-place rank-1 FW update on A. + A *= 1.0 - alpha + A[amins, arange_n] += alpha + + # Same rank-1 update applied to t1A. t1[:, amins] selects the + # k columns of t1 indexed by amins (one per cell). + t1A *= 1.0 - alpha + t1A += alpha * t1[:, amins] + + return csr_matrix(A) + + def _updateB(self, A, B_prev): + """Update step for archetype matrix B. + + Given assignment matrix A and using kernel matrix K, compute archetype matrix B using constrained gradient + descent via Frank-Wolfe algorithm. + + :param A: (n x k csr_matrix) defining weights used for assigning cells to SEACells + :param B_prev: (n x k csr_matrix) defining previous SEACells as weighted combinations of cells + :return: (n x k csr_matrix) defining updated SEACells as weighted combinations of cells + + Same equivalence note as _updateA: in-place dense rank-1 updates and + incremental tracking of K @ B replace the sparse one-hot construction + and the per-step recomputation of K @ B. + """ + K = self.K + k, n = A.shape + + B = B_prev.toarray() if hasattr(B_prev, "toarray") else np.asarray(B_prev, dtype=float) + B = np.ascontiguousarray(B, dtype=float) + + A_dense = A.toarray() if hasattr(A, "toarray") else np.asarray(A) + + AAT = A_dense @ A_dense.T # (k, k) dense + + KAT = K @ A_dense.T # (n, k) + if hasattr(KAT, "toarray"): + KAT = KAT.toarray() + KAT = np.asarray(KAT) + + # KB = K @ B, tracked incrementally to avoid recomputing the n x n + # sparse @ dense product on every FW step. + KB = K @ B + if hasattr(KB, "toarray"): + KB = KB.toarray() + KB = np.asarray(KB) + + # CSC view for fast column slicing K[:, amins] inside the loop. + K_csc = K.tocsc() if hasattr(K, "tocsc") else K + + arange_k = np.arange(k) + for t in range(self.max_FW_iter): + G = 2.0 * (KB @ AAT - KAT) + amins = G.argmin(axis=0) + alpha = 2.0 / (t + 2.0) + + # In-place rank-1 FW update on B. + B *= 1.0 - alpha + B[amins, arange_k] += alpha + + # KB_new = K @ B_new = (1-alpha) KB + alpha K[:, amins]. + K_amins = K_csc[:, amins] + if hasattr(K_amins, "toarray"): + K_amins = K_amins.toarray() + K_amins = np.asarray(K_amins) + KB *= 1.0 - alpha + KB += alpha * K_amins + + return csr_matrix(B) + + def compute_reconstruction(self, A=None, B=None): + """Compute reconstructed data matrix using learned archetypes (SEACells) and assignments. + + :param A: (k x n csr_matrix) defining weights used for assigning cells to SEACells + If None provided, self.A is used. + :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells + If None provided, self.B is used. + :return: (n x n csr_matrix) defining reconstructed data matrix. + """ + if A is None: + A = self.A_ + if B is None: + B = self.B_ + + if A is None or B is None: + raise RuntimeError( + "Either assignment matrix A or archetype matrix B is None." + ) + return (self.kernel_matrix.dot(B)).dot(A) + + def compute_RSS(self, A=None, B=None): + """Compute residual sum of squares error in difference between reconstruction and true data matrix. + + :param A: (k x n csr_matrix) defining weights used for assigning cells to SEACells + If None provided, self.A is used. + :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells + If None provided, self.B is used. + :return: + ||X-XBA||^2 - (float) square difference between true data and reconstruction. + """ + if A is None: + A = self.A_ + if B is None: + B = self.B_ + + reconstruction = self.compute_reconstruction(A, B) + return norm(self.kernel_matrix - reconstruction) + + def plot_convergence(self, save_as=None, show=True): + """Plot behaviour of squared error over iterations. + + :param save_as: (str) name of file which figure is saved as. If None, no plot is saved. + :param show: (bool) whether to show plot + :return: None. + """ + import matplotlib.pyplot as plt + + plt.figure() + plt.plot(self.RSS_iters) + plt.title("Reconstruction Error over Iterations") + plt.xlabel("Iterations") + plt.ylabel("Squared Error") + if save_as is not None: + plt.savefig(save_as, dpi=150) + if show: + plt.show() + plt.close() + + def step(self): + """Perform one iteration of SEACell algorithm. Update assignment matrix A and archetype matrix B. + + :return: None. + """ + A = self.A_ + B = self.B_ + + if self.K is None: + raise RuntimeError( + "Kernel matrix has not been computed. Run model.construct_kernel_matrix() first." + ) + + if A is None: + raise RuntimeError( + "Cell to SEACell assignment matrix has not been initialised. Run model.initialize() first." + ) + + if B is None: + raise RuntimeError( + "Archetype matrix has not been initialised. Run model.initialize() first." + ) + + A = self._updateA(B, A) + B = self._updateB(A, B) + + self.RSS_iters.append(self.compute_RSS(A, B)) + + self.A_ = A + self.B_ = B + + # Label cells by SEACells assignment + labels = self.get_hard_assignments() + self.ad.obs["SEACell"] = labels["SEACell"] + + return + + def _fit( + self, + max_iter: int = 50, + min_iter: int = 10, + initial_archetypes=None, + initial_assignments=None, + ): + """Internal method to compute archetypes and loadings given kernel matrix K. + + Iteratively updates A and B matrices until maximum number of iterations or convergence has been achieved. + + Modifies ad.obs in place to add 'SEACell' labels to cells. + :param max_iter: (int) maximum number of iterations to perform + :param min_iter: (int) minimum number of iterations to perform + :param initial_archetypes: (array) initial archetypes to use. If None, random initialisation is used. + :param initial_assignments: (array) initial assignments to use. If None, random initialisation is used. + :return: None + """ + self.initialize( + initial_archetypes=initial_archetypes, + initial_assignments=initial_assignments, + ) + + converged = False + n_iter = 0 + while (not converged and n_iter < max_iter) or n_iter < min_iter: + n_iter += 1 + if n_iter == 1 or (n_iter) % 10 == 0: + if self.verbose: + print(f"Starting iteration {n_iter}.") + + self.step() + + if n_iter == 1 or (n_iter) % 10 == 0: + if self.verbose: + print(f"Completed iteration {n_iter}.") + + # Check for convergence + if ( + np.abs(self.RSS_iters[-2] - self.RSS_iters[-1]) + < self.convergence_threshold + ): + if self.verbose: + print(f"Converged after {n_iter} iterations.") + converged = True + + self.Z_ = self.B_.T @ self.K + + # Label cells by SEACells assignment + labels = self.get_hard_assignments() + self.ad.obs["SEACell"] = labels["SEACell"] + + if not converged: + import warnings + + warnings.warn( + "Algorithm has not converged - you may need to increase the maximum number of iterations", + RuntimeWarning, + stacklevel=2, + ) + return + + def fit( + self, + max_iter: int = 100, + min_iter: int = 10, + initial_archetypes=None, + initial_assignments=None, + ): + """Compute archetypes and loadings given kernel matrix K. + + Iteratively updates A and B matrices until maximum number of iterations or convergence has been achieved. + :param max_iter: (int) maximum number of iterations to perform (default 100) + :param min_iter: (int) minimum number of iterations to perform (default 10) + :param initial_archetypes: (array) initial archetypes to use. If None, random initialisation is used. + :param initial_assignments: (array) initial assignments to use. If None, random initialisation is used. + :return: None. + """ + if max_iter < min_iter: + raise ValueError( + "The maximum number of iterations specified is lower than the minimum number of iterations specified." + ) + self._fit( + max_iter=max_iter, + min_iter=min_iter, + initial_archetypes=initial_archetypes, + initial_assignments=initial_assignments, + ) + + def get_archetype_matrix(self): + """Return k x n matrix of archetypes computed as the product of the archetype matrix B and the kernel matrix K.""" + return self.Z_ + + def get_soft_assignments(self): + """Return soft SEACell assignment. + + Returns a tuple of (labels, weights) where labels is a dataframe with SEACell assignments for the top 5 + SEACell assignments for each cell and weights is an array with the corresponding weights for each assignment. + :return: (pd.DataFrame, np.array) with labels and weights. + """ + archetype_labels = self.get_hard_archetypes() + A = copy.deepcopy(self.A_.T) + + labels = [] + weights = [] + for _i in range(5): + l = A.argmax(1) + labels.append(archetype_labels[l]) + weights.append(A[np.arange(A.shape[0]), l]) + A[np.arange(A.shape[0]), l] = -1 + + weights = np.vstack(weights).T + labels = np.vstack(labels).T + + soft_labels = pd.DataFrame(labels) + soft_labels.index = self.ad.obs_names + + return soft_labels, weights + + def get_hard_assignments(self): + """Returns a dataframe with the SEACell assignment for each cell. + + The assignment is the SEACell with the highest assignment weight. + :return: (pd.DataFrame) with SEACell assignments. + """ + # Use argmax to get the index with the highest assignment weight + assmts = np.array(self.A_.argmax(0)).reshape(-1) + + df = pd.DataFrame({"SEACell": [f"SEACell-{i}" for i in assmts]}) + df.index = self.ad.obs_names + df.index.name = "index" + return df + + def get_hard_archetypes(self): + """Return the names of cells most strongly identified as archetypes. + + :return list of archetype names. + """ + return self.ad.obs_names[self.B_.argmax(0)] + + def save_model(self, outdir): + """Save the model to a pickle file. + + :param outdir: (str) path to directory to save to + :return: None. + """ + import pickle + + with open(outdir + "/model.pkl", "wb") as f: + pickle.dump(self, f) + return None + + def save_assignments(self, outdir): + """Save SEACells assignment. + + Saves: + (1) the cell to SEACell assignments to a csv file with the name 'SEACells.csv'. + (2) the kernel matrix to a .npz file with the name 'kernel_matrix.npz'. + (3) the archetype matrix to a .npz file with the name 'A.npz'. + (4) the loading matrix to a .npz file with the name 'B.npz'. + + :param outdir: (str) path to directory to save to + :return: None + """ + import os + + os.makedirs(outdir, exist_ok=True) + save_npz(outdir + "/kernel_matrix.npz", self.kernel_matrix) + save_npz(outdir + "/A.npz", self.A_.T) + save_npz(outdir + "/B.npz", self.B_) + + labels = self.get_hard_assignments() + labels.to_csv(outdir + "/SEACells.csv") + return None diff --git a/SEACells/cpu_dense.py b/SEACells/cpu_dense.py index 16e67f1..87adf6e 100644 --- a/SEACells/cpu_dense.py +++ b/SEACells/cpu_dense.py @@ -1,644 +1,648 @@ -import copy - -import numpy as np -import palantir -import pandas as pd -from scipy.sparse import csr_matrix, save_npz -from tqdm import tqdm - -from . import build_graph - - -class SEACellsCPUDense: - """Fast kernel archetypal analysis. - - Finds archetypes and weights given annotated data matrix. - Modifies annotated data matrix in place to include SEACell assignments in ad.obs['SEACell']. - """ - - def __init__( - self, - ad, - build_kernel_on: str, - n_SEACells: int, - verbose: bool = True, - n_waypoint_eigs: int = 10, - n_neighbors: int = 15, - convergence_epsilon: float = 1e-3, - l2_penalty: float = 0, - max_franke_wolfe_iters: int = 50, - ): - """Fast kernel archetypal analysis. - - ad - anndata.AnnData. - """ - print("Welcome to SEACells!") - self.ad = ad - self.build_kernel_on = build_kernel_on - self.n_cells = ad.shape[0] - - if not isinstance(n_SEACells, int): - try: - n_SEACells = int(n_SEACells) - except ValueError: - raise ValueError( - f"The number of SEACells specified must be an integer type, not {type(n_SEACells)}" - ) - - self.k = n_SEACells - - self.n_waypoint_eigs = n_waypoint_eigs - self.waypoint_proportion = 1 - self.n_neighbors = n_neighbors - - self.max_FW_iter = max_franke_wolfe_iters - self.verbose = verbose - self.l2_penalty = l2_penalty - - self.RSS_iters = [] - self.convergence_epsilon = convergence_epsilon - self.convergence_threshold = None - - # Parameters to be initialized later in the model - self.kernel_matrix = None - self.K = None - - # Archetypes as list of cell indices - self.archetypes = None - - self.A_ = None - self.B_ = None - self.B0 = None - - return - - def add_precomputed_kernel_matrix(self, K): - """Compute kernel matrix.""" - assert K.shape == (self.n_cells, self.n_cells), ( - f"Dimension of kernel matrix must be n_cells = " - f"({self.n_cells},{self.n_cells}), not {K.shape} " - ) - self.kernel_matrix = K - - # Pre-compute dot product - self.K = self.kernel_matrix @ self.kernel_matrix.T - - def construct_kernel_matrix( - self, n_neighbors: int = None, graph_construction="union" - ): - """Construct kernel matrix.""" - # input to graph construction is PCA/SVD - kernel_model = build_graph.SEACellGraph( - self.ad, self.build_kernel_on, verbose=self.verbose - ) - - # K is a sparse matrix representing input to SEACell alg - if n_neighbors is None: - n_neighbors = self.n_neighbors - - M = kernel_model.rbf(n_neighbors, graph_construction=graph_construction) - self.kernel_matrix = M - - # Pre-compute dot product - self.K = self.kernel_matrix @ self.kernel_matrix.T - - return - - def initialize_archetypes(self): - """Initialize B matrix which defines cells as SEACells. - - Selects waypoint_proportion from waypoint analysis, and the remainder by greedy selection. - - Modifies self.archetypes in-place with the indices of cells that are used as initialization for archetypes - """ - k = self.k - - if self.waypoint_proportion > 0: - waypoint_ix = self._get_waypoint_centers(k) - waypoint_ix = np.random.choice( - waypoint_ix, - int(len(waypoint_ix) * self.waypoint_proportion), - replace=False, - ) - from_greedy = self.k - len(waypoint_ix) - if self.verbose: - print( - f"Selecting {len(waypoint_ix)} cells from waypoint initialization." - ) - - else: - from_greedy = self.k - - greedy_ix = self._get_greedy_centers(n_mcs=from_greedy + 10) - if self.verbose: - print(f"Selecting {from_greedy} cells from greedy initialization.") - - if self.waypoint_proportion > 0: - all_ix = np.hstack([waypoint_ix, greedy_ix]) - else: - all_ix = np.hstack([greedy_ix]) - - unique_ix, ind = np.unique(all_ix, return_index=True) - all_ix = unique_ix[np.argsort(ind)][:k] - self.archetypes = all_ix - - def initialize(self, initial_archetypes=None, initial_assignments=None): - """Initalize SEACells assignment.""" - if self.K is None: - raise RuntimeError( - "Must first construct kernel matrix before initializing SEACells." - ) - K = self.K - # initialize B (update this to allow initialization from RRQR) - n = K.shape[0] - - if initial_archetypes is not None: - if self.verbose: - print("Using provided list of initial archetypes") - self.archetypes = initial_archetypes - - if self.archetypes is None: - self.initialize_archetypes() - self.k = len(self.archetypes) - k = self.k - - # Construction of B matrix - B0 = np.zeros((n, k)) - all_ix = self.archetypes - idx1 = list(zip(all_ix, np.arange(k))) - B0[tuple(zip(*idx1))] = 1 - self.B0 = B0 - B = self.B0.copy() - - if initial_assignments is not None: - A0 = initial_assignments - assert A0.shape == ( - k, - n, - ), f"Initial assignment matrix should be of shape (k={k} x n={n})" - - else: - A0 = np.random.random((k, n)) - A0 /= A0.sum(0) - if self.verbose: - print("Randomly initialized A matrix.") - - self.A0 = A0 - A = self.A0.copy() - A = self._updateA(B, A) - - self.A_ = A - self.B_ = B - - # Create convergence threshold - RSS = self.compute_RSS(A, B) - self.RSS_iters.append(RSS) - - if self.convergence_threshold is None: - self.convergence_threshold = self.convergence_epsilon * RSS - if self.verbose: - print( - f"Setting convergence threshold at {self.convergence_threshold:.5f}" - ) - - def _get_waypoint_centers(self, n_waypoints=None): - """Initialize B matrix using waypoint analysis, as described in Palantir. - - From https://www.nature.com/articles/s41587-019-0068-4. - - :param n_waypoints: (int) number of SEACells to initialize using waypoint analysis. If None specified, - all SEACells initialized using this method. - :return: B - (array) n_datapoints x n_SEACells matrix with initial SEACell definitions - """ - if n_waypoints is None: - k = self.k - else: - k = n_waypoints - - ad = self.ad - - if self.build_kernel_on == "X_pca": - pca_components = pd.DataFrame(ad.obsm["X_pca"]).set_index(ad.obs_names) - elif self.build_kernel_on == "X_svd": - # Compute PCA components from ad object - pca_components = pd.DataFrame(ad.obsm["X_svd"]).set_index(ad.obs_names) - else: - pca_components = pd.DataFrame(ad.obsm[self.build_kernel_on]).set_index( - ad.obs_names - ) - - print(f"Building kernel on {self.build_kernel_on}") - - if self.verbose: - print( - f"Computing diffusion components from {self.build_kernel_on} for waypoint initialization ... " - ) - - dm_res = palantir.utils.run_diffusion_maps( - pca_components, n_components=self.n_neighbors - ) - dc_components = palantir.utils.determine_multiscale_space( - dm_res, n_eigs=self.n_waypoint_eigs - ) - if self.verbose: - print("Done.") - - # Initialize SEACells via waypoint sampling - if self.verbose: - print("Sampling waypoints ...") - waypoint_init = palantir.core._max_min_sampling( - data=dc_components, num_waypoints=k - ) - dc_components["iix"] = np.arange(len(dc_components)) - waypoint_ix = dc_components.loc[waypoint_init]["iix"].values - if self.verbose: - print("Done.") - - return waypoint_ix - - def _get_greedy_centers(self, n_mcs=None): - """Initialize SEACells using fast greedy adaptive CSSP. - - From https://arxiv.org/pdf/1312.6838.pdf - :param n_mcs: (int) number of SEACells to initialize using greedy selection. If None specified, - all SEACells initialized using this method. - :return: B - (array) n_datapoints x n_SEACells matrix with initial SEACell definitions - """ - K = self.K - n = K.shape[0] - - if n_mcs is None: - k = self.k - else: - k = n_mcs - - if self.verbose: - print("Initializing residual matrix using greedy column selection") - - # precompute M.T * M - # ATA = M.T @ M - ATA = K - - if self.verbose: - print("Initializing f and g...") - - f = np.array((ATA.multiply(ATA)).sum(axis=0)).ravel() - # f = np.array((ATA * ATA).sum(axis=0)).ravel() - g = np.array(ATA.diagonal()).ravel() - - d = np.zeros((k, n)) - omega = np.zeros((k, n)) - - # keep track of selected indices - centers = np.zeros(k, dtype=int) - - # sampling - for j in tqdm(range(k)): - score = f / g - p = np.argmax(score) - - # print residuals - np.sum(f) - - delta_term1 = ATA[:, p].toarray().squeeze() - # print(delta_term1) - delta_term2 = ( - np.multiply(omega[:, p].reshape(-1, 1), omega).sum(axis=0).squeeze() - ) - delta = delta_term1 - delta_term2 - - # some weird rounding errors - delta[p] = np.max([0, delta[p]]) - - o = delta / np.max([np.sqrt(delta[p]), 1e-6]) - omega_square_norm = np.linalg.norm(o) ** 2 - omega_hadamard = np.multiply(o, o) - term1 = omega_square_norm * omega_hadamard - - # update f (term2) - pl = np.zeros(n) - for r in range(j): - omega_r = omega[r, :] - pl += np.dot(omega_r, o) * omega_r - - ATAo = (ATA @ o.reshape(-1, 1)).ravel() - term2 = np.multiply(o, ATAo - pl) - - # update f - f += -2.0 * term2 + term1 - - # update g - g += omega_hadamard - - # store omega and delta - d[j, :] = delta - omega[j, :] = o - - # add index - centers[j] = int(p) - - return centers - - def _updateA(self, B, A_prev): - """Given archetype matrix B and using kernel matrix K, compute assignment matrix A using gradient descent. - - :param B: (array) n*k matrix (dense) defining SEACells as weighted combinations of cells - :return: A: (array) k*n matrix (dense) defining weights used for assigning cells to SEACells - """ - n, k = B.shape - A = A_prev - - t = 0 # current iteration (determine multiplicative update) - - # precompute some gradient terms - t2 = (self.K @ B).T - t1 = t2 @ B - - # update rows of A for given number of iterations - while t < self.max_FW_iter: - # compute gradient (must convert matrix to ndarray) - G = 2.0 * np.array(t1 @ A - t2) - self.l2_penalty * A - - # get argmins - amins = np.argmin(G, axis=0) - - # loop free implementation - e = np.zeros((k, n)) - e[amins, np.arange(n)] = 1.0 - - A += 2.0 / (t + 2.0) * (e - A) - t += 1 - - return A - - def _updateB(self, A, B_prev): - """Given assignment matrix A and using kernel matrix K, compute archetype matrix B. - - :param A: (array) k*n matrix (dense) defining weights used for assigning cells to SEACells - :return: B: (array) n*k matrix (dense) defining SEACells as weighted combinations of cells - """ - K = self.K - k, n = A.shape - - B = B_prev - - # keep track of error - t = 0 - - # precompute some terms - t1 = A @ A.T - t2 = K @ A.T - - # update rows of B for a given number of iterations - while t < self.max_FW_iter: - # compute gradient (need to convert np.matrix to np.array) - G = 2.0 * np.array(K @ B @ t1 - t2) - - # get all argmins - amins = np.argmin(G, axis=0) - - e = np.zeros((n, k)) - e[amins, np.arange(k)] = 1.0 - - B += 2.0 / (t + 2.0) * (e - B) - - t += 1 - - return B - - def compute_reconstruction(self, A=None, B=None): - """Compute reconstructed data matrix using learned archetypes (SEACells) and assignments. - - :param A: (array) k*n matrix (dense) defining weights used for assigning cells to SEACells - If None provided, self.A is used. - :param B: (array) n*k matrix (dense) defining SEACells as weighted combinations of cells - If None provided, self.B is used. - :return: array (n data points x data dimension) representing reconstruction of original data matrix - """ - if A is None: - A = self.A_ - if B is None: - B = self.B_ - - if A is None or B is None: - raise RuntimeError( - "Either assignment matrix A or archetype matrix B is None." - ) - return (self.kernel_matrix.dot(B)).dot(A) - - def compute_RSS(self, A=None, B=None): - """Compute residual sum of squares error in difference between reconstruction and true data matrix. - - :param A: (array) k*n matrix (dense) defining weights used for assigning cells to SEACells - If None provided, self.A is used. - :param B: (array) n*k matrix (dense) defining SEACells as weighted combinations of cells - :param B: (array) n*k matrix (dense) defining SEACells as weighted combinations of cells - If None provided, self.B is used. - :return: - ||X-XBA||^2 - (float) square difference between true data and reconstruction. - """ - if A is None: - A = self.A_ - if B is None: - B = self.B_ - - reconstruction = self.compute_reconstruction(A, B) - return np.linalg.norm(self.kernel_matrix - reconstruction) - - def plot_convergence(self, save_as=None, show=True): - """Plot behaviour of squared error over iterations. - - :param save_as: (str) name of file which figure is saved as. If None, no plot is saved. - """ - import matplotlib.pyplot as plt - - plt.figure() - plt.plot(self.RSS_iters) - plt.title("Reconstruction Error over Iterations") - plt.xlabel("Iterations") - plt.ylabel("Squared Error") - if save_as is not None: - plt.savefig(save_as, dpi=150) - if show: - plt.show() - plt.close() - - def step(self): - """Perform one iteration of fitting to update A and B assignment matrices.""" - A = self.A_ - B = self.B_ - - if self.K is None: - raise RuntimeError( - "Kernel matrix has not been computed. Run model.construct_kernel_matrix() first." - ) - - if A is None: - raise RuntimeError( - "Cell to SEACell assignment matrix has not been initialised. Run model.initialize() first." - ) - - if B is None: - raise RuntimeError( - "Archetype matrix has not been initialised. Run model.initialize() first." - ) - - A = self._updateA(B, A) - B = self._updateB(A, B) - - self.RSS_iters.append(self.compute_RSS(A, B)) - - self.A_ = A - self.B_ = B - - # Label cells by SEACells assignment - labels = self.get_hard_assignments() - self.ad.obs["SEACell"] = labels["SEACell"] - - return - - def _fit( - self, - max_iter: int = 50, - min_iter: int = 10, - initial_archetypes=None, - initial_assignments=None, - ): - """Compute archetypes and loadings given kernel matrix K. - - Iteratively updates A and B matrices until maximum number of iterations or convergence has been achieved. - - Modifies ad.obs in place to add 'SEACell' labels to cells. - - :param max_iter: (int) maximum number of iterations to update A and B matrices - :param min_iter: (int) minimum number of iterations to update A and B matrices - :param initial_archetypes: (list) indices of cells to use as initial archetypes - - """ - self.initialize( - initial_archetypes=initial_archetypes, - initial_assignments=initial_assignments, - ) - - converged = False - n_iter = 0 - while (not converged and n_iter < max_iter) or n_iter < min_iter: - n_iter += 1 - if n_iter == 1 or (n_iter) % 10 == 0: - if self.verbose: - print(f"Starting iteration {n_iter}.") - - self.step() - - if n_iter == 1 or (n_iter) % 10 == 0: - if self.verbose: - print(f"Completed iteration {n_iter}.") - - # Check for convergence - if ( - np.abs(self.RSS_iters[-2] - self.RSS_iters[-1]) - < self.convergence_threshold - ): - if self.verbose: - print(f"Converged after {n_iter} iterations.") - converged = True - - self.Z_ = self.B_.T @ self.K - - # Label cells by SEACells assignment - labels = self.get_hard_assignments() - self.ad.obs["SEACell"] = labels["SEACell"] - - if not converged: - raise RuntimeWarning( - "Warning: Algorithm has not converged - you may need to increase the maximum number of iterations" - ) - return - - def fit( - self, - max_iter: int = 100, - min_iter: int = 10, - initial_archetypes=None, - initial_assignments=None, - ): - """Wrapper to fit model. - - :param max_iter: (int) maximum number of iterations to update A and B matrices. Default: 100 - :param min_iter: (int) maximum number of iterations to update A and B matrices. Default: 10 - :param initial_archetypes: (list) indices of cells to use as initial archetypes - """ - if max_iter < min_iter: - raise ValueError( - "The maximum number of iterations specified is lower than the minimum number of iterations specified." - ) - self._fit( - max_iter=max_iter, - min_iter=min_iter, - initial_archetypes=initial_archetypes, - initial_assignments=initial_assignments, - ) - - def get_archetype_matrix(self): - """Return k x n matrix of archetypes.""" - return self.Z_ - - def get_soft_assignments(self): - """Compute soft SEACells assignment.""" - archetype_labels = self.get_hard_archetypes() - A = copy.deepcopy(self.A_.T) - - labels = [] - weights = [] - for _i in range(5): - l = A.argmax(1) - labels.append(archetype_labels[l]) - weights.append(A[np.arange(A.shape[0]), l]) - A[np.arange(A.shape[0]), l] = -1 - - weights = np.vstack(weights).T - labels = np.vstack(labels).T - - soft_labels = pd.DataFrame(labels) - soft_labels.index = self.ad.obs_names - - return soft_labels, weights - - def get_hard_assignments(self): - """Returns a dataframe with SEACell assignments under the column 'SEACell'. - - :return: pd.DataFrame with column 'SEACell'. - """ - # Use argmax to get the index with the highest assignment weight - - df = pd.DataFrame({"SEACell": [f"SEACell-{i}" for i in self.A_.argmax(0)]}) - df.index = self.ad.obs_names - df.index.name = "index" - - return df - - def get_archetypes(self): - """TODO.""" - raise NotImplementedError - - def get_hard_archetypes(self): - """Return the names of cells most strongly identified as archetypes.""" - return self.ad.obs_names[self.B_.argmax(0)] - - def save_assignments(self, outdir): - """Save (sparse) assignment matrices to specified directory.""" - import os - - os.makedirs(outdir, exist_ok=True) - - A = csr_matrix(self.A_) - B = csr_matrix(self.B_) - - A = A.T - - save_npz(outdir + "/kernel_matrix.npz", self.kernel_matrix) - save_npz(outdir + "/A.npz", A) - save_npz(outdir + "/B.npz", B) - - labels = self.get_hard_assignments() - labels.to_csv(outdir + "/SEACells.csv") +import copy + +import numpy as np +import palantir +import pandas as pd +from scipy.sparse import csr_matrix, save_npz +from tqdm import tqdm + +from . import build_graph + + +class SEACellsCPUDense: + """Fast kernel archetypal analysis. + + Finds archetypes and weights given annotated data matrix. + Modifies annotated data matrix in place to include SEACell assignments in ad.obs['SEACell']. + """ + + def __init__( + self, + ad, + build_kernel_on: str, + n_SEACells: int, + verbose: bool = True, + n_waypoint_eigs: int = 10, + n_neighbors: int = 15, + convergence_epsilon: float = 1e-3, + l2_penalty: float = 0, + max_franke_wolfe_iters: int = 50, + ): + """Fast kernel archetypal analysis. + + ad - anndata.AnnData. + """ + print("Welcome to SEACells!") + self.ad = ad + self.build_kernel_on = build_kernel_on + self.n_cells = ad.shape[0] + + if not isinstance(n_SEACells, int): + try: + n_SEACells = int(n_SEACells) + except ValueError: + raise ValueError( + f"The number of SEACells specified must be an integer type, not {type(n_SEACells)}" + ) + + self.k = n_SEACells + + self.n_waypoint_eigs = n_waypoint_eigs + self.waypoint_proportion = 1 + self.n_neighbors = n_neighbors + + self.max_FW_iter = max_franke_wolfe_iters + self.verbose = verbose + self.l2_penalty = l2_penalty + + self.RSS_iters = [] + self.convergence_epsilon = convergence_epsilon + self.convergence_threshold = None + + # Parameters to be initialized later in the model + self.kernel_matrix = None + self.K = None + + # Archetypes as list of cell indices + self.archetypes = None + + self.A_ = None + self.B_ = None + self.B0 = None + + return + + def add_precomputed_kernel_matrix(self, K): + """Compute kernel matrix.""" + assert K.shape == (self.n_cells, self.n_cells), ( + f"Dimension of kernel matrix must be n_cells = " + f"({self.n_cells},{self.n_cells}), not {K.shape} " + ) + self.kernel_matrix = K + + # Pre-compute dot product + self.K = self.kernel_matrix @ self.kernel_matrix.T + + def construct_kernel_matrix( + self, n_neighbors: int = None, graph_construction="union" + ): + """Construct kernel matrix.""" + # input to graph construction is PCA/SVD + kernel_model = build_graph.SEACellGraph( + self.ad, self.build_kernel_on, verbose=self.verbose + ) + + # K is a sparse matrix representing input to SEACell alg + if n_neighbors is None: + n_neighbors = self.n_neighbors + + M = kernel_model.rbf(n_neighbors, graph_construction=graph_construction) + self.kernel_matrix = M + + # Pre-compute dot product + self.K = self.kernel_matrix @ self.kernel_matrix.T + + return + + def initialize_archetypes(self): + """Initialize B matrix which defines cells as SEACells. + + Selects waypoint_proportion from waypoint analysis, and the remainder by greedy selection. + + Modifies self.archetypes in-place with the indices of cells that are used as initialization for archetypes + """ + k = self.k + + if self.waypoint_proportion > 0: + waypoint_ix = self._get_waypoint_centers(k) + waypoint_ix = np.random.choice( + waypoint_ix, + int(len(waypoint_ix) * self.waypoint_proportion), + replace=False, + ) + from_greedy = self.k - len(waypoint_ix) + if self.verbose: + print( + f"Selecting {len(waypoint_ix)} cells from waypoint initialization." + ) + + else: + from_greedy = self.k + + greedy_ix = self._get_greedy_centers(n_mcs=from_greedy + 10) + if self.verbose: + print(f"Selecting {from_greedy} cells from greedy initialization.") + + if self.waypoint_proportion > 0: + all_ix = np.hstack([waypoint_ix, greedy_ix]) + else: + all_ix = np.hstack([greedy_ix]) + + unique_ix, ind = np.unique(all_ix, return_index=True) + all_ix = unique_ix[np.argsort(ind)][:k] + self.archetypes = all_ix + + def initialize(self, initial_archetypes=None, initial_assignments=None): + """Initalize SEACells assignment.""" + if self.K is None: + raise RuntimeError( + "Must first construct kernel matrix before initializing SEACells." + ) + K = self.K + # initialize B (update this to allow initialization from RRQR) + n = K.shape[0] + + if initial_archetypes is not None: + if self.verbose: + print("Using provided list of initial archetypes") + self.archetypes = initial_archetypes + + if self.archetypes is None: + self.initialize_archetypes() + self.k = len(self.archetypes) + k = self.k + + # Construction of B matrix + B0 = np.zeros((n, k)) + all_ix = self.archetypes + idx1 = list(zip(all_ix, np.arange(k))) + B0[tuple(zip(*idx1))] = 1 + self.B0 = B0 + B = self.B0.copy() + + if initial_assignments is not None: + A0 = initial_assignments + assert A0.shape == ( + k, + n, + ), f"Initial assignment matrix should be of shape (k={k} x n={n})" + + else: + A0 = np.random.random((k, n)) + A0 /= A0.sum(0) + if self.verbose: + print("Randomly initialized A matrix.") + + self.A0 = A0 + A = self.A0.copy() + A = self._updateA(B, A) + + self.A_ = A + self.B_ = B + + # Create convergence threshold + RSS = self.compute_RSS(A, B) + self.RSS_iters.append(RSS) + + if self.convergence_threshold is None: + self.convergence_threshold = self.convergence_epsilon * RSS + if self.verbose: + print( + f"Setting convergence threshold at {self.convergence_threshold:.5f}" + ) + + def _get_waypoint_centers(self, n_waypoints=None): + """Initialize B matrix using waypoint analysis, as described in Palantir. + + From https://www.nature.com/articles/s41587-019-0068-4. + + :param n_waypoints: (int) number of SEACells to initialize using waypoint analysis. If None specified, + all SEACells initialized using this method. + :return: B - (array) n_datapoints x n_SEACells matrix with initial SEACell definitions + """ + if n_waypoints is None: + k = self.k + else: + k = n_waypoints + + ad = self.ad + + if self.build_kernel_on == "X_pca": + pca_components = pd.DataFrame(ad.obsm["X_pca"]).set_index(ad.obs_names) + elif self.build_kernel_on == "X_svd": + # Compute PCA components from ad object + pca_components = pd.DataFrame(ad.obsm["X_svd"]).set_index(ad.obs_names) + else: + pca_components = pd.DataFrame(ad.obsm[self.build_kernel_on]).set_index( + ad.obs_names + ) + + print(f"Building kernel on {self.build_kernel_on}") + + if self.verbose: + print( + f"Computing diffusion components from {self.build_kernel_on} for waypoint initialization ... " + ) + + dm_res = palantir.utils.run_diffusion_maps( + pca_components, n_components=self.n_neighbors + ) + dc_components = palantir.utils.determine_multiscale_space( + dm_res, n_eigs=self.n_waypoint_eigs + ) + if self.verbose: + print("Done.") + + # Initialize SEACells via waypoint sampling + if self.verbose: + print("Sampling waypoints ...") + waypoint_init = palantir.core._max_min_sampling( + data=dc_components, num_waypoints=k + ) + dc_components["iix"] = np.arange(len(dc_components)) + waypoint_ix = dc_components.loc[waypoint_init]["iix"].values + if self.verbose: + print("Done.") + + return waypoint_ix + + def _get_greedy_centers(self, n_mcs=None): + """Initialize SEACells using fast greedy adaptive CSSP. + + From https://arxiv.org/pdf/1312.6838.pdf + :param n_mcs: (int) number of SEACells to initialize using greedy selection. If None specified, + all SEACells initialized using this method. + :return: B - (array) n_datapoints x n_SEACells matrix with initial SEACell definitions + """ + K = self.K + n = K.shape[0] + + if n_mcs is None: + k = self.k + else: + k = n_mcs + + if self.verbose: + print("Initializing residual matrix using greedy column selection") + + # precompute M.T * M + # ATA = M.T @ M + ATA = K + + if self.verbose: + print("Initializing f and g...") + + f = np.array((ATA.multiply(ATA)).sum(axis=0)).ravel() + # f = np.array((ATA * ATA).sum(axis=0)).ravel() + g = np.array(ATA.diagonal()).ravel() + + d = np.zeros((k, n)) + omega = np.zeros((k, n)) + + # keep track of selected indices + centers = np.zeros(k, dtype=int) + + # sampling + for j in tqdm(range(k)): + score = f / g + p = np.argmax(score) + + # print residuals + np.sum(f) + + delta_term1 = ATA[:, p].toarray().squeeze() + # print(delta_term1) + delta_term2 = ( + np.multiply(omega[:, p].reshape(-1, 1), omega).sum(axis=0).squeeze() + ) + delta = delta_term1 - delta_term2 + + # some weird rounding errors + delta[p] = np.max([0, delta[p]]) + + o = delta / np.max([np.sqrt(delta[p]), 1e-6]) + omega_square_norm = np.linalg.norm(o) ** 2 + omega_hadamard = np.multiply(o, o) + term1 = omega_square_norm * omega_hadamard + + # update f (term2) + pl = np.zeros(n) + for r in range(j): + omega_r = omega[r, :] + pl += np.dot(omega_r, o) * omega_r + + ATAo = (ATA @ o.reshape(-1, 1)).ravel() + term2 = np.multiply(o, ATAo - pl) + + # update f + f += -2.0 * term2 + term1 + + # update g + g += omega_hadamard + + # store omega and delta + d[j, :] = delta + omega[j, :] = o + + # add index + centers[j] = int(p) + + return centers + + def _updateA(self, B, A_prev): + """Given archetype matrix B and using kernel matrix K, compute assignment matrix A using gradient descent. + + :param B: (array) n*k matrix (dense) defining SEACells as weighted combinations of cells + :return: A: (array) k*n matrix (dense) defining weights used for assigning cells to SEACells + """ + n, k = B.shape + A = A_prev + + t = 0 # current iteration (determine multiplicative update) + + # precompute some gradient terms + t2 = (self.K @ B).T + t1 = t2 @ B + + # update rows of A for given number of iterations + while t < self.max_FW_iter: + # compute gradient (must convert matrix to ndarray) + G = 2.0 * np.array(t1 @ A - t2) - self.l2_penalty * A + + # get argmins + amins = np.argmin(G, axis=0) + + # loop free implementation + e = np.zeros((k, n)) + e[amins, np.arange(n)] = 1.0 + + A += 2.0 / (t + 2.0) * (e - A) + t += 1 + + return A + + def _updateB(self, A, B_prev): + """Given assignment matrix A and using kernel matrix K, compute archetype matrix B. + + :param A: (array) k*n matrix (dense) defining weights used for assigning cells to SEACells + :return: B: (array) n*k matrix (dense) defining SEACells as weighted combinations of cells + """ + K = self.K + k, n = A.shape + + B = B_prev + + # keep track of error + t = 0 + + # precompute some terms + t1 = A @ A.T + t2 = K @ A.T + + # update rows of B for a given number of iterations + while t < self.max_FW_iter: + # compute gradient (need to convert np.matrix to np.array) + G = 2.0 * np.array(K @ B @ t1 - t2) + + # get all argmins + amins = np.argmin(G, axis=0) + + e = np.zeros((n, k)) + e[amins, np.arange(k)] = 1.0 + + B += 2.0 / (t + 2.0) * (e - B) + + t += 1 + + return B + + def compute_reconstruction(self, A=None, B=None): + """Compute reconstructed data matrix using learned archetypes (SEACells) and assignments. + + :param A: (array) k*n matrix (dense) defining weights used for assigning cells to SEACells + If None provided, self.A is used. + :param B: (array) n*k matrix (dense) defining SEACells as weighted combinations of cells + If None provided, self.B is used. + :return: array (n data points x data dimension) representing reconstruction of original data matrix + """ + if A is None: + A = self.A_ + if B is None: + B = self.B_ + + if A is None or B is None: + raise RuntimeError( + "Either assignment matrix A or archetype matrix B is None." + ) + return (self.kernel_matrix.dot(B)).dot(A) + + def compute_RSS(self, A=None, B=None): + """Compute residual sum of squares error in difference between reconstruction and true data matrix. + + :param A: (array) k*n matrix (dense) defining weights used for assigning cells to SEACells + If None provided, self.A is used. + :param B: (array) n*k matrix (dense) defining SEACells as weighted combinations of cells + :param B: (array) n*k matrix (dense) defining SEACells as weighted combinations of cells + If None provided, self.B is used. + :return: + ||X-XBA||^2 - (float) square difference between true data and reconstruction. + """ + if A is None: + A = self.A_ + if B is None: + B = self.B_ + + reconstruction = self.compute_reconstruction(A, B) + return np.linalg.norm(self.kernel_matrix - reconstruction) + + def plot_convergence(self, save_as=None, show=True): + """Plot behaviour of squared error over iterations. + + :param save_as: (str) name of file which figure is saved as. If None, no plot is saved. + """ + import matplotlib.pyplot as plt + + plt.figure() + plt.plot(self.RSS_iters) + plt.title("Reconstruction Error over Iterations") + plt.xlabel("Iterations") + plt.ylabel("Squared Error") + if save_as is not None: + plt.savefig(save_as, dpi=150) + if show: + plt.show() + plt.close() + + def step(self): + """Perform one iteration of fitting to update A and B assignment matrices.""" + A = self.A_ + B = self.B_ + + if self.K is None: + raise RuntimeError( + "Kernel matrix has not been computed. Run model.construct_kernel_matrix() first." + ) + + if A is None: + raise RuntimeError( + "Cell to SEACell assignment matrix has not been initialised. Run model.initialize() first." + ) + + if B is None: + raise RuntimeError( + "Archetype matrix has not been initialised. Run model.initialize() first." + ) + + A = self._updateA(B, A) + B = self._updateB(A, B) + + self.RSS_iters.append(self.compute_RSS(A, B)) + + self.A_ = A + self.B_ = B + + # Label cells by SEACells assignment + labels = self.get_hard_assignments() + self.ad.obs["SEACell"] = labels["SEACell"] + + return + + def _fit( + self, + max_iter: int = 50, + min_iter: int = 10, + initial_archetypes=None, + initial_assignments=None, + ): + """Compute archetypes and loadings given kernel matrix K. + + Iteratively updates A and B matrices until maximum number of iterations or convergence has been achieved. + + Modifies ad.obs in place to add 'SEACell' labels to cells. + + :param max_iter: (int) maximum number of iterations to update A and B matrices + :param min_iter: (int) minimum number of iterations to update A and B matrices + :param initial_archetypes: (list) indices of cells to use as initial archetypes + + """ + self.initialize( + initial_archetypes=initial_archetypes, + initial_assignments=initial_assignments, + ) + + converged = False + n_iter = 0 + while (not converged and n_iter < max_iter) or n_iter < min_iter: + n_iter += 1 + if n_iter == 1 or (n_iter) % 10 == 0: + if self.verbose: + print(f"Starting iteration {n_iter}.") + + self.step() + + if n_iter == 1 or (n_iter) % 10 == 0: + if self.verbose: + print(f"Completed iteration {n_iter}.") + + # Check for convergence + if ( + np.abs(self.RSS_iters[-2] - self.RSS_iters[-1]) + < self.convergence_threshold + ): + if self.verbose: + print(f"Converged after {n_iter} iterations.") + converged = True + + self.Z_ = self.B_.T @ self.K + + # Label cells by SEACells assignment + labels = self.get_hard_assignments() + self.ad.obs["SEACell"] = labels["SEACell"] + + if not converged: + import warnings + + warnings.warn( + "Algorithm has not converged - you may need to increase the maximum number of iterations", + RuntimeWarning, + stacklevel=2, + ) + return + + def fit( + self, + max_iter: int = 100, + min_iter: int = 10, + initial_archetypes=None, + initial_assignments=None, + ): + """Wrapper to fit model. + + :param max_iter: (int) maximum number of iterations to update A and B matrices. Default: 100 + :param min_iter: (int) maximum number of iterations to update A and B matrices. Default: 10 + :param initial_archetypes: (list) indices of cells to use as initial archetypes + """ + if max_iter < min_iter: + raise ValueError( + "The maximum number of iterations specified is lower than the minimum number of iterations specified." + ) + self._fit( + max_iter=max_iter, + min_iter=min_iter, + initial_archetypes=initial_archetypes, + initial_assignments=initial_assignments, + ) + + def get_archetype_matrix(self): + """Return k x n matrix of archetypes.""" + return self.Z_ + + def get_soft_assignments(self): + """Compute soft SEACells assignment.""" + archetype_labels = self.get_hard_archetypes() + A = copy.deepcopy(self.A_.T) + + labels = [] + weights = [] + for _i in range(5): + l = A.argmax(1) + labels.append(archetype_labels[l]) + weights.append(A[np.arange(A.shape[0]), l]) + A[np.arange(A.shape[0]), l] = -1 + + weights = np.vstack(weights).T + labels = np.vstack(labels).T + + soft_labels = pd.DataFrame(labels) + soft_labels.index = self.ad.obs_names + + return soft_labels, weights + + def get_hard_assignments(self): + """Returns a dataframe with SEACell assignments under the column 'SEACell'. + + :return: pd.DataFrame with column 'SEACell'. + """ + # Use argmax to get the index with the highest assignment weight + + df = pd.DataFrame({"SEACell": [f"SEACell-{i}" for i in self.A_.argmax(0)]}) + df.index = self.ad.obs_names + df.index.name = "index" + + return df + + def get_archetypes(self): + """TODO.""" + raise NotImplementedError + + def get_hard_archetypes(self): + """Return the names of cells most strongly identified as archetypes.""" + return self.ad.obs_names[self.B_.argmax(0)] + + def save_assignments(self, outdir): + """Save (sparse) assignment matrices to specified directory.""" + import os + + os.makedirs(outdir, exist_ok=True) + + A = csr_matrix(self.A_) + B = csr_matrix(self.B_) + + A = A.T + + save_npz(outdir + "/kernel_matrix.npz", self.kernel_matrix) + save_npz(outdir + "/A.npz", A) + save_npz(outdir + "/B.npz", B) + + labels = self.get_hard_assignments() + labels.to_csv(outdir + "/SEACells.csv") diff --git a/SEACells/evaluate.py b/SEACells/evaluate.py index 25978e3..5f1dea3 100644 --- a/SEACells/evaluate.py +++ b/SEACells/evaluate.py @@ -1,139 +1,139 @@ -import numpy as np -import palantir -import pandas as pd - - -def compactness(ad, low_dim_embedding="X_pca", SEACells_label="SEACell"): - """Compute compactness of each metacell. - - Compactness is defined is the average variance of diffusion components across cells that constitute a metcell. - - :param ad: (Anndata) Anndata object - :param low_dim_embedding: (str) `ad.obsm` field for constructing diffusion components - :param SEACell_label: (str) `ad.obs` field for computing diffusion component variances - - :return: `pd.DataFrame` with a dataframe of compactness per metacell - - """ - import palantir - - components = pd.DataFrame(ad.obsm[low_dim_embedding]).set_index(ad.obs_names) - dm_res = palantir.utils.run_diffusion_maps(components) - dc = palantir.utils.determine_multiscale_space(dm_res, n_eigs=10) - - return pd.DataFrame( - dc.join(ad.obs[SEACells_label]).groupby(SEACells_label).var().mean(1) - ).rename(columns={0: "compactness"}) - - -def separation( - ad, low_dim_embedding="X_pca", nth_nbr=1, cluster=None, SEACells_label="SEACell" -): - """Compute separation of each metacell. - - Separation is defined is the distance to the nearest neighboring metacell. - - :param ad: (Anndata) Anndata object - :param low_dim_embedding: (str) `ad.obsm` field for constructing diffusion components - :param nth_nbr: (int) Which neighbor to use for computing separation - :param SEACell_label: (str) `ad.obs` field for computing diffusion component variances - - :return: `pd.DataFrame` with a separation of compactness per metacell - - """ - components = pd.DataFrame(ad.obsm[low_dim_embedding]).set_index(ad.obs_names) - dm_res = palantir.utils.run_diffusion_maps(components) - dc = palantir.utils.determine_multiscale_space(dm_res, n_eigs=10) - - # Compute DC per metacell - metacells_dcs = ( - dc.join(ad.obs[SEACells_label], how="inner").groupby(SEACells_label).mean() - ) - - from sklearn.neighbors import NearestNeighbors - - neigh = NearestNeighbors(n_neighbors=nth_nbr) - nbrs = neigh.fit(metacells_dcs) - dists, nbrs = nbrs.kneighbors() - dists = pd.DataFrame(dists).set_index(metacells_dcs.index) - dists.columns += 1 - - nbr_cells = np.array(metacells_dcs.index)[nbrs] - - metacells_nbrs = pd.DataFrame(nbr_cells) - metacells_nbrs.index = metacells_dcs.index - metacells_nbrs.columns += 1 - - if cluster is not None: - # Get cluster type of neighbors to ensure they match the metacell cluster - clusters = ad.obs.groupby(SEACells_label)[cluster].agg( - lambda x: x.value_counts().index[0] - ) - nbr_clusters = pd.DataFrame(clusters.values[nbrs]).set_index(clusters.index) - nbr_clusters.columns = metacells_nbrs.columns - nbr_clusters = nbr_clusters.join(pd.DataFrame(clusters)) - - clusters_match = nbr_clusters.eq(nbr_clusters[cluster], axis=0) - return pd.DataFrame(dists[nth_nbr][clusters_match[nth_nbr]]).rename( - columns={1: "separation"} - ) - else: - return pd.DataFrame(dists[nth_nbr]).rename(columns={1: "separation"}) - - -def get_density(ad, key, nth_neighbor=150): - """Compute cell density as 1/ the distance to the 150th (by default) nearest neighbour. - - :param ad: AnnData object - :param key: (str) key in ad.obsm to use to build diffusion components on. - :param nth_neighbor: - :return: pd.DataFrame containing cell ID and density. - """ - from sklearn.neighbors import NearestNeighbors - - neigh = NearestNeighbors(n_neighbors=nth_neighbor) - - if "key" in ad.obsm: - print(f"Using {key} to compute cell density") - components = pd.DataFrame(ad.obsm["X_pca"]).set_index(ad.obs_names) - else: - raise ValueError(f"Key {key} not present in ad.obsm.") - - diffusion_map_results = palantir.utils.run_diffusion_maps(components) - diffusion_components = palantir.utils.determine_multiscale_space( - diffusion_map_results, n_eigs=8 - ) - - nbrs = neigh.fit(diffusion_components) - cell_density = ( - pd.DataFrame(nbrs.kneighbors()[0][:, nth_neighbor - 1]) - .set_index(ad.obs_names) - .rename(columns={0: "density"}) - ) - density = 1 / cell_density - - return density - - -def celltype_frac(x, col_name): - """TODO.""" - val_counts = x[col_name].value_counts() - return val_counts.values[0] / val_counts.values.sum() - - -def compute_celltype_purity(ad, col_name): - """Compute the purity (prevalence of most abundant value) of the specified col_name from ad.obs within each metacell. - - @param: ad - AnnData object with SEACell assignment and col_name in ad.obs dataframe - @param: col_name - (str) column name within ad.obs representing celltype groupings for each cell. - """ - celltype_fraction = ad.obs.groupby("SEACell").apply( - lambda x: celltype_frac(x, col_name) - ) - celltype = ad.obs.groupby("SEACell").apply( - lambda x: x[col_name].value_counts().index[0] - ) - - return pd.concat([celltype, celltype_fraction], axis=1).rename( - columns={0: col_name, 1: f"{col_name}_purity"} - ) +import numpy as np +import palantir +import pandas as pd + + +def compactness(ad, low_dim_embedding="X_pca", SEACells_label="SEACell"): + """Compute compactness of each metacell. + + Compactness is defined is the average variance of diffusion components across cells that constitute a metcell. + + :param ad: (Anndata) Anndata object + :param low_dim_embedding: (str) `ad.obsm` field for constructing diffusion components + :param SEACell_label: (str) `ad.obs` field for computing diffusion component variances + + :return: `pd.DataFrame` with a dataframe of compactness per metacell + + """ + import palantir + + components = pd.DataFrame(ad.obsm[low_dim_embedding]).set_index(ad.obs_names) + dm_res = palantir.utils.run_diffusion_maps(components) + dc = palantir.utils.determine_multiscale_space(dm_res, n_eigs=10) + + return pd.DataFrame( + dc.join(ad.obs[SEACells_label]).groupby(SEACells_label).var().mean(1) + ).rename(columns={0: "compactness"}) + + +def separation( + ad, low_dim_embedding="X_pca", nth_nbr=1, cluster=None, SEACells_label="SEACell" +): + """Compute separation of each metacell. + + Separation is defined is the distance to the nearest neighboring metacell. + + :param ad: (Anndata) Anndata object + :param low_dim_embedding: (str) `ad.obsm` field for constructing diffusion components + :param nth_nbr: (int) Which neighbor to use for computing separation + :param SEACell_label: (str) `ad.obs` field for computing diffusion component variances + + :return: `pd.DataFrame` with a separation of compactness per metacell + + """ + components = pd.DataFrame(ad.obsm[low_dim_embedding]).set_index(ad.obs_names) + dm_res = palantir.utils.run_diffusion_maps(components) + dc = palantir.utils.determine_multiscale_space(dm_res, n_eigs=10) + + # Compute DC per metacell + metacells_dcs = ( + dc.join(ad.obs[SEACells_label], how="inner").groupby(SEACells_label).mean() + ) + + from sklearn.neighbors import NearestNeighbors + + neigh = NearestNeighbors(n_neighbors=nth_nbr) + nbrs = neigh.fit(metacells_dcs) + dists, nbrs = nbrs.kneighbors() + dists = pd.DataFrame(dists).set_index(metacells_dcs.index) + dists.columns += 1 + + nbr_cells = np.array(metacells_dcs.index)[nbrs] + + metacells_nbrs = pd.DataFrame(nbr_cells) + metacells_nbrs.index = metacells_dcs.index + metacells_nbrs.columns += 1 + + if cluster is not None: + # Get cluster type of neighbors to ensure they match the metacell cluster + clusters = ad.obs.groupby(SEACells_label)[cluster].agg( + lambda x: x.value_counts().index[0] + ) + nbr_clusters = pd.DataFrame(clusters.values[nbrs]).set_index(clusters.index) + nbr_clusters.columns = metacells_nbrs.columns + nbr_clusters = nbr_clusters.join(pd.DataFrame(clusters)) + + clusters_match = nbr_clusters.eq(nbr_clusters[cluster], axis=0) + return pd.DataFrame(dists[nth_nbr][clusters_match[nth_nbr]]).rename( + columns={1: "separation"} + ) + else: + return pd.DataFrame(dists[nth_nbr]).rename(columns={1: "separation"}) + + +def get_density(ad, key, nth_neighbor=150): + """Compute cell density as 1/ the distance to the 150th (by default) nearest neighbour. + + :param ad: AnnData object + :param key: (str) key in ad.obsm to use to build diffusion components on. + :param nth_neighbor: + :return: pd.DataFrame containing cell ID and density. + """ + from sklearn.neighbors import NearestNeighbors + + neigh = NearestNeighbors(n_neighbors=nth_neighbor) + + if key in ad.obsm: + print(f"Using {key} to compute cell density") + components = pd.DataFrame(ad.obsm[key]).set_index(ad.obs_names) + else: + raise ValueError(f"Key {key} not present in ad.obsm.") + + diffusion_map_results = palantir.utils.run_diffusion_maps(components) + diffusion_components = palantir.utils.determine_multiscale_space( + diffusion_map_results, n_eigs=8 + ) + + nbrs = neigh.fit(diffusion_components) + cell_density = ( + pd.DataFrame(nbrs.kneighbors()[0][:, nth_neighbor - 1]) + .set_index(ad.obs_names) + .rename(columns={0: "density"}) + ) + density = 1 / cell_density + + return density + + +def celltype_frac(x, col_name): + """TODO.""" + val_counts = x[col_name].value_counts() + return val_counts.values[0] / val_counts.values.sum() + + +def compute_celltype_purity(ad, col_name): + """Compute the purity (prevalence of most abundant value) of the specified col_name from ad.obs within each metacell. + + @param: ad - AnnData object with SEACell assignment and col_name in ad.obs dataframe + @param: col_name - (str) column name within ad.obs representing celltype groupings for each cell. + """ + celltype_fraction = ad.obs.groupby("SEACell").apply( + lambda x: celltype_frac(x, col_name) + ) + celltype = ad.obs.groupby("SEACell").apply( + lambda x: x[col_name].value_counts().index[0] + ) + + return pd.concat([celltype, celltype_fraction], axis=1).rename( + columns={0: col_name, 1: f"{col_name}_purity"} + ) diff --git a/SEACells/genescores.py b/SEACells/genescores.py index 3ab455a..49ede2e 100644 --- a/SEACells/genescores.py +++ b/SEACells/genescores.py @@ -1,334 +1,355 @@ -import numpy as np -import pandas as pd -import pyranges as pr -import scanpy as sc -from scipy.stats import rankdata -from sklearn.metrics import pairwise_distances -from tqdm import tqdm - -from . import core - - -def prepare_multiome_anndata( - atac_ad, rna_ad, SEACells_label="SEACell", n_bins_for_gc=50 -): - """Function to create metacell Anndata objects from single-cell Anndata objects for multiome data. - - :param atac_ad: (Anndata) ATAC Anndata object with raw peak counts in `X`. These anndata objects should be constructed - using the example notebook available in - :param rna_ad: (Anndata) RNA Anndata object with raw gene expression counts in `X`. Note: RNA and ATAC anndata objects - should contain the same set of cells - :param SEACells_label: (str) `atac_ad.obs` field for constructing metacell matrices. Same field will be used for - summarizing RNA and ATAC metacells. - :param n_bins_gc: (int) Number of bins for creating GC bins of ATAC peaks. - :return: ATAC metacell Anndata object and RNA metacell Anndata object. - """ - # Subset of cells common to ATAC and RNA - common_cells = atac_ad.obs_names.intersection(rna_ad.obs_names) - if len(common_cells) != atac_ad.shape[0]: - print( - "Warning: The number of cells in RNA and ATAC objects are different. Only the common cells will be used." - ) - atac_mod_ad = atac_ad[common_cells, :] - rna_mod_ad = rna_ad[common_cells, :] - - # ################################################################################# - # Generate metacell matrices - - # Set of metacells - metacells = atac_mod_ad.obs[SEACells_label].astype(str).unique() - metacells = metacells[atac_mod_ad.obs[SEACells_label].value_counts()[metacells] > 1] - - print("Generating Metacell matrices...") - print(" ATAC") - atac_meta_ad = core.summarize_by_SEACell( - atac_mod_ad, SEACells_label=SEACells_label, summarize_layer="X" - ) - atac_meta_ad = atac_meta_ad[metacells, :] - # ATAC - Summarize SVD representation - - svd = pd.DataFrame(atac_mod_ad.obsm["X_svd"], index=atac_mod_ad.obs_names) - summ_svd = svd.groupby(atac_mod_ad.obs[SEACells_label]).mean() - atac_meta_ad.obsm["X_svd"] = summ_svd.loc[atac_meta_ad.obs_names, :].values - - # ATAC - Normalize - _add_atac_meta_data(atac_meta_ad, atac_mod_ad, n_bins_for_gc) - sc.pp.filter_genes(atac_meta_ad, min_cells=1) - _normalize_ad(atac_meta_ad) - - # RNA summaries using ATAC SEACells - print(" RNA") - rna_mod_ad.obs["temp"] = atac_mod_ad.obs[SEACells_label] - rna_meta_ad = core.summarize_by_SEACell( - rna_mod_ad, SEACells_label="temp", summarize_layer="X" - ) - rna_meta_ad = rna_meta_ad[metacells, :] - _normalize_ad(rna_meta_ad) - - return atac_meta_ad, rna_meta_ad - - -def _normalize_ad(meta_ad, save_raw=True): - if save_raw: - # Save in raw - meta_ad.raw = meta_ad.copy() - - # Normalize - sc.pp.normalize_total(meta_ad, key_added="n_counts") - sc.pp.log1p(meta_ad) - - -def _add_atac_meta_data(atac_meta_ad, atac_ad, n_bins_for_gc): - atac_ad.var["log_n_counts"] = np.ravel(np.log10(atac_ad.X.sum(axis=0))) - - atac_meta_ad.var["GC_bin"] = np.digitize( - atac_ad.var["GC"], np.linspace(0, 1, n_bins_for_gc) - ) - atac_meta_ad.var["counts_bin"] = np.digitize( - atac_ad.var["log_n_counts"], - np.linspace( - atac_ad.var["log_n_counts"].min(), - atac_ad.var["log_n_counts"].max(), - n_bins_for_gc, - ), - ) - - -def _pyranges_from_strings(pos_list): - """Function to create pyranges for a `pd.Series` of strings.""" - # Chromosome and positions - chr = pos_list.str.split(":").str.get(0) - start = pd.Series(pos_list.str.split(":").str.get(1)).str.split("-").str.get(0) - end = pd.Series(pos_list.str.split(":").str.get(1)).str.split("-").str.get(1) - - # Create ranges - gr = pr.PyRanges(chromosomes=chr, starts=start, ends=end) - return gr - - -def _pyranges_to_strings(peaks): - """Function to convert pyranges to `pd.Series` of strings of format 'chr:start-end'.""" - # Chromosome and positions - chr = peaks.Chromosome.astype(str).values - start = peaks.Start.astype(str).values - end = peaks.End.astype(str).values - - # Create ranges - gr = chr + ":" + start + "-" + end - - return gr - - -def load_transcripts(path_to_gtf): - """Load transcripts from GTF File. `chr` is preprended to each entry.""" - gtf = pr.read_gtf(path_to_gtf) - gtf.Chromosome = "chr" + gtf.Chromosome.astype(str) - transcripts = gtf[gtf.Feature == "transcript"] - return transcripts - - -def _peaks_correlations_per_gene( - gene, - atac_exprs, - rna_exprs, - atac_meta_ad, - peaks_pr, - transcripts, - span, - n_rand_sample=100, -): - # Gene transcript - use the longest transcript - gene_transcripts = transcripts[transcripts.gene_name == gene] - if len(gene_transcripts) == 0: - return 0 - longest_transcript = gene_transcripts[ - np.arange(len(gene_transcripts)) - == np.argmax(gene_transcripts.End - gene_transcripts.Start) - ] - start = longest_transcript.Start.values[0] - span - end = longest_transcript.End.values[0] + span - - # Gene span - gene_pr = pr.from_dict( - { - "Chromosome": [longest_transcript.Chromosome.values[0]], - "Start": [start], - "End": [end], - } - ) - gene_peaks = peaks_pr.overlap(gene_pr) - if len(gene_peaks) == 0: - return 0 - gene_peaks_str = _pyranges_to_strings(gene_peaks) - - # Compute correlations - X = atac_exprs.loc[:, gene_peaks_str].T - cors = 1 - np.ravel( - pairwise_distances( - np.apply_along_axis(rankdata, 1, X.values), - rankdata(rna_exprs[gene].T.values).reshape(1, -1), - metric="correlation", - ) - ) - cors = pd.Series(cors, index=gene_peaks_str) - - # Random background - df = pd.DataFrame(1.0, index=cors.index, columns=["cor", "pval"]) - df["cor"] = cors - for p in df.index: - # TODO: Handle exception properly - try: - # Try random sampling without replacement - rand_peaks = np.random.choice( - atac_meta_ad.var_names[ - (atac_meta_ad.var["GC_bin"] == atac_meta_ad.var["GC_bin"][p]) - & ( - atac_meta_ad.var["counts_bin"] - == atac_meta_ad.var["counts_bin"][p] - ) - ], - n_rand_sample, - False, - ) - except: # noqa: E722 - rand_peaks = np.random.choice( - atac_meta_ad.var_names[ - (atac_meta_ad.var["GC_bin"] == atac_meta_ad.var["GC_bin"][p]) - & ( - atac_meta_ad.var["counts_bin"] - == atac_meta_ad.var["counts_bin"][p] - ) - ], - n_rand_sample, - True, - ) - - if type(atac_exprs) is sc.AnnData: - X = pd.DataFrame(atac_exprs[:, rand_peaks].X.todense().T) - else: - X = atac_exprs.loc[:, rand_peaks].T - - rand_cors = 1 - np.ravel( - pairwise_distances( - np.apply_along_axis(rankdata, 1, X.values), - rankdata(rna_exprs[gene].T.values).reshape(1, -1), - metric="correlation", - ) - ) - - m = np.mean(rand_cors) - v = np.std(rand_cors) - - from scipy.stats import norm - - df.loc[p, "pval"] = 1 - norm.cdf(cors[p], m, v) - - return df - - -def get_gene_peak_correlations( - atac_meta_ad, - rna_meta_ad, - path_to_gtf, - gene_ranges=None, - span=100000, - n_jobs=1, - gene_set=None, -): - """Function to compute correlations between gene expression and peak accessibility. - - :param atac_meta_ad: (Anndata) ATAC metacell Anndata created using `prepare_multiome_anndata` - :param rna_meta_ad: (Anndata) RNA metacell Anndata created using `prepare_multiome_anndata` - :param path_to_gtf: (str or None) Path to ENSEMBL GTF file OR None if using pyranges object as input - :param gene_ranges: (pyranges or None) Pyranges object containing regions corresponding to custom annotation sets. Only used if path_to_gtf is None. - :param span: (int) Genomic window around the gene body to identify for which correlations with expression are computed - :param n_jobs: (int) Number of jobs for parallel processing - :param gene_set: (pd.Series) Subset of genes for which to compute correlations. All genes are used by default - - :return: `pd.Series` with a dataframe of correlation and p-value for each gene. Note that p-value is one-sided assuming positive correlations - """ - # ################################################################################# - print("Loading transcripts per gene...") - if path_to_gtf is None: - transcripts = gene_ranges - else: - transcripts = load_transcripts(path_to_gtf) - - print("Preparing matrices for gene-peak associations") - atac_exprs = pd.DataFrame( - atac_meta_ad.X.todense(), - index=atac_meta_ad.obs_names, - columns=atac_meta_ad.var_names, - ) - rna_exprs = pd.DataFrame( - rna_meta_ad.X.todense(), - index=rna_meta_ad.obs_names, - columns=rna_meta_ad.var_names, - ) - peaks_pr = _pyranges_from_strings(atac_meta_ad.var_names) - - print("Computing peak-gene correlations") - if gene_set is None: - use_genes = rna_meta_ad.var_names - else: - use_genes = gene_set - from joblib import Parallel, delayed - - gene_peak_correlations = Parallel(n_jobs=n_jobs)( - delayed(_peaks_correlations_per_gene)( - gene, atac_exprs, rna_exprs, atac_meta_ad, peaks_pr, transcripts, span - ) - for gene in tqdm(use_genes) - ) - gene_peak_correlations = pd.Series(gene_peak_correlations, index=use_genes) - return gene_peak_correlations - - -def get_gene_peak_assocations(gene_peak_correlations, pval_cutoff=1e-1, cor_cutoff=0.1): - """Determine the number of significantly correlated peaks per gene. - - :param gene_peak_correlations: (pd.Series) Output of `get_gene_peak_correlations` function - :param p_val_cutoff: (float) Nominal p-value cutoff for test of significance of correlation - :param cor_cutoff: (float) Correlation cutoff - - :return: `pd.Series` with number of significantly positive correlated peaks with each gene - """ - peak_counts = pd.Series(0, index=gene_peak_correlations.index) - for gene in tqdm(peak_counts.index): - df = gene_peak_correlations[gene] - if type(df) is int: - continue - gene_peaks = df.index[(df["pval"] < pval_cutoff) & (df["cor"] > cor_cutoff)] - peak_counts[gene] = len(gene_peaks) - - return peak_counts - - -def get_gene_scores( - atac_meta_ad, gene_peak_correlations, pval_cutoff=1e-1, cor_cutoff=0.1 -): - """Compute the aggregate accessibility of all peaks associated with each gene. - - Gene scores are computed as the aggregate accessibility of all the signficantly correlated peaks associated with a gene. - - :param atac_meta_ad: (Anndata) ATAC metacell Anndata created using `prepare_multiome_anndata` - :param gene_peak_correlations: (pd.Series) Output of `get_gene_peak_correlations` function - :param p_val_cutoff: (float) Nominal p-value cutoff for test of significance of correlation - :param cor_cutoff: (float) Correlation cutoff - - :return: `pd.DataFrame` of ATAC gene scores (cells X genes) - """ - gene_scores = pd.DataFrame( - 0.0, index=atac_meta_ad.obs_names, columns=gene_peak_correlations.index - ) - - for gene in tqdm(gene_scores.columns): - df = gene_peak_correlations[gene] - if type(df) is int: - continue - gene_peaks = df.index[(df["pval"] < pval_cutoff) & (df["cor"] > cor_cutoff)] - gene_scores[gene] = np.ravel( - np.dot(atac_meta_ad[:, gene_peaks].X.todense(), df.loc[gene_peaks, "cor"]) - ) - gene_scores = gene_scores.loc[:, (gene_scores.sum() >= 0)] - return gene_scores +import numpy as np +import pandas as pd +import pyranges as pr +import scanpy as sc +from scipy.stats import rankdata +from sklearn.metrics import pairwise_distances +from tqdm import tqdm + +from . import core + + +def prepare_multiome_anndata( + atac_ad, rna_ad, SEACells_label="SEACell", n_bins_for_gc=50 +): + """Function to create metacell Anndata objects from single-cell Anndata objects for multiome data. + + :param atac_ad: (Anndata) ATAC Anndata object with raw peak counts in `X`. These anndata objects should be constructed + using the example notebook available in + :param rna_ad: (Anndata) RNA Anndata object with raw gene expression counts in `X`. Note: RNA and ATAC anndata objects + should contain the same set of cells + :param SEACells_label: (str) `atac_ad.obs` field for constructing metacell matrices. Same field will be used for + summarizing RNA and ATAC metacells. + :param n_bins_gc: (int) Number of bins for creating GC bins of ATAC peaks. + :return: ATAC metacell Anndata object and RNA metacell Anndata object. + """ + # Subset of cells common to ATAC and RNA + common_cells = atac_ad.obs_names.intersection(rna_ad.obs_names) + if len(common_cells) != atac_ad.shape[0]: + print( + "Warning: The number of cells in RNA and ATAC objects are different. Only the common cells will be used." + ) + atac_mod_ad = atac_ad[common_cells, :] + rna_mod_ad = rna_ad[common_cells, :] + + # ################################################################################# + # Generate metacell matrices + + # Set of metacells. Cast both sides to string so that integer-valued labels + # do not cause a KeyError when indexing value_counts() with `metacells`. + seacell_labels = atac_mod_ad.obs[SEACells_label].astype(str) + metacells = seacell_labels.unique() + metacells = metacells[seacell_labels.value_counts()[metacells] > 1] + + print("Generating Metacell matrices...") + print(" ATAC") + atac_meta_ad = core.summarize_by_SEACell( + atac_mod_ad, SEACells_label=SEACells_label, summarize_layer="X" + ) + atac_meta_ad = atac_meta_ad[metacells, :] + # ATAC - Summarize SVD representation + + svd = pd.DataFrame(atac_mod_ad.obsm["X_svd"], index=atac_mod_ad.obs_names) + summ_svd = svd.groupby(atac_mod_ad.obs[SEACells_label]).mean() + atac_meta_ad.obsm["X_svd"] = summ_svd.loc[atac_meta_ad.obs_names, :].values + + # ATAC - Normalize + _add_atac_meta_data(atac_meta_ad, atac_mod_ad, n_bins_for_gc) + sc.pp.filter_genes(atac_meta_ad, min_cells=1) + _normalize_ad(atac_meta_ad) + + # RNA summaries using ATAC SEACells + print(" RNA") + rna_mod_ad.obs["temp"] = atac_mod_ad.obs[SEACells_label] + rna_meta_ad = core.summarize_by_SEACell( + rna_mod_ad, SEACells_label="temp", summarize_layer="X" + ) + rna_meta_ad = rna_meta_ad[metacells, :] + _normalize_ad(rna_meta_ad) + + return atac_meta_ad, rna_meta_ad + + +def _normalize_ad(meta_ad, save_raw=True): + if save_raw: + # Save in raw + meta_ad.raw = meta_ad.copy() + + # Normalize + sc.pp.normalize_total(meta_ad, key_added="n_counts") + sc.pp.log1p(meta_ad) + + +def _add_atac_meta_data(atac_meta_ad, atac_ad, n_bins_for_gc): + atac_ad.var["log_n_counts"] = np.ravel(np.log10(atac_ad.X.sum(axis=0))) + + atac_meta_ad.var["GC_bin"] = np.digitize( + atac_ad.var["GC"], np.linspace(0, 1, n_bins_for_gc) + ) + atac_meta_ad.var["counts_bin"] = np.digitize( + atac_ad.var["log_n_counts"], + np.linspace( + atac_ad.var["log_n_counts"].min(), + atac_ad.var["log_n_counts"].max(), + n_bins_for_gc, + ), + ) + + +def _pyranges_from_strings(pos_list): + """Function to create pyranges for a `pd.Series` of strings.""" + # Chromosome and positions + chr = pos_list.str.split(":").str.get(0) + start = pd.Series(pos_list.str.split(":").str.get(1)).str.split("-").str.get(0) + end = pd.Series(pos_list.str.split(":").str.get(1)).str.split("-").str.get(1) + + # Create ranges + gr = pr.PyRanges(chromosomes=chr, starts=start, ends=end) + return gr + + +def _pyranges_to_strings(peaks): + """Function to convert pyranges to `pd.Series` of strings of format 'chr:start-end'.""" + # Chromosome and positions + chr = peaks.Chromosome.astype(str).values + start = peaks.Start.astype(str).values + end = peaks.End.astype(str).values + + # Create ranges + gr = chr + ":" + start + "-" + end + + return gr + + +def load_transcripts(path_to_gtf): + """Load transcripts from GTF File. `chr` is preprended to each entry.""" + gtf = pr.read_gtf(path_to_gtf) + gtf.Chromosome = "chr" + gtf.Chromosome.astype(str) + transcripts = gtf[gtf.Feature == "transcript"] + return transcripts + + +def _peaks_correlations_per_gene( + gene, + atac_ranks, + rna_ranks, + atac_meta_ad, + peaks_pr, + transcripts, + span, + n_rand_sample=100, +): + """Compute peak-gene correlations using precomputed ranks. + + `atac_ranks` is a (M, peaks) DataFrame whose columns are the ranks of + each peak across metacells. `rna_ranks` is a (M, genes) DataFrame, ranks + of each gene across metacells. Hoisting these out of this function + removes a per-gene and per-random-resample call to rankdata, which + otherwise dominates the multiome pipeline runtime. + """ + # Gene transcript - use the longest transcript + gene_transcripts = transcripts[transcripts.gene_name == gene] + if len(gene_transcripts) == 0: + return 0 + longest_transcript = gene_transcripts[ + np.arange(len(gene_transcripts)) + == np.argmax(gene_transcripts.End - gene_transcripts.Start) + ] + start = longest_transcript.Start.values[0] - span + end = longest_transcript.End.values[0] + span + + # Gene span + gene_pr = pr.from_dict( + { + "Chromosome": [longest_transcript.Chromosome.values[0]], + "Start": [start], + "End": [end], + } + ) + gene_peaks = peaks_pr.overlap(gene_pr) + if len(gene_peaks) == 0: + return 0 + gene_peaks_str = _pyranges_to_strings(gene_peaks) + + # Compute correlations using precomputed ranks. The original code did + # np.apply_along_axis(rankdata, 1, X.values) for X = atac_exprs[:, peaks].T; + # that is exactly atac_ranks[:, peaks].T (rows are peaks, columns are + # metacells, values are per-peak ranks across metacells). + gene_rna_ranks = rna_ranks[gene].values.reshape(1, -1) + peak_ranks = atac_ranks.loc[:, gene_peaks_str].values.T + cors = 1 - np.ravel( + pairwise_distances(peak_ranks, gene_rna_ranks, metric="correlation") + ) + cors = pd.Series(cors, index=gene_peaks_str) + + # Random background + df = pd.DataFrame(1.0, index=cors.index, columns=["cor", "pval"]) + df["cor"] = cors + for p in df.index: + # TODO: Handle exception properly + try: + # Try random sampling without replacement + rand_peaks = np.random.choice( + atac_meta_ad.var_names[ + (atac_meta_ad.var["GC_bin"] == atac_meta_ad.var["GC_bin"][p]) + & ( + atac_meta_ad.var["counts_bin"] + == atac_meta_ad.var["counts_bin"][p] + ) + ], + n_rand_sample, + False, + ) + except: # noqa: E722 + rand_peaks = np.random.choice( + atac_meta_ad.var_names[ + (atac_meta_ad.var["GC_bin"] == atac_meta_ad.var["GC_bin"][p]) + & ( + atac_meta_ad.var["counts_bin"] + == atac_meta_ad.var["counts_bin"][p] + ) + ], + n_rand_sample, + True, + ) + + rand_peak_ranks = atac_ranks.loc[:, rand_peaks].values.T + rand_cors = 1 - np.ravel( + pairwise_distances(rand_peak_ranks, gene_rna_ranks, metric="correlation") + ) + + m = np.mean(rand_cors) + v = np.std(rand_cors) + + from scipy.stats import norm + + df.loc[p, "pval"] = 1 - norm.cdf(cors[p], m, v) + + return df + + +def get_gene_peak_correlations( + atac_meta_ad, + rna_meta_ad, + path_to_gtf, + gene_ranges=None, + span=100000, + n_jobs=1, + gene_set=None, +): + """Function to compute correlations between gene expression and peak accessibility. + + :param atac_meta_ad: (Anndata) ATAC metacell Anndata created using `prepare_multiome_anndata` + :param rna_meta_ad: (Anndata) RNA metacell Anndata created using `prepare_multiome_anndata` + :param path_to_gtf: (str or None) Path to ENSEMBL GTF file OR None if using pyranges object as input + :param gene_ranges: (pyranges or None) Pyranges object containing regions corresponding to custom annotation sets. Only used if path_to_gtf is None. + :param span: (int) Genomic window around the gene body to identify for which correlations with expression are computed + :param n_jobs: (int) Number of jobs for parallel processing + :param gene_set: (pd.Series) Subset of genes for which to compute correlations. All genes are used by default + + :return: `pd.Series` with a dataframe of correlation and p-value for each gene. Note that p-value is one-sided assuming positive correlations + """ + # ################################################################################# + print("Loading transcripts per gene...") + if path_to_gtf is None: + transcripts = gene_ranges + else: + transcripts = load_transcripts(path_to_gtf) + + print("Preparing matrices for gene-peak associations") + atac_exprs = pd.DataFrame( + atac_meta_ad.X.todense(), + index=atac_meta_ad.obs_names, + columns=atac_meta_ad.var_names, + ) + rna_exprs = pd.DataFrame( + rna_meta_ad.X.todense(), + index=rna_meta_ad.obs_names, + columns=rna_meta_ad.var_names, + ) + peaks_pr = _pyranges_from_strings(atac_meta_ad.var_names) + + # Hoist rankdata out of the per-gene loop. The original implementation + # ranked each gene's peaks (and each random-peak resample) on every call; + # since ranks are over metacells (axis 0) and the same matrices are reused + # for every gene, we precompute them once. The raw-expression DataFrames + # are no longer needed after this point and are freed to keep peak memory + # close to the prior implementation. + print("Precomputing rank matrices for Spearman correlation") + atac_ranks = pd.DataFrame( + rankdata(atac_exprs.values, axis=0), + index=atac_exprs.index, + columns=atac_exprs.columns, + ) + rna_ranks = pd.DataFrame( + rankdata(rna_exprs.values, axis=0), + index=rna_exprs.index, + columns=rna_exprs.columns, + ) + del atac_exprs, rna_exprs + + print("Computing peak-gene correlations") + if gene_set is None: + use_genes = rna_meta_ad.var_names + else: + use_genes = gene_set + from joblib import Parallel, delayed + + gene_peak_correlations = Parallel(n_jobs=n_jobs)( + delayed(_peaks_correlations_per_gene)( + gene, atac_ranks, rna_ranks, atac_meta_ad, peaks_pr, transcripts, span + ) + for gene in tqdm(use_genes) + ) + gene_peak_correlations = pd.Series(gene_peak_correlations, index=use_genes) + return gene_peak_correlations + + +def get_gene_peak_assocations(gene_peak_correlations, pval_cutoff=1e-1, cor_cutoff=0.1): + """Determine the number of significantly correlated peaks per gene. + + :param gene_peak_correlations: (pd.Series) Output of `get_gene_peak_correlations` function + :param p_val_cutoff: (float) Nominal p-value cutoff for test of significance of correlation + :param cor_cutoff: (float) Correlation cutoff + + :return: `pd.Series` with number of significantly positive correlated peaks with each gene + """ + peak_counts = pd.Series(0, index=gene_peak_correlations.index) + for gene in tqdm(peak_counts.index): + df = gene_peak_correlations[gene] + if type(df) is int: + continue + gene_peaks = df.index[(df["pval"] < pval_cutoff) & (df["cor"] > cor_cutoff)] + peak_counts[gene] = len(gene_peaks) + + return peak_counts + + +def get_gene_scores( + atac_meta_ad, gene_peak_correlations, pval_cutoff=1e-1, cor_cutoff=0.1 +): + """Compute the aggregate accessibility of all peaks associated with each gene. + + Gene scores are computed as the aggregate accessibility of all the signficantly correlated peaks associated with a gene. + + :param atac_meta_ad: (Anndata) ATAC metacell Anndata created using `prepare_multiome_anndata` + :param gene_peak_correlations: (pd.Series) Output of `get_gene_peak_correlations` function + :param p_val_cutoff: (float) Nominal p-value cutoff for test of significance of correlation + :param cor_cutoff: (float) Correlation cutoff + + :return: `pd.DataFrame` of ATAC gene scores (cells X genes) + """ + gene_scores = pd.DataFrame( + 0.0, index=atac_meta_ad.obs_names, columns=gene_peak_correlations.index + ) + + for gene in tqdm(gene_scores.columns): + df = gene_peak_correlations[gene] + if type(df) is int: + continue + gene_peaks = df.index[(df["pval"] < pval_cutoff) & (df["cor"] > cor_cutoff)] + gene_scores[gene] = np.ravel( + np.dot(atac_meta_ad[:, gene_peaks].X.todense(), df.loc[gene_peaks, "cor"]) + ) + gene_scores = gene_scores.loc[:, (gene_scores.sum() >= 0)] + return gene_scores diff --git a/SEACells/gpu.py b/SEACells/gpu.py index 6ffc449..735a2e8 100644 --- a/SEACells/gpu.py +++ b/SEACells/gpu.py @@ -1,779 +1,783 @@ -import cupy as cp -import cupyx -import numpy as np -import palantir -import pandas as pd -from scipy.sparse import save_npz -from tqdm import tqdm - -try: - from . import build_graph -except ImportError: - import build_graph - - -class SEACellsGPU: - """GPU Implementation of SEACells algorithm. - - The implementation uses fast kernel archetypal analysis to find SEACells - groupings - of cells that represent highly granular, distinct cell states. SEACells are found by solving a convex optimization - problem that minimizes the residual sum of squares between the kernel matrix and the weighted sum of the archetypes. - - Modifies annotated data matrix in place to include SEACell assignments in ad.obs['SEACell'] - - """ - - def __init__( - self, - ad, - build_kernel_on: str, - n_SEACells: int, - verbose: bool = True, - n_waypoint_eigs: int = 10, - n_neighbors: int = 15, - convergence_epsilon: float = 1e-3, - l2_penalty: float = 0, - max_franke_wolfe_iters: int = 50, - ): - """GPU Implementation of SEACells algorithm. - - :param ad: (AnnData) annotated data matrix - :param build_kernel_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells - Typically 'X_pca' for scRNA or 'X_svd' for scATAC - :param n_SEACells: (int) number of SEACells to compute - :param verbose: (bool) whether to suppress verbose program logging - :param n_waypoint_eigs: (int) number of eigenvectors to use for waypoint initialization - :param n_neighbors: (int) number of nearest neighbors to use for graph construction - :param convergence_epsilon: (float) convergence threshold for Franke-Wolfe algorithm - :param l2_penalty: (float) L2 penalty for Franke-Wolfe algorithm - :param max_franke_wolfe_iters: (int) maximum number of iterations for Franke-Wolfe algorithm - - Class Attributes: - ad: (AnnData) annotated data matrix - build_kernel_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells - n_cells: (int) number of cells in ad - k: (int) number of SEACells to compute - n_waypoint_eigs: (int) number of eigenvectors to use for waypoint initialization - waypoint_proportion: (float) proportion of cells to use for waypoint initialization - n_neighbors: (int) number of nearest neighbors to use for graph construction - max_FW_iter: (int) maximum number of iterations for Franke-Wolfe algorithm - verbose: (bool) whether to suppress verbose program logging - l2_penalty: (float) L2 penalty for Franke-Wolfe algorithm - RSS_iters: (list) list of residual sum of squares at each iteration of Franke-Wolfe algorithm - convergence_epsilon: (float) algorithm converges when RSS < convergence_epsilon * RSS(0) - convergence_threshold: (float) convergence threshold for Franke-Wolfe algorithm - kernel_matrix: (csr_matrix) kernel matrix of shape (n_cells, n_cells) - K: (csr_matrix) dot product of kernel matrix with itself, K = K @ K.T - archetypes: (list) list of cell indices corresponding to archetypes - A_: (csr_matrix) matrix of shape (k, n) containing final assignments of cells to SEACells - B_: (csr_matrix) matrix of shape (n, k) containing archetype weights - A0: (csr_matrix) matrix of shape (k, n) containing initial assignments of cells to SEACells - B0: (csr_matrix) matrix of shape (n, k) containing initial archetype weights - """ - print("Welcome to SEACells GPU!") - self.ad = ad - self.build_kernel_on = build_kernel_on - self.n_cells = ad.shape[0] - - if not isinstance(n_SEACells, int): - try: - n_SEACells = int(n_SEACells) - except ValueError: - raise ValueError( - f"The number of SEACells specified must be an integer type, not {type(n_SEACells)}" - ) - - self.k = n_SEACells - - self.n_waypoint_eigs = n_waypoint_eigs - self.waypoint_proportion = 1 - self.n_neighbors = n_neighbors - - self.max_FW_iter = max_franke_wolfe_iters - self.verbose = verbose - self.l2_penalty = l2_penalty - - self.RSS_iters = [] - self.convergence_epsilon = convergence_epsilon - self.convergence_threshold = None - - # Parameters to be initialized later in the model - self.kernel_matrix = None - self.K = None - - # Archetypes as list of cell indices - self.archetypes = None - - self.A_ = None - self.B_ = None - self.B0 = None - - return - - def add_precomputed_kernel_matrix(self, K): - """Add precomputed kernel matrix to SEACells object. - - :param K: (np.ndarray) kernel matrix of shape (n_cells, n_cells) - :return: None. - """ - assert K.shape == ( - self.n_cells, - self.n_cells, - ), f"Dimension of kernel matrix must be n_cells = ({self.n_cells},{self.n_cells}), not {K.shape} " - self.kernel_matrix = K - - # Pre-compute dot product - self.K = self.kernel_matrix @ self.kernel_matrix.T - - def construct_kernel_matrix( - self, n_neighbors: int = None, graph_construction="union" - ): - """Construct kernel matrix from data matrix using PCA/SVD and nearest neighbors. - - :param n_neighbors: (int) number of nearest neighbors to use for graph construction. - If none, use self.n_neighbors, which has a default value of 15. - :param graph_construction: (str) method for graph construction. Options are 'union' or 'intersection'. - Default is 'union', where the neighborhood graph is made symmetric by adding an edge - (u,v) if either (u,v) or (v,u) is in the neighborhood graph. If 'intersection', the - neighborhood graph is made symmetric by adding an edge (u,v) if both (u,v) and (v,u) - are in the neighborhood graph. - :return: None. - """ - # input to graph construction is PCA/SVD - kernel_model = build_graph.SEACellGraph( - self.ad, self.build_kernel_on, verbose=self.verbose - ) - - # K is a sparse matrix representing input to SEACell alg - if n_neighbors is None: - n_neighbors = self.n_neighbors - - M = kernel_model.rbf(n_neighbors, graph_construction=graph_construction) - self.kernel_matrix = M - - # Pre-compute dot product - self.K = self.kernel_matrix @ self.kernel_matrix.T - return - - def initialize_archetypes(self): - """Initialize B matrix which defines cells as SEACells. - - Uses waypoint analysis for initialization into to fully cover the phenotype space, and then greedily - selects the remaining cells (if redundant cells are selected by waypoint analysis). - - Modifies self.archetypes in-place with the indices of cells that are used as initialization for archetypes. - - By default, the proportion of cells selected by waypoint analysis is 1. This can be changed by setting the - waypoint_proportion parameter in the SEACells object. For example, setting waypoint_proportion = 0.5 will - select half of the cells by waypoint analysis and half by greedy selection. - """ - k = self.k - - if self.waypoint_proportion > 0: - waypoint_ix = self._get_waypoint_centers(k) - waypoint_ix = np.random.choice( - waypoint_ix, - int(len(waypoint_ix) * self.waypoint_proportion), - replace=False, - ) - from_greedy = self.k - len(waypoint_ix) - if self.verbose: - print( - f"Selecting {len(waypoint_ix)} cells from waypoint initialization." - ) - - else: - from_greedy = self.k - - greedy_ix = self._get_greedy_centers(n_SEACells=from_greedy + 10) - if self.verbose: - print(f"Selecting {from_greedy} cells from greedy initialization.") - - if self.waypoint_proportion > 0: - all_ix = np.hstack([waypoint_ix, greedy_ix]) - else: - all_ix = np.hstack([greedy_ix]) - - unique_ix, ind = np.unique(all_ix, return_index=True) - all_ix = unique_ix[np.argsort(ind)][:k] - self.archetypes = all_ix - - def initialize(self, initial_archetypes=None, initial_assignments=None): - """Initialize the model. - - Initializes the B matrix (constructs archetypes from a convex combination of cells) and the A matrix - (defines assignments of cells to archetypes). - - Assumes the kernel matrix has already been constructed. B matrix is of shape (n_cells, n_SEACells) and A matrix - is of shape (n_SEACells, n_cells). - - :param initial_archetypes: (np.ndarray) initial archetypes to use for initialization. If None, use waypoint - analysis and greedy selection to initialize archetypes. - :param initial_assignments: (np.ndarray) initial assignments to use for initialization. If None, use - random initialization. - :return: None - """ - if self.K is None: - raise RuntimeError( - "Must first construct kernel matrix before initializing SEACells." - ) - # initialize B (update this to allow initialization from RRQR) - n = self.K.shape[0] - - if initial_archetypes is not None: - if self.verbose: - print("Using provided list of initial archetypes") - self.archetypes = initial_archetypes - - if self.archetypes is None: - self.initialize_archetypes() - - self.k = len(self.archetypes) - k = self.k - - # Construction of B matrix - B0 = np.zeros((n, k)) - all_ix = self.archetypes - idx1 = list(zip(all_ix, np.arange(k))) - B0[tuple(zip(*idx1))] = 1 - self.B0 = B0 - B = self.B0.copy() - - if initial_assignments is not None: - A0 = initial_assignments - assert A0.shape == ( - k, - n, - ), f"Initial assignment matrix should be of shape (k={k} x n={n})" - - else: - A0 = np.random.random((k, n)) - A0 /= A0.sum(0) - if self.verbose: - print("Randomly initialized A matrix.") - - self.A0 = A0 - A = self.A0.copy() - A = self._updateA(B, A) - - self.A_ = A - self.B_ = B - - # Create convergence threshold - RSS = self.compute_RSS(A, B) - self.RSS_iters.append(RSS) - - if self.convergence_threshold is None: - self.convergence_threshold = self.convergence_epsilon * RSS - if self.verbose: - print( - f"Setting convergence threshold at {self.convergence_threshold:.5f}" - ) - - def _get_waypoint_centers(self, n_waypoints=None): - """Initialize B matrix using waypoint analysis, as described in Palantir. - - From https://www.nature.com/articles/s41587-019-0068-4. - - :param n_waypoints: (int) number of SEACells to initialize using waypoint analysis. If None specified, - all SEACells initialized using this method. - :return: (np.ndarray) indices of cells to use as initial archetypes - """ - if n_waypoints is None: - k = self.k - else: - k = n_waypoints - - ad = self.ad - - if self.build_kernel_on == "X_pca": - pca_components = pd.DataFrame(ad.obsm["X_pca"]).set_index(ad.obs_names) - elif self.build_kernel_on == "X_svd": - # Compute PCA components from ad object - pca_components = pd.DataFrame(ad.obsm["X_svd"]).set_index(ad.obs_names) - else: - pca_components = pd.DataFrame(ad.obsm[self.build_kernel_on]).set_index( - ad.obs_names - ) - - print(f"Building kernel on {self.build_kernel_on}") - - if self.verbose: - print( - f"Computing diffusion components from {self.build_kernel_on} for waypoint initialization ... " - ) - - dm_res = palantir.utils.run_diffusion_maps( - pca_components, n_components=self.n_neighbors - ) - dc_components = palantir.utils.determine_multiscale_space( - dm_res, n_eigs=self.n_waypoint_eigs - ) - if self.verbose: - print("Done.") - - # Initialize SEACells via waypoint sampling - if self.verbose: - print("Sampling waypoints ...") - waypoint_init = palantir.core._max_min_sampling( - data=dc_components, num_waypoints=k - ) - dc_components["iix"] = np.arange(len(dc_components)) - waypoint_ix = dc_components.loc[waypoint_init]["iix"].values - if self.verbose: - print("Done.") - - return waypoint_ix - - def _get_greedy_centers(self, n_SEACells=None): - """Initialize SEACells using fast greedy adaptive CSSP. - - From https://arxiv.org/pdf/1312.6838.pdf - :param n_SEACells: (int) number of SEACells to initialize using greedy selection. If None specified, - all SEACells initialized using this method. - :return: (np.ndarray) indices of cells to use as initial archetypes - """ - n = self.K.shape[0] - - if n_SEACells is None: - k = self.k - else: - k = n_SEACells - - if self.verbose: - print("Initializing residual matrix using greedy column selection") - - # precompute M.T * M - # ATA = M.T @ M - ATA = self.K - - if self.verbose: - print("Initializing f and g...") - - f = np.array((ATA.multiply(ATA)).sum(axis=0)).ravel() - g = np.array(ATA.diagonal()).ravel() - - d = np.zeros((k, n)) - omega = np.zeros((k, n)) - - # keep track of selected indices - centers = np.zeros(k, dtype=int) - - # sampling - for j in tqdm(range(k)): - score = f / g - p = np.argmax(score) - - # print residuals - np.sum(f) - - delta_term1 = ATA[:, p].toarray().squeeze() - # print(delta_term1) - delta_term2 = ( - np.multiply(omega[:, p].reshape(-1, 1), omega).sum(axis=0).squeeze() - ) - delta = delta_term1 - delta_term2 - - # some weird rounding errors - delta[p] = np.max([0, delta[p]]) - - o = delta / np.max([np.sqrt(delta[p]), 1e-6]) - omega_square_norm = np.linalg.norm(o) ** 2 - omega_hadamard = np.multiply(o, o) - term1 = omega_square_norm * omega_hadamard - - # update f (term2) - pl = np.zeros(n) - for r in range(j): - omega_r = omega[r, :] - pl += np.dot(omega_r, o) * omega_r - - ATAo = (ATA @ o.reshape(-1, 1)).ravel() - term2 = np.multiply(o, ATAo - pl) - - # update f - f += -2.0 * term2 + term1 - - # update g - g += omega_hadamard - - # store omega and delta - d[j, :] = delta - omega[j, :] = o - - # add index - centers[j] = int(p) - - return centers - - def _updateA(self, B, A_prev): - """Compute assignment matrix A using constrained gradient descent via Frank-Wolfe algorithm. - - Given archetype matrix B and using kernel matrix K, compute assignment matrix A using constrained gradient - descent via Frank-Wolfe algorithm. - - :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells - :param A_prev: (n x k csr_matrix) defining previous weights used for assigning cells to SEACells - :return: (n x k csr_matrix) defining updated weights used for assigning cells to SEACells - """ - n, k = B.shape - A = A_prev - - t = 0 # current iteration (determine multiplicative update) - - Ag = cp.array(A) - Bg = cp.array(B) - Kg = cupyx.scipy.sparse.csc_matrix(self.K) - - # precompute some gradient terms - t2g = Kg.dot(Bg).T - t1g = t2g.dot(Bg) - - # update rows of A for given number of iterations - while t < self.max_FW_iter: - # compute gradient (must convert matrix to ndarray) - Gg = cp.multiply(2, cp.subtract(t1g.dot(Ag), t2g)) - - # get argmins - amins = cp.argmin(Gg, axis=0) - - # loop free implementation - eg = cp.zeros((k, n)) - eg[amins, cp.arange(n)] = 1.0 - - f = 2.0 / (t + 2.0) - Ag = cp.add(Ag, cp.multiply(f, cp.subtract(eg, Ag))) - t += 1 - - A = Ag.get() - - del t1g, t2g, Ag, Kg, Gg, Bg, eg, amins - cp._default_memory_pool.free_all_blocks() - - return A - - def _updateB(self, A, B_prev): - """Compute archetype matrix B using constrained gradient descent via Frank-Wolfe algorithm. - - Given assignment matrix A and using kernel matrix K, compute archetype matrix B using constrained gradient - descent via Frank-Wolfe algorithm. - - :param A: (n x k csr_matrix) defining weights used for assigning cells to SEACells - :param B_prev: (n x k csr_matrix) defining previous SEACells as weighted combinations of cells - :return: (n x k csr_matrix) defining updated SEACells as weighted combinations of cells - """ - k, n = A.shape - - B = B_prev - - # keep track of error - t = 0 - - Ag = cp.array(A) - Bg = cp.array(B) - Kg = cupyx.scipy.sparse.csc_matrix(self.K) - # precompute some terms - t1g = Ag.dot(Ag.T) - t2g = Kg.dot(Ag.T) - - # update rows of B for a given number of iterations - while t < 50: - # compute gradient - Gg = cp.multiply(2, cp.subtract(Kg.dot(Bg).dot(t1g), t2g)) - - # get all argmins - amins = cp.argmin(Gg, axis=0) - - eg = cp.zeros((n, k)) - eg[amins, cp.arange(k)] = 1.0 - - f = 2.0 / (t + 2.0) - Bg = cp.add(Bg, cp.multiply(f, cp.subtract(eg, Bg))) - - t += 1 - - B = Bg.get() - - del ( - t1g, - t2g, - Ag, - Kg, - Gg, - Bg, - eg, - amins, - ) - cp._default_memory_pool.free_all_blocks() - - return B - - def compute_reconstruction(self, A=None, B=None): - """Compute reconstructed data matrix using learned archetypes (SEACells) and assignments. - - :param A: (k x n csr_matrix) defining weights used for assigning cells to SEACells - If None provided, self.A is used. - :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells - If None provided, self.B is used. - :return: (n x n csr_matrix) defining reconstructed data matrix. - """ - if A is None: - A = self.A_ - if B is None: - B = self.B_ - - if A is None or B is None: - raise RuntimeError( - "Either assignment matrix A or archetype matrix B is None." - ) - return (self.kernel_matrix.dot(B)).dot(A) - - def compute_RSS(self, A=None, B=None): - """Compute residual sum of squares error in difference between reconstruction and true data matrix. - - :param A: (k x n csr_matrix) defining weights used for assigning cells to SEACells - If None provided, self.A is used. - :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells - If None provided, self.B is used. - :return: - ||X-XBA||^2 - (float) square difference between true data and reconstruction. - """ - if A is None: - A = self.A_ - if B is None: - B = self.B_ - - reconstruction = self.compute_reconstruction(A, B) - return np.linalg.norm(self.kernel_matrix - reconstruction) - - def plot_convergence(self, save_as=None, show=True): - """Plot behaviour of squared error over iterations. - - :param save_as: (str) name of file which figure is saved as. If None, no plot is saved. - :param show: (bool) whether to show plot - :return: None. - """ - import matplotlib.pyplot as plt - - plt.figure() - plt.plot(self.RSS_iters) - plt.title("Reconstruction Error over Iterations") - plt.xlabel("Iterations") - plt.ylabel("Squared Error") - if save_as is not None: - plt.savefig(save_as, dpi=150) - if show: - plt.show() - plt.close() - - def step(self): - """Perform one iteration of SEACell algorithm. Update assignment matrix A and archetype matrix B. - - :return: None. - """ - A = self.A_ - B = self.B_ - - if self.K is None: - raise RuntimeError( - "Kernel matrix has not been computed. Run model.construct_kernel_matrix() first." - ) - - if A is None: - raise RuntimeError( - "Cell to SEACell assignment matrix has not been initialised. Run model.initialize() first." - ) - - if B is None: - raise RuntimeError( - "Archetype matrix has not been initialised. Run model.initialize() first." - ) - - A = self._updateA(B, A) - B = self._updateB(A, B) - - self.RSS_iters.append(self.compute_RSS(A, B)) - - self.A_ = A - self.B_ = B - - del A, B - - # Label cells by SEACells assignment - labels = self.get_hard_assignments() - self.ad.obs["SEACell"] = labels["SEACell"] - - return - - def _fit( - self, - max_iter: int = 50, - min_iter: int = 10, - initial_archetypes=None, - initial_assignments=None, - ): - """Internal method to compute archetypes and loadings given kernel matrix K. - - Iteratively updates A and B matrices until maximum number of iterations or convergence has been achieved. - - Modifies ad.obs in place to add 'SEACell' labels to cells. - :param max_iter: (int) maximum number of iterations to perform - :param min_iter: (int) minimum number of iterations to perform - :param initial_archetypes: (array) initial archetypes to use. If None, random initialisation is used. - :param initial_assignments: (array) initial assignments to use. If None, random initialisation is used. - :return: None - """ - self.initialize( - initial_archetypes=initial_archetypes, - initial_assignments=initial_assignments, - ) - - converged = False - n_iter = 0 - while (not converged and n_iter < max_iter) or n_iter < min_iter: - n_iter += 1 - if n_iter == 1 or (n_iter) % 10 == 0: - if self.verbose: - print(f"Starting iteration {n_iter}.") - self.step() - - if n_iter == 1 or (n_iter) % 10 == 0: - if self.verbose: - print(f"Completed iteration {n_iter}.") - - # Check for convergence - if ( - np.abs(self.RSS_iters[-2] - self.RSS_iters[-1]) - < self.convergence_threshold - ): - if self.verbose: - print(f"Converged after {n_iter} iterations.") - converged = True - - self.Z_ = self.B_.T @ self.K - - # Label cells by SEACells assignment - labels = self.get_hard_assignments() - self.ad.obs["SEACell"] = labels["SEACell"] - - if not converged: - raise RuntimeWarning( - "Warning: Algorithm has not converged - you may need to increase the maximum number of iterations" - ) - return - - def fit( - self, - max_iter: int = 100, - min_iter: int = 10, - initial_archetypes=None, - initial_assignments=None, - ): - """Compute archetypes and loadings given kernel matrix K. - - Iteratively updates A and B matrices until maximum number of iterations or convergence has been achieved. - :param max_iter: (int) maximum number of iterations to perform (default 100) - :param min_iter: (int) minimum number of iterations to perform (default 10) - :param initial_archetypes: (array) initial archetypes to use. If None, random initialisation is used. - :param initial_assignments: (array) initial assignments to use. If None, random initialisation is used. - :return: None. - """ - if max_iter < min_iter: - raise ValueError( - "The maximum number of iterations specified is lower than the minimum number of iterations specified." - ) - self._fit( - max_iter=max_iter, - min_iter=min_iter, - initial_archetypes=initial_archetypes, - initial_assignments=initial_assignments, - ) - - def get_archetype_matrix(self): - """Return k x n matrix of archetypes computed as the product of the archetype matrix B and the kernel matrix K.""" - return self.Z_ - - def get_soft_assignments(self): - """Return soft SEACells assignment. - - Returns a tuple of (labels, weights) where labels is a dataframe with SEACell assignments for the top 5 - SEACell assignments for each cell and weights is an array with the corresponding weights for each assignment. - :return: (pd.DataFrame, np.array) with labels and weights. - """ - import copy - - archetype_labels = self.get_hard_archetypes() - A = copy.deepcopy(self.A_.T) - - labels = [] - weights = [] - for _i in range(5): - l = A.argmax(1) - labels.append(archetype_labels[l]) - weights.append(A[np.arange(A.shape[0]), l]) - A[np.arange(A.shape[0]), l] = -1 - - weights = np.vstack(weights).T - labels = np.vstack(labels).T - - soft_labels = pd.DataFrame(labels) - soft_labels.index = self.ad.obs_names - - return soft_labels, weights - - def get_hard_assignments(self): - """Return a dataframe with the SEACell assignment for each cell. - - The assignment is the SEACell with the highest assignment weight. - - :return: (pd.DataFrame) with SEACell assignments. - """ - # Use argmax to get the index with the highest assignment weight - - df = pd.DataFrame({"SEACell": [f"SEACell-{i}" for i in self.A_.argmax(0)]}) - df.index = self.ad.obs_names - df.index.name = "index" - - return df - - def get_hard_archetypes(self): - """Return the names of cells most strongly identified as archetypes. - - :return list of archetype names. - """ - return self.ad.obs_names[self.B_.argmax(0)] - - def save_model(self, outdir): - """Save the model to a pickle file. - - :param outdir: (str) path to directory to save to - :return: None. - """ - import pickle - - with open(outdir + "/model.pkl", "wb") as f: - pickle.dump(self, f) - return None - - def save_assignments(self, outdir): - """Save SEACell assignments. - - Saves: - (1) the cell to SEACell assignments to a csv file with the name 'SEACells.csv'. - (2) the kernel matrix to a .npz file with the name 'kernel_matrix.npz'. - (3) the archetype matrix to a .npz file with the name 'A.npz'. - (4) the loading matrix to a .npz file with the name 'B.npz'. - - :param outdir: (str) path to directory to save to - :return: None - """ - import os - - os.makedirs(outdir, exist_ok=True) - save_npz(outdir + "/kernel_matrix.npz", self.kernel_matrix) - save_npz(outdir + "/A.npz", self.A_.T) - save_npz(outdir + "/B.npz", self.B_) - - labels = self.get_hard_assignments() - labels.to_csv(outdir + "/SEACells.csv") - return None +import cupy as cp +import cupyx +import numpy as np +import palantir +import pandas as pd +from scipy.sparse import save_npz +from tqdm import tqdm + +try: + from . import build_graph +except ImportError: + import build_graph + + +class SEACellsGPU: + """GPU Implementation of SEACells algorithm. + + The implementation uses fast kernel archetypal analysis to find SEACells - groupings + of cells that represent highly granular, distinct cell states. SEACells are found by solving a convex optimization + problem that minimizes the residual sum of squares between the kernel matrix and the weighted sum of the archetypes. + + Modifies annotated data matrix in place to include SEACell assignments in ad.obs['SEACell'] + + """ + + def __init__( + self, + ad, + build_kernel_on: str, + n_SEACells: int, + verbose: bool = True, + n_waypoint_eigs: int = 10, + n_neighbors: int = 15, + convergence_epsilon: float = 1e-3, + l2_penalty: float = 0, + max_franke_wolfe_iters: int = 50, + ): + """GPU Implementation of SEACells algorithm. + + :param ad: (AnnData) annotated data matrix + :param build_kernel_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells + Typically 'X_pca' for scRNA or 'X_svd' for scATAC + :param n_SEACells: (int) number of SEACells to compute + :param verbose: (bool) whether to suppress verbose program logging + :param n_waypoint_eigs: (int) number of eigenvectors to use for waypoint initialization + :param n_neighbors: (int) number of nearest neighbors to use for graph construction + :param convergence_epsilon: (float) convergence threshold for Franke-Wolfe algorithm + :param l2_penalty: (float) L2 penalty for Franke-Wolfe algorithm + :param max_franke_wolfe_iters: (int) maximum number of iterations for Franke-Wolfe algorithm + + Class Attributes: + ad: (AnnData) annotated data matrix + build_kernel_on: (str) key corresponding to matrix in ad.obsm which is used to compute kernel for metacells + n_cells: (int) number of cells in ad + k: (int) number of SEACells to compute + n_waypoint_eigs: (int) number of eigenvectors to use for waypoint initialization + waypoint_proportion: (float) proportion of cells to use for waypoint initialization + n_neighbors: (int) number of nearest neighbors to use for graph construction + max_FW_iter: (int) maximum number of iterations for Franke-Wolfe algorithm + verbose: (bool) whether to suppress verbose program logging + l2_penalty: (float) L2 penalty for Franke-Wolfe algorithm + RSS_iters: (list) list of residual sum of squares at each iteration of Franke-Wolfe algorithm + convergence_epsilon: (float) algorithm converges when RSS < convergence_epsilon * RSS(0) + convergence_threshold: (float) convergence threshold for Franke-Wolfe algorithm + kernel_matrix: (csr_matrix) kernel matrix of shape (n_cells, n_cells) + K: (csr_matrix) dot product of kernel matrix with itself, K = K @ K.T + archetypes: (list) list of cell indices corresponding to archetypes + A_: (csr_matrix) matrix of shape (k, n) containing final assignments of cells to SEACells + B_: (csr_matrix) matrix of shape (n, k) containing archetype weights + A0: (csr_matrix) matrix of shape (k, n) containing initial assignments of cells to SEACells + B0: (csr_matrix) matrix of shape (n, k) containing initial archetype weights + """ + print("Welcome to SEACells GPU!") + self.ad = ad + self.build_kernel_on = build_kernel_on + self.n_cells = ad.shape[0] + + if not isinstance(n_SEACells, int): + try: + n_SEACells = int(n_SEACells) + except ValueError: + raise ValueError( + f"The number of SEACells specified must be an integer type, not {type(n_SEACells)}" + ) + + self.k = n_SEACells + + self.n_waypoint_eigs = n_waypoint_eigs + self.waypoint_proportion = 1 + self.n_neighbors = n_neighbors + + self.max_FW_iter = max_franke_wolfe_iters + self.verbose = verbose + self.l2_penalty = l2_penalty + + self.RSS_iters = [] + self.convergence_epsilon = convergence_epsilon + self.convergence_threshold = None + + # Parameters to be initialized later in the model + self.kernel_matrix = None + self.K = None + + # Archetypes as list of cell indices + self.archetypes = None + + self.A_ = None + self.B_ = None + self.B0 = None + + return + + def add_precomputed_kernel_matrix(self, K): + """Add precomputed kernel matrix to SEACells object. + + :param K: (np.ndarray) kernel matrix of shape (n_cells, n_cells) + :return: None. + """ + assert K.shape == ( + self.n_cells, + self.n_cells, + ), f"Dimension of kernel matrix must be n_cells = ({self.n_cells},{self.n_cells}), not {K.shape} " + self.kernel_matrix = K + + # Pre-compute dot product + self.K = self.kernel_matrix @ self.kernel_matrix.T + + def construct_kernel_matrix( + self, n_neighbors: int = None, graph_construction="union" + ): + """Construct kernel matrix from data matrix using PCA/SVD and nearest neighbors. + + :param n_neighbors: (int) number of nearest neighbors to use for graph construction. + If none, use self.n_neighbors, which has a default value of 15. + :param graph_construction: (str) method for graph construction. Options are 'union' or 'intersection'. + Default is 'union', where the neighborhood graph is made symmetric by adding an edge + (u,v) if either (u,v) or (v,u) is in the neighborhood graph. If 'intersection', the + neighborhood graph is made symmetric by adding an edge (u,v) if both (u,v) and (v,u) + are in the neighborhood graph. + :return: None. + """ + # input to graph construction is PCA/SVD + kernel_model = build_graph.SEACellGraph( + self.ad, self.build_kernel_on, verbose=self.verbose + ) + + # K is a sparse matrix representing input to SEACell alg + if n_neighbors is None: + n_neighbors = self.n_neighbors + + M = kernel_model.rbf(n_neighbors, graph_construction=graph_construction) + self.kernel_matrix = M + + # Pre-compute dot product + self.K = self.kernel_matrix @ self.kernel_matrix.T + return + + def initialize_archetypes(self): + """Initialize B matrix which defines cells as SEACells. + + Uses waypoint analysis for initialization into to fully cover the phenotype space, and then greedily + selects the remaining cells (if redundant cells are selected by waypoint analysis). + + Modifies self.archetypes in-place with the indices of cells that are used as initialization for archetypes. + + By default, the proportion of cells selected by waypoint analysis is 1. This can be changed by setting the + waypoint_proportion parameter in the SEACells object. For example, setting waypoint_proportion = 0.5 will + select half of the cells by waypoint analysis and half by greedy selection. + """ + k = self.k + + if self.waypoint_proportion > 0: + waypoint_ix = self._get_waypoint_centers(k) + waypoint_ix = np.random.choice( + waypoint_ix, + int(len(waypoint_ix) * self.waypoint_proportion), + replace=False, + ) + from_greedy = self.k - len(waypoint_ix) + if self.verbose: + print( + f"Selecting {len(waypoint_ix)} cells from waypoint initialization." + ) + + else: + from_greedy = self.k + + greedy_ix = self._get_greedy_centers(n_SEACells=from_greedy + 10) + if self.verbose: + print(f"Selecting {from_greedy} cells from greedy initialization.") + + if self.waypoint_proportion > 0: + all_ix = np.hstack([waypoint_ix, greedy_ix]) + else: + all_ix = np.hstack([greedy_ix]) + + unique_ix, ind = np.unique(all_ix, return_index=True) + all_ix = unique_ix[np.argsort(ind)][:k] + self.archetypes = all_ix + + def initialize(self, initial_archetypes=None, initial_assignments=None): + """Initialize the model. + + Initializes the B matrix (constructs archetypes from a convex combination of cells) and the A matrix + (defines assignments of cells to archetypes). + + Assumes the kernel matrix has already been constructed. B matrix is of shape (n_cells, n_SEACells) and A matrix + is of shape (n_SEACells, n_cells). + + :param initial_archetypes: (np.ndarray) initial archetypes to use for initialization. If None, use waypoint + analysis and greedy selection to initialize archetypes. + :param initial_assignments: (np.ndarray) initial assignments to use for initialization. If None, use + random initialization. + :return: None + """ + if self.K is None: + raise RuntimeError( + "Must first construct kernel matrix before initializing SEACells." + ) + # initialize B (update this to allow initialization from RRQR) + n = self.K.shape[0] + + if initial_archetypes is not None: + if self.verbose: + print("Using provided list of initial archetypes") + self.archetypes = initial_archetypes + + if self.archetypes is None: + self.initialize_archetypes() + + self.k = len(self.archetypes) + k = self.k + + # Construction of B matrix + B0 = np.zeros((n, k)) + all_ix = self.archetypes + idx1 = list(zip(all_ix, np.arange(k))) + B0[tuple(zip(*idx1))] = 1 + self.B0 = B0 + B = self.B0.copy() + + if initial_assignments is not None: + A0 = initial_assignments + assert A0.shape == ( + k, + n, + ), f"Initial assignment matrix should be of shape (k={k} x n={n})" + + else: + A0 = np.random.random((k, n)) + A0 /= A0.sum(0) + if self.verbose: + print("Randomly initialized A matrix.") + + self.A0 = A0 + A = self.A0.copy() + A = self._updateA(B, A) + + self.A_ = A + self.B_ = B + + # Create convergence threshold + RSS = self.compute_RSS(A, B) + self.RSS_iters.append(RSS) + + if self.convergence_threshold is None: + self.convergence_threshold = self.convergence_epsilon * RSS + if self.verbose: + print( + f"Setting convergence threshold at {self.convergence_threshold:.5f}" + ) + + def _get_waypoint_centers(self, n_waypoints=None): + """Initialize B matrix using waypoint analysis, as described in Palantir. + + From https://www.nature.com/articles/s41587-019-0068-4. + + :param n_waypoints: (int) number of SEACells to initialize using waypoint analysis. If None specified, + all SEACells initialized using this method. + :return: (np.ndarray) indices of cells to use as initial archetypes + """ + if n_waypoints is None: + k = self.k + else: + k = n_waypoints + + ad = self.ad + + if self.build_kernel_on == "X_pca": + pca_components = pd.DataFrame(ad.obsm["X_pca"]).set_index(ad.obs_names) + elif self.build_kernel_on == "X_svd": + # Compute PCA components from ad object + pca_components = pd.DataFrame(ad.obsm["X_svd"]).set_index(ad.obs_names) + else: + pca_components = pd.DataFrame(ad.obsm[self.build_kernel_on]).set_index( + ad.obs_names + ) + + print(f"Building kernel on {self.build_kernel_on}") + + if self.verbose: + print( + f"Computing diffusion components from {self.build_kernel_on} for waypoint initialization ... " + ) + + dm_res = palantir.utils.run_diffusion_maps( + pca_components, n_components=self.n_neighbors + ) + dc_components = palantir.utils.determine_multiscale_space( + dm_res, n_eigs=self.n_waypoint_eigs + ) + if self.verbose: + print("Done.") + + # Initialize SEACells via waypoint sampling + if self.verbose: + print("Sampling waypoints ...") + waypoint_init = palantir.core._max_min_sampling( + data=dc_components, num_waypoints=k + ) + dc_components["iix"] = np.arange(len(dc_components)) + waypoint_ix = dc_components.loc[waypoint_init]["iix"].values + if self.verbose: + print("Done.") + + return waypoint_ix + + def _get_greedy_centers(self, n_SEACells=None): + """Initialize SEACells using fast greedy adaptive CSSP. + + From https://arxiv.org/pdf/1312.6838.pdf + :param n_SEACells: (int) number of SEACells to initialize using greedy selection. If None specified, + all SEACells initialized using this method. + :return: (np.ndarray) indices of cells to use as initial archetypes + """ + n = self.K.shape[0] + + if n_SEACells is None: + k = self.k + else: + k = n_SEACells + + if self.verbose: + print("Initializing residual matrix using greedy column selection") + + # precompute M.T * M + # ATA = M.T @ M + ATA = self.K + + if self.verbose: + print("Initializing f and g...") + + f = np.array((ATA.multiply(ATA)).sum(axis=0)).ravel() + g = np.array(ATA.diagonal()).ravel() + + d = np.zeros((k, n)) + omega = np.zeros((k, n)) + + # keep track of selected indices + centers = np.zeros(k, dtype=int) + + # sampling + for j in tqdm(range(k)): + score = f / g + p = np.argmax(score) + + # print residuals + np.sum(f) + + delta_term1 = ATA[:, p].toarray().squeeze() + # print(delta_term1) + delta_term2 = ( + np.multiply(omega[:, p].reshape(-1, 1), omega).sum(axis=0).squeeze() + ) + delta = delta_term1 - delta_term2 + + # some weird rounding errors + delta[p] = np.max([0, delta[p]]) + + o = delta / np.max([np.sqrt(delta[p]), 1e-6]) + omega_square_norm = np.linalg.norm(o) ** 2 + omega_hadamard = np.multiply(o, o) + term1 = omega_square_norm * omega_hadamard + + # update f (term2) + pl = np.zeros(n) + for r in range(j): + omega_r = omega[r, :] + pl += np.dot(omega_r, o) * omega_r + + ATAo = (ATA @ o.reshape(-1, 1)).ravel() + term2 = np.multiply(o, ATAo - pl) + + # update f + f += -2.0 * term2 + term1 + + # update g + g += omega_hadamard + + # store omega and delta + d[j, :] = delta + omega[j, :] = o + + # add index + centers[j] = int(p) + + return centers + + def _updateA(self, B, A_prev): + """Compute assignment matrix A using constrained gradient descent via Frank-Wolfe algorithm. + + Given archetype matrix B and using kernel matrix K, compute assignment matrix A using constrained gradient + descent via Frank-Wolfe algorithm. + + :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells + :param A_prev: (n x k csr_matrix) defining previous weights used for assigning cells to SEACells + :return: (n x k csr_matrix) defining updated weights used for assigning cells to SEACells + """ + n, k = B.shape + A = A_prev + + t = 0 # current iteration (determine multiplicative update) + + Ag = cp.array(A) + Bg = cp.array(B) + Kg = cupyx.scipy.sparse.csc_matrix(self.K) + + # precompute some gradient terms + t2g = Kg.dot(Bg).T + t1g = t2g.dot(Bg) + + # update rows of A for given number of iterations + while t < self.max_FW_iter: + # compute gradient (must convert matrix to ndarray) + Gg = cp.multiply(2, cp.subtract(t1g.dot(Ag), t2g)) + + # get argmins + amins = cp.argmin(Gg, axis=0) + + # loop free implementation + eg = cp.zeros((k, n)) + eg[amins, cp.arange(n)] = 1.0 + + f = 2.0 / (t + 2.0) + Ag = cp.add(Ag, cp.multiply(f, cp.subtract(eg, Ag))) + t += 1 + + A = Ag.get() + + del t1g, t2g, Ag, Kg, Gg, Bg, eg, amins + cp._default_memory_pool.free_all_blocks() + + return A + + def _updateB(self, A, B_prev): + """Compute archetype matrix B using constrained gradient descent via Frank-Wolfe algorithm. + + Given assignment matrix A and using kernel matrix K, compute archetype matrix B using constrained gradient + descent via Frank-Wolfe algorithm. + + :param A: (n x k csr_matrix) defining weights used for assigning cells to SEACells + :param B_prev: (n x k csr_matrix) defining previous SEACells as weighted combinations of cells + :return: (n x k csr_matrix) defining updated SEACells as weighted combinations of cells + """ + k, n = A.shape + + B = B_prev + + # keep track of error + t = 0 + + Ag = cp.array(A) + Bg = cp.array(B) + Kg = cupyx.scipy.sparse.csc_matrix(self.K) + # precompute some terms + t1g = Ag.dot(Ag.T) + t2g = Kg.dot(Ag.T) + + # update rows of B for a given number of iterations + while t < 50: + # compute gradient + Gg = cp.multiply(2, cp.subtract(Kg.dot(Bg).dot(t1g), t2g)) + + # get all argmins + amins = cp.argmin(Gg, axis=0) + + eg = cp.zeros((n, k)) + eg[amins, cp.arange(k)] = 1.0 + + f = 2.0 / (t + 2.0) + Bg = cp.add(Bg, cp.multiply(f, cp.subtract(eg, Bg))) + + t += 1 + + B = Bg.get() + + del ( + t1g, + t2g, + Ag, + Kg, + Gg, + Bg, + eg, + amins, + ) + cp._default_memory_pool.free_all_blocks() + + return B + + def compute_reconstruction(self, A=None, B=None): + """Compute reconstructed data matrix using learned archetypes (SEACells) and assignments. + + :param A: (k x n csr_matrix) defining weights used for assigning cells to SEACells + If None provided, self.A is used. + :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells + If None provided, self.B is used. + :return: (n x n csr_matrix) defining reconstructed data matrix. + """ + if A is None: + A = self.A_ + if B is None: + B = self.B_ + + if A is None or B is None: + raise RuntimeError( + "Either assignment matrix A or archetype matrix B is None." + ) + return (self.kernel_matrix.dot(B)).dot(A) + + def compute_RSS(self, A=None, B=None): + """Compute residual sum of squares error in difference between reconstruction and true data matrix. + + :param A: (k x n csr_matrix) defining weights used for assigning cells to SEACells + If None provided, self.A is used. + :param B: (n x k csr_matrix) defining SEACells as weighted combinations of cells + If None provided, self.B is used. + :return: + ||X-XBA||^2 - (float) square difference between true data and reconstruction. + """ + if A is None: + A = self.A_ + if B is None: + B = self.B_ + + reconstruction = self.compute_reconstruction(A, B) + return np.linalg.norm(self.kernel_matrix - reconstruction) + + def plot_convergence(self, save_as=None, show=True): + """Plot behaviour of squared error over iterations. + + :param save_as: (str) name of file which figure is saved as. If None, no plot is saved. + :param show: (bool) whether to show plot + :return: None. + """ + import matplotlib.pyplot as plt + + plt.figure() + plt.plot(self.RSS_iters) + plt.title("Reconstruction Error over Iterations") + plt.xlabel("Iterations") + plt.ylabel("Squared Error") + if save_as is not None: + plt.savefig(save_as, dpi=150) + if show: + plt.show() + plt.close() + + def step(self): + """Perform one iteration of SEACell algorithm. Update assignment matrix A and archetype matrix B. + + :return: None. + """ + A = self.A_ + B = self.B_ + + if self.K is None: + raise RuntimeError( + "Kernel matrix has not been computed. Run model.construct_kernel_matrix() first." + ) + + if A is None: + raise RuntimeError( + "Cell to SEACell assignment matrix has not been initialised. Run model.initialize() first." + ) + + if B is None: + raise RuntimeError( + "Archetype matrix has not been initialised. Run model.initialize() first." + ) + + A = self._updateA(B, A) + B = self._updateB(A, B) + + self.RSS_iters.append(self.compute_RSS(A, B)) + + self.A_ = A + self.B_ = B + + del A, B + + # Label cells by SEACells assignment + labels = self.get_hard_assignments() + self.ad.obs["SEACell"] = labels["SEACell"] + + return + + def _fit( + self, + max_iter: int = 50, + min_iter: int = 10, + initial_archetypes=None, + initial_assignments=None, + ): + """Internal method to compute archetypes and loadings given kernel matrix K. + + Iteratively updates A and B matrices until maximum number of iterations or convergence has been achieved. + + Modifies ad.obs in place to add 'SEACell' labels to cells. + :param max_iter: (int) maximum number of iterations to perform + :param min_iter: (int) minimum number of iterations to perform + :param initial_archetypes: (array) initial archetypes to use. If None, random initialisation is used. + :param initial_assignments: (array) initial assignments to use. If None, random initialisation is used. + :return: None + """ + self.initialize( + initial_archetypes=initial_archetypes, + initial_assignments=initial_assignments, + ) + + converged = False + n_iter = 0 + while (not converged and n_iter < max_iter) or n_iter < min_iter: + n_iter += 1 + if n_iter == 1 or (n_iter) % 10 == 0: + if self.verbose: + print(f"Starting iteration {n_iter}.") + self.step() + + if n_iter == 1 or (n_iter) % 10 == 0: + if self.verbose: + print(f"Completed iteration {n_iter}.") + + # Check for convergence + if ( + np.abs(self.RSS_iters[-2] - self.RSS_iters[-1]) + < self.convergence_threshold + ): + if self.verbose: + print(f"Converged after {n_iter} iterations.") + converged = True + + self.Z_ = self.B_.T @ self.K + + # Label cells by SEACells assignment + labels = self.get_hard_assignments() + self.ad.obs["SEACell"] = labels["SEACell"] + + if not converged: + import warnings + + warnings.warn( + "Algorithm has not converged - you may need to increase the maximum number of iterations", + RuntimeWarning, + stacklevel=2, + ) + return + + def fit( + self, + max_iter: int = 100, + min_iter: int = 10, + initial_archetypes=None, + initial_assignments=None, + ): + """Compute archetypes and loadings given kernel matrix K. + + Iteratively updates A and B matrices until maximum number of iterations or convergence has been achieved. + :param max_iter: (int) maximum number of iterations to perform (default 100) + :param min_iter: (int) minimum number of iterations to perform (default 10) + :param initial_archetypes: (array) initial archetypes to use. If None, random initialisation is used. + :param initial_assignments: (array) initial assignments to use. If None, random initialisation is used. + :return: None. + """ + if max_iter < min_iter: + raise ValueError( + "The maximum number of iterations specified is lower than the minimum number of iterations specified." + ) + self._fit( + max_iter=max_iter, + min_iter=min_iter, + initial_archetypes=initial_archetypes, + initial_assignments=initial_assignments, + ) + + def get_archetype_matrix(self): + """Return k x n matrix of archetypes computed as the product of the archetype matrix B and the kernel matrix K.""" + return self.Z_ + + def get_soft_assignments(self): + """Return soft SEACells assignment. + + Returns a tuple of (labels, weights) where labels is a dataframe with SEACell assignments for the top 5 + SEACell assignments for each cell and weights is an array with the corresponding weights for each assignment. + :return: (pd.DataFrame, np.array) with labels and weights. + """ + import copy + + archetype_labels = self.get_hard_archetypes() + A = copy.deepcopy(self.A_.T) + + labels = [] + weights = [] + for _i in range(5): + l = A.argmax(1) + labels.append(archetype_labels[l]) + weights.append(A[np.arange(A.shape[0]), l]) + A[np.arange(A.shape[0]), l] = -1 + + weights = np.vstack(weights).T + labels = np.vstack(labels).T + + soft_labels = pd.DataFrame(labels) + soft_labels.index = self.ad.obs_names + + return soft_labels, weights + + def get_hard_assignments(self): + """Return a dataframe with the SEACell assignment for each cell. + + The assignment is the SEACell with the highest assignment weight. + + :return: (pd.DataFrame) with SEACell assignments. + """ + # Use argmax to get the index with the highest assignment weight + + df = pd.DataFrame({"SEACell": [f"SEACell-{i}" for i in self.A_.argmax(0)]}) + df.index = self.ad.obs_names + df.index.name = "index" + + return df + + def get_hard_archetypes(self): + """Return the names of cells most strongly identified as archetypes. + + :return list of archetype names. + """ + return self.ad.obs_names[self.B_.argmax(0)] + + def save_model(self, outdir): + """Save the model to a pickle file. + + :param outdir: (str) path to directory to save to + :return: None. + """ + import pickle + + with open(outdir + "/model.pkl", "wb") as f: + pickle.dump(self, f) + return None + + def save_assignments(self, outdir): + """Save SEACell assignments. + + Saves: + (1) the cell to SEACell assignments to a csv file with the name 'SEACells.csv'. + (2) the kernel matrix to a .npz file with the name 'kernel_matrix.npz'. + (3) the archetype matrix to a .npz file with the name 'A.npz'. + (4) the loading matrix to a .npz file with the name 'B.npz'. + + :param outdir: (str) path to directory to save to + :return: None + """ + import os + + os.makedirs(outdir, exist_ok=True) + save_npz(outdir + "/kernel_matrix.npz", self.kernel_matrix) + save_npz(outdir + "/A.npz", self.A_.T) + save_npz(outdir + "/B.npz", self.B_) + + labels = self.get_hard_assignments() + labels.to_csv(outdir + "/SEACells.csv") + return None