TL;DR: In modern LLM architectures (RMSNorm, QK-Norm), standard Weight Decay (L2) fails to control model complexity due to scale invariance. We introduce Low-Rank Decay, an efficient Nuclear Norm regularization technique using Newton-Schulz iteration (no SVD required). It aggressively prunes the "memorization circuit," forcing Rank Collapse and accelerating Grokking by orders of magnitude.
Low-Rank Decay significantly expands the grokking region. While L2 decay fails to generalize in data-scarce regimes (40%-50% data), Low-Rank Decay successfully forces the model to learn the algorithm, achieving perfect accuracy.
Rank collapse is a pre-condition for generalization. Our method induces a spectral phase transition (orange line) within the first 500 steps, reducing the effective rank of QK matrices to near-unity. This creates a low-dimensional manifold that facilitates learning the structured solution.
Modern Transformers (e.g., Gemma 3, LLaMA) are heavily normalized (RMSNorm, QK-Norm). In these Scale-Invariant Architectures, scaling the weight matrix
To control complexity in a scale-invariant network, we must control the Rank, not the Magnitude. We apply Nuclear Norm Regularization:
Computing SVD at every step is too slow. We use Newton-Schulz iteration to approximate the sign of the matrix (
This repository is self-contained in a single file for reproducibility.
pip install torch numpy matplotlib tqdm seabornTo see the Rank Collapse and Attention Patterns (generating the mechanism plot):
python experiment.py --mode mechanismTo run the grid search over data fractions and decay strengths:
python experiment.py --mode phase_diagramThe core logic is decoupled from the optimizer, making it easy to plug into existing training loops (e.g., HF Trainer).
class NewtonSchulzLowRankDecay:
def step(self):
for W in self.params:
# 1. Normalize Spectral Norm
X = W / (W.norm() + 1e-8)
# 2. Newton-Schulz Iteration (Approximate UV^T)
# Y_{k+1} = 0.5 * Y_k * (3I - Y_k^T * Y_k)
Y = X
for _ in range(5):
A = Y.T @ Y
Y = 0.5 * Y @ (3.0 * torch.eye(...) - A)
# 3. Apply Update
# W <- W - lambda * sign(W)
W.sub_(self.decay_rate * Y)-
Task: Modular Addition (
$a + b \pmod P$ ) - Architecture: Scale-Invariant Transformer (Pre-RMSNorm + QK-Norm)
-
Baselines:
- AdamW + L2 Decay (Standard)
- AdamW + Low-Rank Decay (Ours)
If you find this useful for your research on Grokking, Mechanistic Interpretability, or LLM Optimization, please cite:
@misc{lowrank2025,
title={Grokking by Rank Collapse: Nuclear Norm Regularization as a Complexity Prior},
year={2025},
publisher={GitHub},
howpublished={\url{https://github.com/Chunjiang-Intelligence/low-rank-decay/}}
}This is a research prototype.

