Skip to content
215 changes: 215 additions & 0 deletions tests/pytorch/distributed/run_muon_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Distributed Muon optimizer test worker.

Launched via torchrun from test_muon_optimizer.py.
"""

import argparse
import sys

import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record

import transformer_engine.pytorch as te
from transformer_engine.pytorch.optimizers.newton_schulz import get_coefficients
from transformer_engine.pytorch.optimizers.muon import get_muon_scale_factor


def _reference_orthogonalize(
grad: torch.Tensor,
*,
partition_dim: int,
coefficients: list[tuple[float, float, float]],
scale_mode: str,
extra_scale_factor: float,
eps: float,
) -> torch.Tensor:
global_shape = [grad.size(0), grad.size(1)]

x = grad.clone()
if partition_dim == 0:
x = x.mT.contiguous()

x = x / torch.sqrt((x.float() * x.float()).sum()).clamp_min(eps).to(dtype=x.dtype)

for a, b, c in coefficients:
xxt = x @ x.mT
x = a * x + b * (xxt @ x) + c * ((xxt @ xxt) @ x)

if partition_dim == 0:
x = x.mT.contiguous()

scale = get_muon_scale_factor(global_shape[0], global_shape[1], mode=scale_mode)
return x * (scale * extra_scale_factor)


def _reference_step(
param: torch.Tensor,
grad: torch.Tensor,
momentum_buffer: torch.Tensor,
*,
lr: float,
momentum: float,
nesterov: bool,
weight_decay: float,
use_decoupled_weight_decay: bool,
partition_dim: int,
coefficients: list[tuple[float, float, float]],
scale_mode: str,
extra_scale_factor: float,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
param = param.clone()
grad = grad.clone()
momentum_buffer = momentum_buffer.clone()

if use_decoupled_weight_decay:
param = param * (1.0 - lr * weight_decay)
elif weight_decay != 0:
grad = grad + weight_decay * param

momentum_buffer = momentum * momentum_buffer + (1.0 - momentum) * grad
if nesterov:
update = (1.0 - momentum) * grad + momentum * momentum_buffer
else:
update = momentum_buffer

orth_update = _reference_orthogonalize(
update,
partition_dim=partition_dim,
coefficients=coefficients,
scale_mode=scale_mode,
extra_scale_factor=extra_scale_factor,
eps=eps,
)
param = param - lr * orth_update
return param, momentum_buffer


@record
def main():
parser = argparse.ArgumentParser(description="Distributed Muon optimizer test")
parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"])
parser.add_argument("--partition-dim", type=int, default=1, choices=[0, 1])
parser.add_argument(
"--weight-decay-mode", type=str, default="decoupled", choices=["decoupled", "l2"]
)
parser.add_argument("--num-steps", type=int, default=2)
args = parser.parse_args()

dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)

dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16
if args.partition_dim == 0:
full_shape = (world_size * 64, 96)
else:
full_shape = (96, world_size * 64)

lr = 3e-4
momentum = 0.95
nesterov = True
weight_decay = 0.01
use_decoupled_weight_decay = args.weight_decay_mode == "decoupled"
coefficient_type = "quintic"
num_ns_steps = 5
scale_mode = "spectral"
extra_scale_factor = 1.0
eps = 1e-7
coefficients = get_coefficients(num_ns_steps, coefficient_type)

if rank == 0:
torch.manual_seed(1234)
full_param = torch.randn(full_shape, device="cuda", dtype=dtype)
full_grads = [
torch.randn(full_shape, device="cuda", dtype=dtype) for _ in range(args.num_steps)
]
else:
full_param = torch.empty(full_shape, device="cuda", dtype=dtype)
full_grads = [
torch.empty(full_shape, device="cuda", dtype=dtype) for _ in range(args.num_steps)
]

dist.broadcast(full_param, src=0)
for grad in full_grads:
dist.broadcast(grad, src=0)

shard_size = full_shape[args.partition_dim] // world_size
shard_slice = slice(rank * shard_size, (rank + 1) * shard_size)
if args.partition_dim == 0:
local_param_init = full_param[shard_slice, :].contiguous()
else:
local_param_init = full_param[:, shard_slice].contiguous()

param = torch.nn.Parameter(local_param_init.clone())
param.partition_dim = args.partition_dim
optimizer = te.optimizers.MuonOptimizer(
[param],
lr=lr,
momentum=momentum,
nesterov=nesterov,
weight_decay=weight_decay,
use_decoupled_weight_decay=use_decoupled_weight_decay,
coefficient_type=coefficient_type,
num_ns_steps=num_ns_steps,
scale_mode=scale_mode,
extra_scale_factor=extra_scale_factor,
process_group=dist.group.WORLD,
eps=eps,
)

ref_param = full_param.float()
ref_momentum = torch.zeros_like(ref_param)
for full_grad in full_grads:
if args.partition_dim == 0:
param.grad = full_grad[shard_slice, :].contiguous()
else:
param.grad = full_grad[:, shard_slice].contiguous()
optimizer.step()

ref_param, ref_momentum = _reference_step(
ref_param,
full_grad.float(),
ref_momentum,
lr=lr,
momentum=momentum,
nesterov=nesterov,
weight_decay=weight_decay,
use_decoupled_weight_decay=use_decoupled_weight_decay,
partition_dim=args.partition_dim,
coefficients=coefficients,
scale_mode=scale_mode,
extra_scale_factor=extra_scale_factor,
eps=eps,
)

gathered = [torch.empty_like(param) for _ in range(world_size)]
dist.all_gather(gathered, param)
if args.partition_dim == 0:
test_param = torch.cat(gathered, dim=0)
else:
test_param = torch.cat(gathered, dim=1)

if rank == 0:
expected = ref_param.to(dtype)
atol, rtol = (5e-2, 5e-2) if dtype == torch.bfloat16 else (2e-3, 2e-3)
if torch.allclose(test_param, expected, atol=atol, rtol=rtol):
print("MUON OPTIMIZER CHECK PASSED", flush=True)
else:
max_diff = (test_param - expected).abs().max().item()
print(f"Max |optimizer - reference|: {max_diff:.6e}", flush=True)
print("MUON OPTIMIZER CHECK FAILED", flush=True, file=sys.stderr)
sys.exit(1)

optimizer.destroy()
dist.destroy_process_group()


if __name__ == "__main__":
main()
43 changes: 35 additions & 8 deletions tests/pytorch/distributed/run_newton_schulz.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record

from transformer_engine.pytorch.newton_schulz import (
from transformer_engine.pytorch.optimizers.newton_schulz import (
CusolverMpCtx,
get_coefficients,
newton_schulz,
newton_schulz_tp,
)


Expand All @@ -43,6 +44,11 @@ def main():
parser.add_argument("--matrix-cols", type=int, default=None)
parser.add_argument("--num-iterations", type=int, default=5)
parser.add_argument("--coeff-type", type=str, default="quintic")
parser.add_argument("--api", type=str, default="base", choices=["base", "tp"])
parser.add_argument("--partition-dim", type=int, default=1, choices=[0, 1])
parser.add_argument(
"--tp-mode", type=str, default="distributed", choices=["duplicated", "distributed"]
)
parser.add_argument("--atol", type=float, default=1e-2)
parser.add_argument("--rtol", type=float, default=1e-2)
args = parser.parse_args()
Expand All @@ -57,8 +63,13 @@ def main():
n = args.matrix_cols if args.matrix_cols is not None else args.matrix_rows
coefficients = get_coefficients(args.num_iterations, args.coeff_type)

# Ensure the distributed column dimension is divisible by world_size.
assert n % world_size == 0, f"Matrix columns {n} must be divisible by world_size {world_size}"
if args.api == "base" or args.partition_dim == 1:
# Ensure the distributed column dimension is divisible by world_size.
assert (
n % world_size == 0
), f"Matrix columns {n} must be divisible by world_size {world_size}"
else:
assert m % world_size == 0, f"Matrix rows {m} must be divisible by world_size {world_size}"

# Create a random matrix on rank 0 with singular values in (0, 1),
# which keeps the Newton-Schulz iterations in the convergence regime.
Expand All @@ -80,20 +91,36 @@ def main():
# Broadcast the full matrix to all ranks
dist.broadcast(A, src=0)

# Scatter columns to each rank
local_cols = n // world_size
x_local = A[:, rank * local_cols : (rank + 1) * local_cols].contiguous()
# Scatter columns for the base API. Scatter along partition_dim for the TP API.
if args.api == "tp" and args.partition_dim == 0:
local_rows = m // world_size
x_local = A[rank * local_rows : (rank + 1) * local_rows, :].contiguous()
gather_dim = 0
else:
local_cols = n // world_size
x_local = A[:, rank * local_cols : (rank + 1) * local_cols].contiguous()
gather_dim = 1

ctx = CusolverMpCtx(dist.group.WORLD)
try:
newton_schulz(x_local, ctx, args.num_iterations, coefficients=coefficients)
if args.api == "tp":
newton_schulz_tp(
x_local,
ctx,
args.num_iterations,
coefficients=coefficients,
partition_dim=args.partition_dim,
tp_mode=args.tp_mode,
)
else:
newton_schulz(x_local, ctx, args.num_iterations, coefficients=coefficients)
finally:
ctx.destroy()

# Gather results
gathered = [torch.empty_like(x_local) for _ in range(world_size)]
dist.all_gather(gathered, x_local)
X = torch.cat(gathered, dim=1)
X = torch.cat(gathered, dim=gather_dim)

# Check: the resulting matrix should be orthogonal, or match a local reference.
if rank == 0:
Expand Down
84 changes: 84 additions & 0 deletions tests/pytorch/distributed/test_muon_optimizer.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Tests for distributed Muon optimizer."""

