Skip to content

[PyTorch] Add distributed Muon optimizer#2920

Open
vcherepanov-nv wants to merge 10 commits intoNVIDIA:mainfrom
vcherepanov-nv:muon
Open

[PyTorch] Add distributed Muon optimizer#2920
vcherepanov-nv wants to merge 10 commits intoNVIDIA:mainfrom
vcherepanov-nv:muon

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

Description

Add a distributed Muon optimizer, based on newton_schulz orthogonalization

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add an optimizer class and tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vcherepanov-nv and others added 2 commits April 23, 2026 18:50
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 23, 2026

Greptile Summary

Adds MuonOptimizer, a distributed Muon optimizer that applies SGD-momentum followed by Newton-Schulz orthogonalization on tensor-parallel parameter shards, and a newton_schulz_tp convenience wrapper that supports both distributed and duplicated TP modes. The existing newton_schulz.py is renamed into optimizers/.

  • P1 — breaking module rename: transformer_engine/pytorch/newton_schulz.py is renamed to transformer_engine/pytorch/optimizers/newton_schulz.py with no backward-compat stub, silently removing the transformer_engine.pytorch.newton_schulz module path. Direct imports such as from transformer_engine.pytorch.newton_schulz import get_coefficients will raise ModuleNotFoundError, yet the PR is marked non-breaking.

Confidence Score: 4/5

Safe to merge after addressing the breaking module-path removal; the optimizer logic itself is correct.

One P1 finding: renaming newton_schulz.py without a backward-compat stub breaks the transformer_engine.pytorch.newton_schulz module path for any downstream direct imports. The optimizer's numerical correctness (momentum, Nesterov, weight decay, distributed normalization, scale factor) is sound and well-tested.

transformer_engine/pytorch/__init__.py and the absence of a stub at transformer_engine/pytorch/newton_schulz.py.

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/muon.py New MuonOptimizer class with distributed Newton-Schulz orthogonalization; momentum, Nesterov, and weight-decay branches are correct; get_coefficients is re-allocated on every step (P2).
transformer_engine/pytorch/optimizers/newton_schulz.py Renamed from transformer_engine/pytorch/newton_schulz.py; adds newton_schulz_tp and _orthogonalize_replicated helpers; logic for distributed/duplicated TP modes looks correct.
transformer_engine/pytorch/init.py Import updated to the new module path; old transformer_engine.pytorch.newton_schulz module path is silently dropped, breaking any direct module-level imports from that path (P1).
tests/pytorch/distributed/run_muon_optimizer.py New distributed test worker; reference implementation correctly uses the full matrix shape without double-scaling the partition dimension.
tests/pytorch/distributed/test_muon_optimizer.py New pytest harness covering both partition dims, bfloat16, L2 weight decay, missing-process-group guard, and per-parameter partition_dim resolution.
transformer_engine/pytorch/optimizers/init.py Adds MuonOptimizer and get_muon_scale_factor to the optimizers package exports.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant MuonOptimizer
    participant DistNormalize as _distributed_normalize_p2_
    participant NS as newton_schulz_tp
    participant CusolverMp as CusolverMpCtx

    Caller->>MuonOptimizer: step()
    MuonOptimizer->>MuonOptimizer: weight decay (decoupled or L2)
    MuonOptimizer->>MuonOptimizer: momentum_buffer.lerp_(grad, 1-β)
    MuonOptimizer->>MuonOptimizer: Nesterov update = grad.lerp(buf, β)
    MuonOptimizer->>DistNormalize: _distributed_normalize_p2_(update, eps)
    DistNormalize-->>DistNormalize: local norm²
    DistNormalize->>DistNormalize: all_reduce(SUM) → global norm²
    DistNormalize-->>MuonOptimizer: update ÷= global_norm (in-place)
    MuonOptimizer->>NS: newton_schulz_tp(update, ctx, partition_dim, tp_mode=distributed)
    alt partition_dim == 0
        NS->>NS: x_t = update.mT.contiguous()
        NS->>CusolverMp: newton_schulz(x_t, ctx, num_iters)
        NS-->>NS: update.copy_(x_t.mT)
    else partition_dim == 1
        NS->>CusolverMp: newton_schulz(update, ctx, num_iters)
    end
    CusolverMp-->>NS: in-place orthogonalized shard
    NS-->>MuonOptimizer: orthogonalized update
    MuonOptimizer->>MuonOptimizer: update *= scale_factor * extra_scale
    MuonOptimizer->>Caller: p -= lr * orth_update
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/__init__.py, line 62-66 (link)

    P1 transformer_engine.pytorch.newton_schulz module path is removed without a backward-compat stub

    transformer_engine/pytorch/newton_schulz.py existed in the base branch and was accessible as the module transformer_engine.pytorch.newton_schulz. This PR renames it to transformer_engine/pytorch/optimizers/newton_schulz.py and updates the __init__.py re-exports for the public symbols, but leaves no stub at the old path. Any downstream code that imported directly from the old module path — e.g., from transformer_engine.pytorch.newton_schulz import get_coefficients — will now fail with ModuleNotFoundError. Adding a shim module at transformer_engine/pytorch/newton_schulz.py that re-exports from the new location would preserve backward compatibility while the PR is marked non-breaking.

Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +186 to +191
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
loss = closure()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Closure called inside @torch.no_grad(), preventing gradient computation

