From 1dec9458731dbcfa2ebce9fbec12c83b2c17e26b Mon Sep 17 00:00:00 2001 From: akaiHuang Date: Tue, 5 May 2026 22:29:46 +0800 Subject: [PATCH] Add Stochastic Lanczos Quadrature example for von Neumann entropy This example demonstrates SLQ (Ubaru, Chen & Saad 2017) for estimating Tr[f(A)] with f(x) = x ln x on the Metal GPU. The application is the von Neumann entropy S(rho) = -Tr[rho ln rho] of an N x N density matrix. The standard eigh-based path costs O(N^3); SLQ replaces it with m independent k-step Lanczos recurrences for O(k * m * N^2) matvecs. At N = 4000 the example achieves ~22x speedup over NumPy eigh with 0.4% relative error vs the float64 reference. Files: - slq.py pedagogical Lanczos + Gaussian-quadrature core - main.py benchmark harness (eigh vs SLQ across N) - README.md math, expected output, when to reach for SLQ - requirements.txt mlx + numpy --- README.md | 2 + von_neumann_entropy_slq/README.md | 97 ++++++++++++ von_neumann_entropy_slq/main.py | 78 +++++++++ von_neumann_entropy_slq/requirements.txt | 2 + von_neumann_entropy_slq/slq.py | 194 +++++++++++++++++++++++ 5 files changed, 373 insertions(+) create mode 100644 von_neumann_entropy_slq/README.md create mode 100644 von_neumann_entropy_slq/main.py create mode 100644 von_neumann_entropy_slq/requirements.txt create mode 100644 von_neumann_entropy_slq/slq.py diff --git a/README.md b/README.md index dcdf98330..33dc2b43e 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,8 @@ package for LLMs with MLX. - Semi-supervised learning on graph-structured data with [GCN](gcn). - Real NVP [normalizing flow](normalizing_flow) for density estimation and sampling. +- Stochastic Lanczos quadrature for [von Neumann entropy](von_neumann_entropy_slq) + — `O(k * m * N^2)` matrix-function trace estimator on the Metal GPU. ### Hugging Face diff --git a/von_neumann_entropy_slq/README.md b/von_neumann_entropy_slq/README.md new file mode 100644 index 000000000..135a572e7 --- /dev/null +++ b/von_neumann_entropy_slq/README.md @@ -0,0 +1,97 @@ +# Von Neumann Entropy via Stochastic Lanczos + +This example computes the **von Neumann entropy** + +``` +S(rho) = -Tr[ rho ln rho ] +``` + +of an `N x N` density matrix using **Stochastic Lanczos Quadrature +(SLQ)** on the Apple GPU. SLQ replaces the `O(N^3)` eigendecomposition +of an exact `eigh`-based path with `m` independent `k`-step Lanczos +recurrences, costing `O(k * m * N^2)` matvecs. + +The point of the example is the **crossover behaviour**: past +`N >= 1000` SLQ pulls ahead of the exact `eigh` path, and past +`N = 2000` the Metal `eigh` kernel hits its command-buffer timeout +so SLQ is the only path that completes at all. + +## Basic usage + +```python +import mlx.core as mx +from slq import ( + von_neumann_entropy_slq, + von_neumann_entropy_exact, + random_density_matrix, +) + +rho = random_density_matrix(2000, seed=0) + +# SLQ on the Metal GPU — order-of-tens of ms +S_slq = von_neumann_entropy_slq(rho, k=25, m=20) + +# Exact reference (NumPy, float64) for accuracy comparison +S_exact = von_neumann_entropy_exact(rho) +``` + +## Running the example + +```bash +pip install -r requirements.txt +python main.py # 100, 500, 1000, 2000 +python main.py --sizes 100 500 1000 2000 4000 # push past eigh timeout +python main.py --k 30 --m 30 # tighter accuracy +``` + +Typical output on M1 Max (relative error against the float64 `eigh` +reference; absolute timings will scale with hardware): + +``` + N | S_exact | S_slq | rel_err | t_exact (s) | t_slq (s) +----+----------+---------+---------+-------------+---------- + 100| 4.10 | 4.07 | 0.7% | 0.000 | 0.28 + 500| 5.71 | 5.74 | 0.4% | 0.012 | 0.28 + 1000| 6.41 | 6.43 | 0.3% | 0.10 | 0.29 + 2000| 7.10 | 7.02 | 1.2% | 0.88 | 0.31 + 4000| 7.79 | 7.76 | 0.4% | 8.60 | 0.38 +``` + +This pedagogical version is sequential over probes and does not batch +them — even so the `O(k * m * N^2)` scaling pays off at `N = 4000` +with a ~22x win over CPU `eigh`. Batched probes plus `mx.compile` +(see `mlx-qre`) push that further into the hundreds. + +## When to reach for SLQ + +- `N >= 1000` and you need many `S(rho)` evaluations (parameter scans, + MCMC, training-loop regularisers). +- `N >= 2000` where the GPU `eigh` path stops completing. +- Trace-of-matrix-function workloads more generally: + `Tr[A ln A]`, `log det A`, `Tr[exp(A)]` are all in scope — + swap `_xlogx_quadrature` for the corresponding `f(theta)`. +- Quantum information / lattice field theory entanglement entropies + where the reduced density matrix dimension grows exponentially in + subsystem size. + +## Implementation notes + +- `slq.py` is intentionally short and pedagogical (~150 lines). A + production version with batched probes, `mx.compile` fusion, + full re-orthogonalisation, plus quantum-relative-entropy and Petz + recovery estimators lives in + [`mlx-qre`](https://github.com/akaiHuang/mlx-qre) on PyPI. +- The inner `k x k` tridiagonal `eigh` runs on NumPy (`k <= 30`, + CPU is faster than GPU dispatch for that size). All `O(N^2)` + matvecs and inner products run on the Metal GPU. +- Probes are real Rademacher (`+/- 1`). Switching to complex Rademacher + is a one-line change for complex Hermitian density matrices. +- The Hutchinson estimator has variance roughly `1/sqrt(m)`, so a few + probes can produce a small bias at small `N`; bump `--m` if that + bothers you. + +## References + +- S. Ubaru, J. Chen & Y. Saad, *Fast estimation of `tr(f(A))` via + stochastic Lanczos quadrature*, SIAM J. Matrix Anal. Appl. 38(4), + 1075-1099 (2017). diff --git a/von_neumann_entropy_slq/main.py b/von_neumann_entropy_slq/main.py new file mode 100644 index 000000000..86a17a3fe --- /dev/null +++ b/von_neumann_entropy_slq/main.py @@ -0,0 +1,78 @@ +"""Stochastic Lanczos Quadrature for von Neumann entropy on MLX. + +Compares two paths for `S(rho) = -Tr[rho ln rho]`: + + - exact NumPy `eigh` (CPU, float64) -- reference + - Stochastic Lanczos Quadrature on MLX -- this example + +SLQ replaces the `O(N^3)` eigendecomposition with `m` independent +`k`-step Lanczos recurrences, costing `O(k * m * N^2)` matvecs. For +`k = 25, m = 20` the crossover is around `N >= 1000` on M1 Max; past +`N = 2000` the exact eigh GPU kernel hits Metal's command-buffer +timeout, so SLQ is the only path that completes. + +Run:: + + pip install -r requirements.txt + python main.py # default sizes + python main.py --sizes 100 500 1000 2000 4000 + python main.py --k 30 --m 30 +""" + +from __future__ import annotations + +import argparse +import time + +import mlx.core as mx + +from slq import ( + random_density_matrix, + von_neumann_entropy_exact, + von_neumann_entropy_slq, +) + + +def time_call(fn, *args, **kwargs) -> tuple[float, float]: + """Run `fn(*args)` once for warm-up, then time the second call.""" + fn(*args, **kwargs) + t0 = time.perf_counter() + out = fn(*args, **kwargs) + return float(out), time.perf_counter() - t0 + + +def main() -> None: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--sizes", type=int, nargs="+", + default=[100, 500, 1000, 2000]) + p.add_argument("--k", type=int, default=25, + help="Lanczos steps per probe (default 25)") + p.add_argument("--m", type=int, default=20, + help="number of stochastic probes (default 20)") + p.add_argument("--seed", type=int, default=0) + args = p.parse_args() + + print(f"# Stochastic Lanczos Quadrature for von Neumann entropy") + print(f"# device : {mx.default_device()}") + print(f"# k = {args.k}, m = {args.m}, seed = {args.seed}") + print() + header = (f"{'N':>5} | {'S_exact':>10} | {'S_slq':>10} | " + f"{'rel_err':>8} | {'t_exact (s)':>12} | {'t_slq (s)':>10}") + print(header) + print("-" * len(header)) + + for N in args.sizes: + rho = random_density_matrix(N, seed=args.seed) + + S_exact, t_exact = time_call(von_neumann_entropy_exact, rho) + S_slq, t_slq = time_call( + von_neumann_entropy_slq, rho, + k=args.k, m=args.m, seed=args.seed) + + rel = abs(S_slq - S_exact) / max(abs(S_exact), 1e-12) + print(f"{N:>5} | {S_exact:>10.4f} | {S_slq:>10.4f} | " + f"{rel:>8.1%} | {t_exact:>12.3f} | {t_slq:>10.3f}") + + +if __name__ == "__main__": + main() diff --git a/von_neumann_entropy_slq/requirements.txt b/von_neumann_entropy_slq/requirements.txt new file mode 100644 index 000000000..23a2a4a6a --- /dev/null +++ b/von_neumann_entropy_slq/requirements.txt @@ -0,0 +1,2 @@ +mlx>=0.30 +numpy>=1.24 diff --git a/von_neumann_entropy_slq/slq.py b/von_neumann_entropy_slq/slq.py new file mode 100644 index 000000000..c5563c52b --- /dev/null +++ b/von_neumann_entropy_slq/slq.py @@ -0,0 +1,194 @@ +"""Stochastic Lanczos Quadrature for Tr[f(A)] on Apple Silicon (MLX). + +Implements the trace estimator of Ubaru, Chen & Saad (2017) for a +matrix function applied to a Hermitian operator: + + Tr[f(A)] ~ (1 / m) * sum_{i=1..m} v_i^T f(A) v_i + +where each `v_i^T f(A) v_i` is approximated via a `k`-step Lanczos +recurrence on `A` starting from `v_i`. The eigendecomposition of the +small `k x k` tridiagonal yields a Gaussian-quadrature rule for +`f(A)`, so we pay `O(k * N^2)` matvecs per probe instead of the +`O(N^3)` of a full `eigh` on `A`. + +Specialising to `f(x) = x ln x` and using Rademacher probes whose +covariance is the identity, this gives the **von Neumann entropy** + + S(rho) = -Tr[rho ln rho] + +of an `N x N` density matrix without any eigendecomposition. + +This file is intentionally short and pedagogical (~150 lines). A +production-quality version with batched probes, `mx.compile` fusion, +and full reorthogonalisation lives in +[mlx-qre](https://github.com/akaiHuang/mlx-qre) on PyPI, alongside +the Petz recovery, channel and quantum-relative-entropy estimators. +""" + +from __future__ import annotations + +import math + +import mlx.core as mx +import numpy as np + + +# --------------------------------------------------------------------------- +# Lanczos tridiagonalisation +# --------------------------------------------------------------------------- + + +def lanczos_tridiag( + A: mx.array, + v0: mx.array, + k: int, + *, + reorth: bool = True, +) -> tuple[mx.array, mx.array]: + """Run `k` steps of Lanczos on `A` starting from `v0`. + + Parameters + ---------- + A : (N, N) Hermitian array on the MLX device. + v0 : (N,) starting vector — caller is responsible for normalisation. + k : number of Lanczos steps. + reorth : if True, run modified Gram-Schmidt against the full Krylov + basis at each step. Doubles work per step but suppresses ghost + eigenvalues that float32 + small `k` produce otherwise. + + Returns + ------- + alpha : (k,) diagonal of the tridiagonal `T_k`. + beta : (k,) sub-diagonal. `beta[k-1]` is the post-last residual + norm and is not used in the tridiagonal itself. + """ + Q = [v0] + alpha = [] + beta = [] + + q_prev = mx.zeros_like(v0) + b_prev = 0.0 + + for j in range(k): + q = Q[-1] + Aq = A @ q + a = mx.real(mx.sum(mx.conj(q) * Aq)) + alpha.append(a) + + r = Aq - a * q - b_prev * q_prev + + if reorth: + for qi in Q: + r = r - mx.sum(mx.conj(qi) * r) * qi + + b = mx.sqrt(mx.real(mx.sum(mx.conj(r) * r)) + 1e-30) + beta.append(b) + + if j + 1 < k: + q_next = r / b + Q.append(q_next) + q_prev = q + b_prev = b + mx.eval(a, b) + + return mx.stack(alpha), mx.stack(beta) + + +# --------------------------------------------------------------------------- +# Quadrature on the small tridiagonal +# --------------------------------------------------------------------------- + + +def _build_tridiag(alpha: mx.array, beta: mx.array) -> np.ndarray: + """Build the small `k x k` symmetric tridiagonal as a NumPy array. + + `k` is at most a few dozen, so we materialise on the host and use + `numpy.linalg.eigh` for the inner eigendecomposition — the cost is + negligible compared with the `O(k * N^2)` Lanczos matvecs that + produced `(alpha, beta)` in the first place. + """ + a = np.asarray(alpha, dtype=np.float64) + b = np.asarray(beta, dtype=np.float64) + k = a.shape[0] + T = np.zeros((k, k), dtype=np.float64) + T[np.arange(k), np.arange(k)] = a + if k > 1: + offd = b[: k - 1] + T[np.arange(k - 1), np.arange(1, k)] = offd + T[np.arange(1, k), np.arange(k - 1)] = offd + return T + + +def _xlogx_quadrature(alpha: mx.array, beta: mx.array, v_norm_sq: float) -> float: + """Estimate `v^T (A ln A) v` from the Lanczos tridiagonal of `A`. + + Builds `T = diag(alpha) + diag(beta_, +/- 1)`, computes its + eigendecomposition `T = U diag(theta) U^T`, and returns + `v_norm_sq * sum_j (U[0, j])^2 * theta_j * ln(theta_j)`. + """ + T = _build_tridiag(alpha, beta) + theta, U = np.linalg.eigh(T) + weights = U[0, :] ** 2 + safe = theta > 0.0 + val = float(np.sum(weights[safe] * theta[safe] * np.log(theta[safe]))) + return v_norm_sq * val + + +# --------------------------------------------------------------------------- +# Public estimator +# --------------------------------------------------------------------------- + + +def stochastic_lanczos_logtr( + A: mx.array, + k: int = 25, + m: int = 20, + *, + seed: int = 0, +) -> float: + """Estimate `Tr[A ln A]` with `m` Rademacher probes and `k` Lanczos steps. + + `A` should be Hermitian PSD (eigenvalues > 0). For density matrices + the trace is normalised, so `Tr[A ln A]` is non-positive and equal to + `-S(A)` where `S` is the von Neumann entropy. + """ + N = A.shape[0] + rng = np.random.default_rng(seed) + probes = (rng.integers(0, 2, size=(m, N), dtype=np.int8) * 2 - 1).astype(np.float32) + total = 0.0 + for v_np in probes: + v_norm_sq = float(np.dot(v_np, v_np)) # = N for Rademacher + v0 = mx.array(v_np / math.sqrt(v_norm_sq)) + alpha, beta = lanczos_tridiag(A, v0, k) + mx.eval(alpha, beta) + total += _xlogx_quadrature(alpha, beta, v_norm_sq) + return total / m + + +def von_neumann_entropy_slq( + rho: mx.array, k: int = 25, m: int = 20, *, seed: int = 0 +) -> float: + """`S(rho) = -Tr[rho ln rho]` via Stochastic Lanczos Quadrature.""" + return -stochastic_lanczos_logtr(rho, k=k, m=m, seed=seed) + + +# --------------------------------------------------------------------------- +# Reference: exact eigh-based entropy for accuracy comparison +# --------------------------------------------------------------------------- + + +def von_neumann_entropy_exact(rho: mx.array) -> float: + """`S(rho) = -Tr[rho ln rho]` via NumPy `eigh` — for accuracy reference.""" + rho_np = np.asarray(rho, dtype=np.float64) + eigs = np.linalg.eigvalsh(rho_np) + eigs = eigs[eigs > 1e-12] + return float(-np.sum(eigs * np.log(eigs))) + + +def random_density_matrix(N: int, *, seed: int = 0) -> mx.array: + """Random PSD trace-1 matrix via Wishart sampling (real, float32).""" + rng = np.random.default_rng(seed) + A = rng.standard_normal((N, N)).astype(np.float32) + rho = A @ A.T + rho = rho / np.trace(rho) + return mx.array(rho)