import os
import subprocess
from pathlib import Path

import pytest
import torch

from transformer_engine.pytorch.optimizers.muon import MuonOptimizer

MULTI_GPU_AVAILABLE = torch.cuda.device_count() >= 2
requires_multi_gpu = pytest.mark.skipif(
not MULTI_GPU_AVAILABLE,
reason="Muon optimizer distributed tests require at least 2 GPUs.",
)

TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS = torch.cuda.device_count()
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]


def _run_test(dtype: str, partition_dim: int, weight_decay_mode: str) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each torchrun launch is somewhat expensive. Instead of launching a separate torchrun for each test case, it's better to launch a single torchrun instance and to perform multiple tests internally. See distributed/test_fusible_ops.py for an example.

test_path = TEST_ROOT / "run_muon_optimizer.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
f"--dtype={dtype}",
f"--partition-dim={partition_dim}",
f"--weight-decay-mode={weight_decay_mode}",
]
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False, timeout=300)
if (
result.returncode != 0
or "MUON OPTIMIZER CHECK FAILED" in result.stderr.decode()
or "MUON OPTIMIZER CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(
"Muon optimizer test failed.\n"
f"stdout: {result.stdout.decode()}\n"
f"stderr: {result.stderr.decode()}"
)


@requires_multi_gpu
@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
@pytest.mark.parametrize("partition_dim", [0, 1])
def test_muon_optimizer_matches_reference(dtype: str, partition_dim: int) -> None:
"""Compare distributed Muon updates with a full-matrix reference."""
_run_test(dtype, partition_dim, "decoupled")


@requires_multi_gpu
def test_muon_optimizer_l2_weight_decay() -> None:
"""Exercise the L2 weight decay branch against the same reference."""
_run_test("float32", 1, "l2")


def test_muon_optimizer_requires_explicit_process_group() -> None:
"""Muon should not silently fall back to the world process group."""
param = torch.nn.Parameter(torch.empty(2, 2))
with pytest.raises(ValueError, match="explicit NCCL tensor-parallel process_group"):
MuonOptimizer([param], process_group=None, partition_dim=0)


def test_muon_optimizer_resolves_partition_dim_per_parameter() -> None:
"""TE tensor-parallel metadata should provide per-parameter partition dims."""
param = torch.empty(2, 2)
param.partition_dim = 0

assert MuonOptimizer._resolve_partition_dim(param, None) == 0

param_without_metadata = torch.empty(2, 2)
assert MuonOptimizer._resolve_partition_dim(param_without_metadata, 1) == 1

with pytest.raises(ValueError, match="Conflicting partition_dim"):
MuonOptimizer._resolve_partition_dim(param, 1)

param.partition_dim = -1
with pytest.raises(ValueError, match="Non-parallel parameters are not supported"):
MuonOptimizer._resolve_partition_dim(param, None)
Loading