- First‑class SAM via
SAMWrapper(closure‑based) - More robust checkpoint/restore with HeavyBall‑internal state
- New optimizers:
SGD,AdamC,MSAMLaProp - Overhauled chainable pipeline: indexed transforms, branching, internal gradient‑accumulation, and
SqueezeGrad - Faster, more accurate code paths
- New
heavyball.helperswith Optuna‑compatible samplers and utilities
SAMWrapperapplies sharpness‑aware minimization to any HeavyBall optimizer while preserving the wrapped step logic; requires a closureSGDbuilt on the chainable internalsAdamC, a "corrected version of Adam" with weight decay normalized by the maximum LRMSAMLaPropbuilt on top of Momentum‑SAM- Chainable pipeline:
- Every transform carries a
transform_idx; state keys include this index - Optimizer branching, for more freedom in optimizer design and native grafting support
PrecondGradAccumGuardenables gradient accumulation for preconditioner fittingSqueezeGradremoves size‑1 dims before functional transforms, improving PSGD's speed and preconditioner fitting
- Every transform carries a
- PSGD and SOAP speedups through new, more accurate SVD calculation and better compilation
heavyball.helpersmodule with Optuna‑compatible samplers and sweep utilities
- Default orthogonalization switches to Newton-Schulz, impacting Muon; SOAP relies on
precise_zeroth_power_mode="qr"and remains unchanged - Optimizer state keys include the per‑transform index (e.g.,
exp_avg_3), breaking old checkpoints
- Re‑test optimizers sensitive to orthogonalization; set
utils.zeroth_power_mode="qr"to restore 1.x behavior - Migrate checkpoints using
python scripts/migrate_optimizer_state.py <checkpoint_path> heavyball.<OptimizerClass> - Update any custom state‑dict tooling to handle transform‑indexed keys