A memory-efficient, IO-aware Triton GPU kernel for the Gaussian Mixture Model (GMM) E-step at arbitrary scale.
Computing GMM responsibilities for large datasets is memory-bound: a naive implementation materializes an N×K responsibility matrix that exhausts GPU memory well before the scales where modern applications operate. Flash-GMM eliminates this matrix entirely.
Inspired by the IO-aware tiling strategy of FlashAttention, Flash-GMM computes the full GMM E-step in a single pass over the data, accumulating only O(KD) sufficient statistics. The N×K responsibility matrix is never written to memory.
| Property | Value |
|---|---|
| Kernel memory (N=1M, K=1024, D=128) | 4.5 MB |
| TorchGMM memory (same config) | 21,006 MB |
| Memory reduction | 4,668× |
| Speedup vs SciPy (CPU) | 766–1,740× |
| Speedup vs TorchGMM (GPU) | 19–32× |
| Max N on A100 80GB (streaming) | 1B+ |
| Validated on | A100 80GB, H100, RTX 5080 |
The kernel processes data in tiles of BLOCK_N rows. For each tile:
-
Pass 1 (log-sum-exp): Loads the tile into registers and computes the per-sample log-normaliser
log Z_ivia a numerically stable online log-sum-exp, iterating over all K components in tiles ofBLOCK_K. -
Pass 2 (accumulation): Reuses the tile from registers (no second HBM read), computes responsibilities
r_ik = exp(log z_ik - log Z_i), and atomically accumulates the sufficient statisticsN_k,mu_acc,sig_accinto O(KD) global buffers.
The tile never leaves on-chip memory between the two passes, giving a single HBM read of X per iteration. The N×K responsibility matrix is discarded immediately after each tile.
pip install torch tritonNo other dependencies required for the kernel itself.
import math
import torch
from flash_gmm import flash_gmm_estep
# Inputs (all GPU tensors, float32)
N, K, D = 1_000_000, 1024, 128
X = torch.randn(N, D, device='cuda') # data
mu = torch.randn(K, D, device='cuda') # component means
log_sigma_sq = torch.zeros(K, device='cuda') # log variances
log_pi = torch.full((K,), -math.log(K), device='cuda') # log weights
# E-step
logZ, Nk, mu_acc, sig_acc = flash_gmm_estep(X, mu, log_sigma_sq, log_pi)
# M-step
pi_new = Nk / N
mu_new = mu_acc / Nk[:, None]
sigma_sq = sig_acc / (D * Nk)| Tensor | Shape | Description |
|---|---|---|
logZ |
(N,) | Per-sample log-normaliser log Z_i |
Nk |
(K,) | Effective cluster counts Σ_i r_ik |
mu_acc |
(K, D) | Weighted sum Σ_i r_ik x_i |
sig_acc |
(K,) | Weighted sq. dist. Σ_i r_ik ‖x_i − μ_k‖² |
The defaults BLOCK_N=64, BLOCK_K=16, BLOCK_D=128 work for D≤128. For larger D, increase BLOCK_D to the next power of two ≥ D:
logZ, Nk, mu_acc, sig_acc = flash_gmm_estep(X, mu, lss, lpi, BLOCK_D=256)For datasets larger than GPU memory, feed data in chunks — the O(KD) accumulators are simply summed across chunks:
Nk_total = torch.zeros(K, device='cuda')
mu_total = torch.zeros(K, D, device='cuda')
sig_total = torch.zeros(K, device='cuda')
for chunk in dataloader: # chunks loaded from CPU/SSD
X_chunk = chunk.cuda()
_, Nk_c, mu_c, sig_c = flash_gmm_estep(X_chunk, mu, lss, lpi)
Nk_total += Nk_c
mu_total += mu_c
sig_total += sig_c
del X_chunk
# M-step on aggregated statistics
mu_new = mu_total / Nk_total[:, None]
sigma_sq = sig_total / (D * Nk_total)This was validated at N=1B vectors (512 GB of data) on a single A100 80GB, completing in ~28 minutes with 1,548 MB peak GPU memory.
Runtime of a single E-step (K=1024, D=128, A100 80GB):
| N | Flash-GMM | vs SciPy (CPU) | vs TorchGMM (GPU) |
|---|---|---|---|
| 10K | 3 ms | 766× | 32× |
| 50K | 9 ms | 1,260× | 20× |
| 100K | 18 ms | 1,458× | 23× |
| 250K | 46 ms | 1,597× | 19× |
| 500K | 84 ms | 1,571× | 20× |
| 1M | 152 ms | 1,738× | 22× |
| 10M | 1,519 ms | 1,740× | OOM |
| 50M | 35,510 ms | 1,752× | OOM |
TorchGMM runs out of memory beyond N≈1M. Flash-GMM scales to N=10⁸ on the same device.
For NVIDIA H100 on the paper benchmark workloads (K=1024, D=96–128), we ship kernels tuned for those specific shapes that run roughly 100× faster than TorchGMM: flash_gmm_h100.py, flash_gmm_diag_h100.py, flash_gmm_full_h100.py.
For shapes far from the paper benchmarks, or non-H100 GPUs, use flash_gmm.py.
- Gal Bloch (gal.bloch@ibm.com)
- Ariel Gera (ariel.gera1@ibm.com)
- Matan Orbach (matano@il.ibm.com)
- Ohad Eytan (ohad.eytan1@ibm.com)
- Assaf Toledo (assaf.toledo@ibm.com)
IBM Research
If you use Flash-GMM in your research, please cite:
@article{bloch2026flashgmm,
title = {Flash-GMM: A Memory-Efficient Kernel for Scalable Soft Clustering},
author = {Bloch, Gal and Gera, Ariel and Orbach, Matan and Eytan, Ohad and Toledo, Assaf},
journal = {arXiv preprint arXiv:2606.10896},
year = {2026},
url = {https://arxiv.org/abs/2606.10896}
}Apache 2.0 — see LICENSE.