closure() is invoked while torch.no_grad() is active. Any loss.backward() call inside the closure will silently produce zero/no gradients. The standard PyTorch pattern (used in SGD, Adam, etc.) is to wrap the closure in with torch.enable_grad():.

Suggested change
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
loss = closure()
@torch.no_grad()
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

Comment on lines +28 to +33
scale_mode: str,
extra_scale_factor: float,
eps: float,
) -> torch.Tensor:
global_shape = [grad.size(0), grad.size(1)]
global_shape[partition_dim] *= world_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Reference global_shape incorrectly scales an already-full tensor

_reference_orthogonalize receives the full matrix (shape full_shape) but then multiplies global_shape[partition_dim] by world_size a second time. For partition_dim=1 with world_size=2 and full_shape=(96, 128) this gives global_shape=[96, 256], so get_muon_scale_factor returns max(96,256)^0.5 = 16. The optimizer, operating on the shard (96, 64), correctly reconstructs global_shape=[96, 128] and computes max(96,128)^0.5 ≈ 11.3. This √2 discrepancy means the reference cannot correctly validate the optimizer's output.

The global_shape[partition_dim] *= world_size line should be removed since the input is already the full matrix.

Comment on lines +33 to +34
if mode == "unit_rms_norm":
return (size_out / size_in) ** 0.5
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 unit_rms_norm mode can divide by zero when size_in == 0

(size_out / size_in) ** 0.5 raises ZeroDivisionError when size_in is 0. While the optimizer validates that the partition dimension is non-empty, it doesn't ensure the other dimension is non-zero. Consider adding a guard or documenting that both dimensions must be strictly positive.

Comment on lines +218 to +221
if group["nesterov"]:
update = grad.lerp(momentum_buffer, group["momentum"])
else:
update = momentum_buffer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Non-Nesterov update is an alias to momentum_buffer, not a copy

update = momentum_buffer holds a reference. If _orthogonalize ever modifies its input in-place in a future refactor, the momentum buffer will be silently corrupted. _orthogonalize currently clones the input immediately so this is safe today, but a defensive .clone() or comment would make the intent explicit.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@vcherepanov-nv vcherepanov-nv changed the title [Draft] [PyTorch] Add distributed Muon optimizer [PyTorch] Add distributed Muon optimizer Apr 27, 2026
@vcherepanov-nv vcherepanov-nv requested a review from cyanguwa April 27, 2026 18:12
Copy link
Copy Markdown

@skyw skyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd advice NOT to expose it in public API. Keeping it in test only if that is the purpose.

Having an optimizer with most code copied invites fragmentation.

Before this, all optimizer TE provides are more optimized fused version. I'd say a highly optimized Fused Muon with similar concept can be justified, but would need more consideration because it has more dependencies on other part of the training pipeline than elementwise optimizers.

on tensor-parallel parameter shards. The local parameter shard must represent a
partition of a logical 2D matrix across the provided NCCL process group.

Args:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does TE use numpy style docstring instead of Google style?


