[PyTorch] Add distributed Muon optimizer#2920
[PyTorch] Add distributed Muon optimizer#2920vcherepanov-nv wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryAdds
Confidence Score: 4/5Safe to merge after addressing the breaking module-path removal; the optimizer logic itself is correct. One P1 finding: renaming
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant MuonOptimizer
participant DistNormalize as _distributed_normalize_p2_
participant NS as newton_schulz_tp
participant CusolverMp as CusolverMpCtx
Caller->>MuonOptimizer: step()
MuonOptimizer->>MuonOptimizer: weight decay (decoupled or L2)
MuonOptimizer->>MuonOptimizer: momentum_buffer.lerp_(grad, 1-β)
MuonOptimizer->>MuonOptimizer: Nesterov update = grad.lerp(buf, β)
MuonOptimizer->>DistNormalize: _distributed_normalize_p2_(update, eps)
DistNormalize-->>DistNormalize: local norm²
DistNormalize->>DistNormalize: all_reduce(SUM) → global norm²
DistNormalize-->>MuonOptimizer: update ÷= global_norm (in-place)
MuonOptimizer->>NS: newton_schulz_tp(update, ctx, partition_dim, tp_mode=distributed)
alt partition_dim == 0
NS->>NS: x_t = update.mT.contiguous()
NS->>CusolverMp: newton_schulz(x_t, ctx, num_iters)
NS-->>NS: update.copy_(x_t.mT)
else partition_dim == 1
NS->>CusolverMp: newton_schulz(update, ctx, num_iters)
end
CusolverMp-->>NS: in-place orthogonalized shard
NS-->>MuonOptimizer: orthogonalized update
MuonOptimizer->>MuonOptimizer: update *= scale_factor * extra_scale
MuonOptimizer->>Caller: p -= lr * orth_update
|
| def step(self, closure=None): | ||
| """Perform a single optimization step.""" | ||
| loss = None | ||
| if closure is not None: | ||
| loss = closure() | ||
|
|
There was a problem hiding this comment.
Closure called inside
@torch.no_grad(), preventing gradient computation
closure() is invoked while torch.no_grad() is active. Any loss.backward() call inside the closure will silently produce zero/no gradients. The standard PyTorch pattern (used in SGD, Adam, etc.) is to wrap the closure in with torch.enable_grad():.
| def step(self, closure=None): | |
| """Perform a single optimization step.""" | |
| loss = None | |
| if closure is not None: | |
| loss = closure() | |
| @torch.no_grad() | |
| def step(self, closure=None): | |
| """Perform a single optimization step.""" | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() |
| scale_mode: str, | ||
| extra_scale_factor: float, | ||
| eps: float, | ||
| ) -> torch.Tensor: | ||
| global_shape = [grad.size(0), grad.size(1)] | ||
| global_shape[partition_dim] *= world_size |
There was a problem hiding this comment.
Reference
global_shape incorrectly scales an already-full tensor
_reference_orthogonalize receives the full matrix (shape full_shape) but then multiplies global_shape[partition_dim] by world_size a second time. For partition_dim=1 with world_size=2 and full_shape=(96, 128) this gives global_shape=[96, 256], so get_muon_scale_factor returns max(96,256)^0.5 = 16. The optimizer, operating on the shard (96, 64), correctly reconstructs global_shape=[96, 128] and computes max(96,128)^0.5 ≈ 11.3. This √2 discrepancy means the reference cannot correctly validate the optimizer's output.
The global_shape[partition_dim] *= world_size line should be removed since the input is already the full matrix.
| if mode == "unit_rms_norm": | ||
| return (size_out / size_in) ** 0.5 |
There was a problem hiding this comment.
unit_rms_norm mode can divide by zero when size_in == 0
(size_out / size_in) ** 0.5 raises ZeroDivisionError when size_in is 0. While the optimizer validates that the partition dimension is non-empty, it doesn't ensure the other dimension is non-zero. Consider adding a guard or documenting that both dimensions must be strictly positive.
| if group["nesterov"]: | ||
| update = grad.lerp(momentum_buffer, group["momentum"]) | ||
| else: | ||
| update = momentum_buffer |
There was a problem hiding this comment.
Non-Nesterov
update is an alias to momentum_buffer, not a copy
update = momentum_buffer holds a reference. If _orthogonalize ever modifies its input in-place in a future refactor, the momentum buffer will be silently corrupted. _orthogonalize currently clones the input immediately so this is safe today, but a defensive .clone() or comment would make the intent explicit.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
skyw
left a comment
There was a problem hiding this comment.
I'd advice NOT to expose it in public API. Keeping it in test only if that is the purpose.
Having an optimizer with most code copied invites fragmentation.
Before this, all optimizer TE provides are more optimized fused version. I'd say a highly optimized Fused Muon with similar concept can be justified, but would need more consideration because it has more dependencies on other part of the training pipeline than elementwise optimizers.
| on tensor-parallel parameter shards. The local parameter shard must represent a | ||
| partition of a logical 2D matrix across the provided NCCL process group. | ||
|
|
||
| Args: |
There was a problem hiding this comment.
Q: Does TE use numpy style docstring instead of Google style?
|
|
||
| def __init__( | ||
| self, | ||
| params: Iterable[torch.nn.Parameter | dict], |
There was a problem hiding this comment.
Nit: The type here doesn't match PyTorch internal. Should be fine for the purpose of this class.
| scale_mode: MuonScaleT = "spectral", | ||
| extra_scale_factor: float = 1.0, | ||
| process_group: Optional[dist.ProcessGroup] = None, | ||
| partition_dim: int = 1, |
| raise ValueError(f"Invalid weight_decay value: {weight_decay}") | ||
| if num_ns_steps < 1: | ||
| raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") | ||
| if partition_dim not in (0, 1): |
There was a problem hiding this comment.
Q: Does this class intend to support non-distributed case? partition_dim would be -1 in TE in such case.
|
|
||
| if process_group is None: | ||
| if not dist.is_initialized(): | ||
| raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.") |
There was a problem hiding this comment.
Same question above regarding single GPU support.
| if process_group is None: | ||
| if not dist.is_initialized(): | ||
| raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.") | ||
| process_group = dist.group.WORLD |
There was a problem hiding this comment.
Suggestion: This silent behavior is dangerous. If user forgot to pass the correct TP group, wrong group will be used.
| global_shape[partition_dim] *= world_size | ||
|
|
||
| orth_grad = grad.clone() | ||
| transposed = partition_dim == 0 |
There was a problem hiding this comment.
Attn: This is from common Row and Column wise tensor parallelism in most LLM. It would be sub optimal for anything other than that. Add comment if the assumption is made.
The idea was to give something to users, who use TE, but not Megatron-LM. By fragmentation you mean that we want to encourage everyone to use Megatron-LM? Or that the optimizer being relatively thin thing on top of newton_schulz call, and the users should have no trouble creating it themselves? I don't think we gain anything by putting it into tests, since we already have tests for newton_schulz call. So we need to decide whether we want this PR, or should abandon it altogether. @cyanguwa |
Fragmentation means there will be different flavor of muon in emerging optimizer and TE, also a lot of copied code. TE can have stalled feature when emerging optimizer updates. Megatron-LM will always have its own version because there are implementation specific things need to be hooked together. For example, how QKV is implemetned, or fused swighlu. |
There was a problem hiding this comment.
Should we move newton_schulz.py to this directory? Also, how do we expect Megatron to call us for this functionality? Thanks.
There was a problem hiding this comment.
Should we move newton_schulz.py to this directory?
No, don't think so.
Should we move newton_schulz.py to this directory?
Megatron will call newton_shulz directly from their optimizers. This one is for other users.
There was a problem hiding this comment.
I'd prefer moving them into something like transformer_engine/pytorch/cusolver. But I suppose that is orthogonal to this PR.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
|
|
@skyw Just following up on the discussion above - our purpose for this PR was two-fold. One was to provide an equivalent |
There was a problem hiding this comment.
I'd prefer moving them into something like transformer_engine/pytorch/cusolver. But I suppose that is orthogonal to this PR.
There was a problem hiding this comment.
We should make sure to include this in the QA script: https://github.com/NVIDIA/TransformerEngine/blob/main/qa/L1_pytorch_distributed_unittest/test.sh
| LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] | ||
|
|
||
|
|
||
| def _run_test(dtype: str, partition_dim: int, weight_decay_mode: str) -> None: |
There was a problem hiding this comment.
Each torchrun launch is somewhat expensive. Instead of launching a separate torchrun for each test case, it's better to launch a single torchrun instance and to perform multiple tests internally. See distributed/test_fusible_ops.py for an example.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Megatron is wrapped over emgering-optimizers with megatron specific details, like TP and how QKV are organized. The most optimizer logic is in emerging-optimizers. Could TE do the same? I understand introducing a new dependency may have concern, let me know. The biggest concern is actually large portion of duplicated code. What I would favor is having |
Description
Add a distributed Muon optimizer, based on newton_schulz orthogonalization
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: