-
Notifications
You must be signed in to change notification settings - Fork 717
[PyTorch] Add distributed Muon optimizer #2920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vcherepanov-nv
wants to merge
10
commits into
NVIDIA:main
Choose a base branch
from
vcherepanov-nv:muon
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a2df6f8
Add distributed Muon optimizer
vcherepanov-nv e332a8e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1304712
Fix Muon closure and reference test
vcherepanov-nv 958923d
Fix Muon optimizer distributed API handling
vcherepanov-nv 860d6a7
Fix Muon optimizer docs and params typing
vcherepanov-nv 037375c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6d043e3
Add tensor-parallel Newton-Schulz wrapper
vcherepanov-nv e9758a3
Move Newton-Schulz wrapper into optimizers
vcherepanov-nv d7597b6
Use tensor-parallel Newton-Schulz in Muon
vcherepanov-nv 47a02d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,215 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Distributed Muon optimizer test worker. | ||
|
|
||
| Launched via torchrun from test_muon_optimizer.py. | ||
| """ | ||
|
|
||
| import argparse | ||
| import sys | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from torch.distributed.elastic.multiprocessing.errors import record | ||
|
|
||
| import transformer_engine.pytorch as te | ||
| from transformer_engine.pytorch.optimizers.newton_schulz import get_coefficients | ||
| from transformer_engine.pytorch.optimizers.muon import get_muon_scale_factor | ||
|
|
||
|
|
||
| def _reference_orthogonalize( | ||
| grad: torch.Tensor, | ||
| *, | ||
| partition_dim: int, | ||
| coefficients: list[tuple[float, float, float]], | ||
| scale_mode: str, | ||
| extra_scale_factor: float, | ||
| eps: float, | ||
| ) -> torch.Tensor: | ||
| global_shape = [grad.size(0), grad.size(1)] | ||
|
|
||
| x = grad.clone() | ||
| if partition_dim == 0: | ||
| x = x.mT.contiguous() | ||
|
|
||
| x = x / torch.sqrt((x.float() * x.float()).sum()).clamp_min(eps).to(dtype=x.dtype) | ||
|
|
||
| for a, b, c in coefficients: | ||
| xxt = x @ x.mT | ||
| x = a * x + b * (xxt @ x) + c * ((xxt @ xxt) @ x) | ||
|
|
||
| if partition_dim == 0: | ||
| x = x.mT.contiguous() | ||
|
|
||
| scale = get_muon_scale_factor(global_shape[0], global_shape[1], mode=scale_mode) | ||
| return x * (scale * extra_scale_factor) | ||
|
|
||
|
|
||
| def _reference_step( | ||
| param: torch.Tensor, | ||
| grad: torch.Tensor, | ||
| momentum_buffer: torch.Tensor, | ||
| *, | ||
| lr: float, | ||
| momentum: float, | ||
| nesterov: bool, | ||
| weight_decay: float, | ||
| use_decoupled_weight_decay: bool, | ||
| partition_dim: int, | ||
| coefficients: list[tuple[float, float, float]], | ||
| scale_mode: str, | ||
| extra_scale_factor: float, | ||
| eps: float, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| param = param.clone() | ||
| grad = grad.clone() | ||
| momentum_buffer = momentum_buffer.clone() | ||
|
|
||
| if use_decoupled_weight_decay: | ||
| param = param * (1.0 - lr * weight_decay) | ||
| elif weight_decay != 0: | ||
| grad = grad + weight_decay * param | ||
|
|
||
| momentum_buffer = momentum * momentum_buffer + (1.0 - momentum) * grad | ||
| if nesterov: | ||
| update = (1.0 - momentum) * grad + momentum * momentum_buffer | ||
| else: | ||
| update = momentum_buffer | ||
|
|
||
| orth_update = _reference_orthogonalize( | ||
| update, | ||
| partition_dim=partition_dim, | ||
| coefficients=coefficients, | ||
| scale_mode=scale_mode, | ||
| extra_scale_factor=extra_scale_factor, | ||
| eps=eps, | ||
| ) | ||
| param = param - lr * orth_update | ||
| return param, momentum_buffer | ||
|
|
||
|
|
||
| @record | ||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Distributed Muon optimizer test") | ||
| parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"]) | ||
| parser.add_argument("--partition-dim", type=int, default=1, choices=[0, 1]) | ||
| parser.add_argument( | ||
| "--weight-decay-mode", type=str, default="decoupled", choices=["decoupled", "l2"] | ||
| ) | ||
| parser.add_argument("--num-steps", type=int, default=2) | ||
| args = parser.parse_args() | ||
|
|
||
| dist.init_process_group(backend="nccl") | ||
| rank = dist.get_rank() | ||
| world_size = dist.get_world_size() | ||
| torch.cuda.set_device(rank) | ||
|
|
||
| dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16 | ||
| if args.partition_dim == 0: | ||
| full_shape = (world_size * 64, 96) | ||
| else: | ||
| full_shape = (96, world_size * 64) | ||
|
|
||
| lr = 3e-4 | ||
| momentum = 0.95 | ||
| nesterov = True | ||
| weight_decay = 0.01 | ||
| use_decoupled_weight_decay = args.weight_decay_mode == "decoupled" | ||
| coefficient_type = "quintic" | ||
| num_ns_steps = 5 | ||
| scale_mode = "spectral" | ||
| extra_scale_factor = 1.0 | ||
| eps = 1e-7 | ||
| coefficients = get_coefficients(num_ns_steps, coefficient_type) | ||
|
|
||
| if rank == 0: | ||
| torch.manual_seed(1234) | ||
| full_param = torch.randn(full_shape, device="cuda", dtype=dtype) | ||
| full_grads = [ | ||
| torch.randn(full_shape, device="cuda", dtype=dtype) for _ in range(args.num_steps) | ||
| ] | ||
| else: | ||
| full_param = torch.empty(full_shape, device="cuda", dtype=dtype) | ||
| full_grads = [ | ||
| torch.empty(full_shape, device="cuda", dtype=dtype) for _ in range(args.num_steps) | ||
| ] | ||
|
|
||
| dist.broadcast(full_param, src=0) | ||
| for grad in full_grads: | ||
| dist.broadcast(grad, src=0) | ||
|
|
||
| shard_size = full_shape[args.partition_dim] // world_size | ||
| shard_slice = slice(rank * shard_size, (rank + 1) * shard_size) | ||
| if args.partition_dim == 0: | ||
| local_param_init = full_param[shard_slice, :].contiguous() | ||
| else: | ||
| local_param_init = full_param[:, shard_slice].contiguous() | ||
|
|
||
| param = torch.nn.Parameter(local_param_init.clone()) | ||
| param.partition_dim = args.partition_dim | ||
| optimizer = te.optimizers.MuonOptimizer( | ||
| [param], | ||
| lr=lr, | ||
| momentum=momentum, | ||
| nesterov=nesterov, | ||
| weight_decay=weight_decay, | ||
| use_decoupled_weight_decay=use_decoupled_weight_decay, | ||
| coefficient_type=coefficient_type, | ||
| num_ns_steps=num_ns_steps, | ||
| scale_mode=scale_mode, | ||
| extra_scale_factor=extra_scale_factor, | ||
| process_group=dist.group.WORLD, | ||
| eps=eps, | ||
| ) | ||
|
|
||
| ref_param = full_param.float() | ||
| ref_momentum = torch.zeros_like(ref_param) | ||
| for full_grad in full_grads: | ||
| if args.partition_dim == 0: | ||
| param.grad = full_grad[shard_slice, :].contiguous() | ||
| else: | ||
| param.grad = full_grad[:, shard_slice].contiguous() | ||
| optimizer.step() | ||
|
|
||
| ref_param, ref_momentum = _reference_step( | ||
| ref_param, | ||
| full_grad.float(), | ||
| ref_momentum, | ||
| lr=lr, | ||
| momentum=momentum, | ||
| nesterov=nesterov, | ||
| weight_decay=weight_decay, | ||
| use_decoupled_weight_decay=use_decoupled_weight_decay, | ||
| partition_dim=args.partition_dim, | ||
| coefficients=coefficients, | ||
| scale_mode=scale_mode, | ||
| extra_scale_factor=extra_scale_factor, | ||
| eps=eps, | ||
| ) | ||
|
|
||
| gathered = [torch.empty_like(param) for _ in range(world_size)] | ||
| dist.all_gather(gathered, param) | ||
| if args.partition_dim == 0: | ||
| test_param = torch.cat(gathered, dim=0) | ||
| else: | ||
| test_param = torch.cat(gathered, dim=1) | ||
|
|
||
| if rank == 0: | ||
| expected = ref_param.to(dtype) | ||
| atol, rtol = (5e-2, 5e-2) if dtype == torch.bfloat16 else (2e-3, 2e-3) | ||
| if torch.allclose(test_param, expected, atol=atol, rtol=rtol): | ||
| print("MUON OPTIMIZER CHECK PASSED", flush=True) | ||
| else: | ||
| max_diff = (test_param - expected).abs().max().item() | ||
| print(f"Max |optimizer - reference|: {max_diff:.6e}", flush=True) | ||
| print("MUON OPTIMIZER CHECK FAILED", flush=True, file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| optimizer.destroy() | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Tests for distributed Muon optimizer.""" | ||
|
|
||
| import os | ||
| import subprocess | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from transformer_engine.pytorch.optimizers.muon import MuonOptimizer | ||
|
|
||
| MULTI_GPU_AVAILABLE = torch.cuda.device_count() >= 2 | ||
| requires_multi_gpu = pytest.mark.skipif( | ||
| not MULTI_GPU_AVAILABLE, | ||
| reason="Muon optimizer distributed tests require at least 2 GPUs.", | ||
| ) | ||
|
|
||
| TEST_ROOT = Path(__file__).parent.resolve() | ||
| NUM_PROCS = torch.cuda.device_count() | ||
| LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] | ||
|
|
||
|
|
||
| def _run_test(dtype: str, partition_dim: int, weight_decay_mode: str) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Each |
||
| test_path = TEST_ROOT / "run_muon_optimizer.py" | ||
| test_cmd = LAUNCH_CMD + [ | ||
| str(test_path), | ||
| f"--dtype={dtype}", | ||
| f"--partition-dim={partition_dim}", | ||
| f"--weight-decay-mode={weight_decay_mode}", | ||
| ] | ||
| result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False, timeout=300) | ||
| if ( | ||
| result.returncode != 0 | ||
| or "MUON OPTIMIZER CHECK FAILED" in result.stderr.decode() | ||
| or "MUON OPTIMIZER CHECK PASSED" not in result.stdout.decode() | ||
| ): | ||
| raise AssertionError( | ||
| "Muon optimizer test failed.\n" | ||
| f"stdout: {result.stdout.decode()}\n" | ||
| f"stderr: {result.stderr.decode()}" | ||
| ) | ||
|
|
||
|
|
||
| @requires_multi_gpu | ||
| @pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) | ||
| @pytest.mark.parametrize("partition_dim", [0, 1]) | ||
| def test_muon_optimizer_matches_reference(dtype: str, partition_dim: int) -> None: | ||
| """Compare distributed Muon updates with a full-matrix reference.""" | ||
| _run_test(dtype, partition_dim, "decoupled") | ||
|
|
||
|
|
||
| @requires_multi_gpu | ||
| def test_muon_optimizer_l2_weight_decay() -> None: | ||
| """Exercise the L2 weight decay branch against the same reference.""" | ||
| _run_test("float32", 1, "l2") | ||
|
|
||
|
|
||
| def test_muon_optimizer_requires_explicit_process_group() -> None: | ||
| """Muon should not silently fall back to the world process group.""" | ||
| param = torch.nn.Parameter(torch.empty(2, 2)) | ||
| with pytest.raises(ValueError, match="explicit NCCL tensor-parallel process_group"): | ||
| MuonOptimizer([param], process_group=None, partition_dim=0) | ||
|
|
||
|
|
||
| def test_muon_optimizer_resolves_partition_dim_per_parameter() -> None: | ||
| """TE tensor-parallel metadata should provide per-parameter partition dims.""" | ||
| param = torch.empty(2, 2) | ||
| param.partition_dim = 0 | ||
|
|
||
| assert MuonOptimizer._resolve_partition_dim(param, None) == 0 | ||
|
|
||
| param_without_metadata = torch.empty(2, 2) | ||
| assert MuonOptimizer._resolve_partition_dim(param_without_metadata, 1) == 1 | ||
|
|
||
| with pytest.raises(ValueError, match="Conflicting partition_dim"): | ||
| MuonOptimizer._resolve_partition_dim(param, 1) | ||
|
|
||
| param.partition_dim = -1 | ||
| with pytest.raises(ValueError, match="Non-parallel parameters are not supported"): | ||
| MuonOptimizer._resolve_partition_dim(param, None) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make sure to include this in the QA script: https://github.com/NVIDIA/TransformerEngine/blob/main/qa/L1_pytorch_distributed_unittest/test.sh