def __init__(
self,
params: Iterable[torch.nn.Parameter | dict],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The type here doesn't match PyTorch internal. Should be fine for the purpose of this class.

scale_mode: MuonScaleT = "spectral",
extra_scale_factor: float = 1.0,
process_group: Optional[dist.ProcessGroup] = None,
partition_dim: int = 1,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix: partition_dim is per parameter.

raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
if partition_dim not in (0, 1):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Does this class intend to support non-distributed case? partition_dim would be -1 in TE in such case.


if process_group is None:
if not dist.is_initialized():
raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question above regarding single GPU support.

if process_group is None:
if not dist.is_initialized():
raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.")
process_group = dist.group.WORLD
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: This silent behavior is dangerous. If user forgot to pass the correct TP group, wrong group will be used.

Comment thread transformer_engine/pytorch/optimizers/muon.py
global_shape[partition_dim] *= world_size

orth_grad = grad.clone()
transposed = partition_dim == 0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attn: This is from common Row and Column wise tensor parallelism in most LLM. It would be sub optimal for anything other than that. Add comment if the assumption is made.

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator Author

Having an optimizer with most code copied invites fragmentation.

The idea was to give something to users, who use TE, but not Megatron-LM. By fragmentation you mean that we want to encourage everyone to use Megatron-LM? Or that the optimizer being relatively thin thing on top of newton_schulz call, and the users should have no trouble creating it themselves?

I don't think we gain anything by putting it into tests, since we already have tests for newton_schulz call. So we need to decide whether we want this PR, or should abandon it altogether. @cyanguwa

@skyw
Copy link
Copy Markdown

skyw commented Apr 28, 2026

Having an optimizer with most code copied invites fragmentation.

The idea was to give something to users, who use TE, but not Megatron-LM. By fragmentation you mean that we want to encourage everyone to use Megatron-LM? Or that the optimizer being relatively thin thing on top of newton_schulz call, and the users should have no trouble creating it themselves?

I don't think we gain anything by putting it into tests, since we already have tests for newton_schulz call. So we need to decide whether we want this PR, or should abandon it altogether. @cyanguwa

Fragmentation means there will be different flavor of muon in emerging optimizer and TE, also a lot of copied code. TE can have stalled feature when emerging optimizer updates. Megatron-LM will always have its own version because there are implementation specific things need to be hooked together. For example, how QKV is implemetned, or fused swighlu.
For TE, I think an example of how to build a version of emerging optimizer use TE NS backend would be good to have. But providing optimizer (not fusion optimized version) confuses customers.
Having said that, I would love for TE to have a more optimized version. similar idea as fusedAdam, etc.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move newton_schulz.py to this directory? Also, how do we expect Megatron to call us for this functionality? Thanks.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move newton_schulz.py to this directory?

No, don't think so.

Should we move newton_schulz.py to this directory?

Megatron will call newton_shulz directly from their optimizers. This one is for other users.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer moving them into something like transformer_engine/pytorch/cusolver. But I suppose that is orthogonal to this PR.

vcherepanov-nv and others added 3 commits May 1, 2026 07:27
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@cyanguwa
Copy link
Copy Markdown
Collaborator

cyanguwa commented May 4, 2026

  1. run CI here "/te-ci torch L1"; add tests to qa/Lx_pytorch_unittest/test.sh?
  2. please create newton_schulz_tp API to include partition_dim/mode params, and for Megatron integration
  3. please test non-distributed cases (per comment above)
  4. please move newton_schulz.py to te/pytorch/optimizers; if we have more solvers to integrate to TE or more use cases of Newton-Schulz, we can definitely restructure the code, but we don't see that in the near future

@cyanguwa
Copy link
Copy Markdown
Collaborator

cyanguwa commented May 4, 2026

@skyw Just following up on the discussion above - our purpose for this PR was two-fold. One was to provide an equivalent newton_schulz_tp API for Megatron; the other one was to provide a dialed-down version of Muon Optimizer class so direct TE users can access the Newton-Schulz solver. I understand this may cause divergence in TE and Megatron's Muon support, but we do want to expose this feature to direct users of TE as well. Hope that helps. Thanks.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer moving them into something like transformer_engine/pytorch/cusolver. But I suppose that is orthogonal to this PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]


def _run_test(dtype: str, partition_dim: int, weight_decay_mode: str) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each torchrun launch is somewhat expensive. Instead of launching a separate torchrun for each test case, it's better to launch a single torchrun instance and to perform multiple tests internally. See distributed/test_fusible_ops.py for an example.

vcherepanov-nv and others added 4 commits May 4, 2026 23:21
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@skyw
Copy link
Copy Markdown

skyw commented May 5, 2026

@skyw Just following up on the discussion above - our purpose for this PR was two-fold. One was to provide an equivalent newton_schulz_tp API for Megatron; the other one was to provide a dialed-down version of Muon Optimizer class so direct TE users can access the Newton-Schulz solver. I understand this may cause divergence in TE and Megatron's Muon support, but we do want to expose this feature to direct users of TE as well. Hope that helps. Thanks.

Megatron is wrapped over emgering-optimizers with megatron specific details, like TP and how QKV are organized. The most optimizer logic is in emerging-optimizers. Could TE do the same? I understand introducing a new dependency may have concern, let me know. The biggest concern is actually large portion of duplicated code.

What I would favor is having newton_schulz_tp in one release, test it out. and have a well optimized version of Muon optimizer class in the next. There are a lot of optimizations (fusion, batch, graph capturability etc.) that can and I believe should go into TE.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants