Skip to content

Commit 456b565

Browse files
committed
feat(zero2): add CPU offload support for Muon optimizer
Enable Muon optimizer with ZeRO Stage 2 CPU offload. The Newton-Schulz orthogonalization always runs on GPU for performance (momentum is temporarily moved to GPU), while the momentum buffer stays on CPU to save GPU memory. The _apply_muon_update_for_cpu_offload method intercepts the gradient copy path in copy_grads_in_partition to apply muon_update before writing to the CPU FP32 grad buffer. Cross-boundary parameters are handled by processing the full gradient on each involved rank. Includes cosimulation test verifying offload vs non-offload produce consistent results. Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
1 parent aa3cfd1 commit 456b565

2 files changed

Lines changed: 219 additions & 1 deletion

File tree

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1502,6 +1502,79 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
15021502
return torch.tensor(total_norm, device=self.device, dtype=torch.float)
15031503

15041504
############################################################################################
1505+
def _apply_muon_update_for_cpu_offload(self, param):
1506+
"""Apply muon_update for a parameter in the CPU offload path.
1507+
1508+
For Muon parameters (use_muon=True), runs Newton-Schulz
1509+
orthogonalization on GPU (momentum is temporarily copied from
1510+
CPU to GPU) and writes only the partition slice back to the
1511+
CPU FP32 grad buffer. Cross-boundary parameters are
1512+
redundantly processed by each involved rank with the full
1513+
gradient, matching the non-offload path behavior in
1514+
get_flat_partition.
1515+
1516+
Returns True if muon_update was applied (caller should skip
1517+
the normal copy for this param).
1518+
"""
1519+
if not getattr(param, 'use_muon', False):
1520+
return False
1521+
if 'muon' not in self.optimizer.__class__.__name__.lower():
1522+
return False
1523+
1524+
param_id = self.get_param_id(param)
1525+
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
1526+
1527+
grad_accum = self.get_param_gradient_attribute(param)
1528+
if grad_accum is None:
1529+
return False
1530+
1531+
flatten_copy = self.optimizer.param_groups[i]['params'][0]
1532+
if "momentum_buffer" not in self.optimizer.state[flatten_copy]:
1533+
total_size = sum(p.numel() for p in self.params_in_partition[i])
1534+
self.optimizer.state[flatten_copy]["momentum_buffer"] = torch.zeros(total_size,
1535+
dtype=torch.float32,
1536+
device=self.device)
1537+
1538+
momentum_flat = self.optimizer.state[flatten_copy]["momentum_buffer"]
1539+
1540+
muon_offset = 0
1541+
for p in self.params_in_partition[i]:
1542+
if p is param:
1543+
break
1544+
muon_offset += p.numel()
1545+
1546+
momentum_cpu = momentum_flat[muon_offset:muon_offset + param.numel()].view(param.size())
1547+
1548+
beta = self.optimizer.param_groups[i].get('momentum', 0.95)
1549+
ns_method = self.optimizer.param_groups[i].get('ns_method', 'gram')
1550+
1551+
# Run NS on GPU: keep grad on GPU, temporarily move momentum to GPU
1552+
gpu_device = grad_accum.device
1553+
grad_gpu = grad_accum.detach().clone().to(dtype=torch.float32)
1554+
momentum_gpu = momentum_cpu.to(device=gpu_device, dtype=torch.float32)
1555+
update = muon_update(grad_gpu.view(param.size()), momentum_gpu, beta=beta, ns_method=ns_method)
1556+
momentum_cpu.copy_(momentum_gpu.to(device='cpu'))
1557+
update_cpu = update.to(device='cpu')
1558+
del grad_gpu, momentum_gpu
1559+
1560+
momentum_flat[muon_offset:muon_offset + param.numel()] = momentum_cpu.view(-1)
1561+
1562+
# Write only the partition slice of the update to CPU FP32 grad buffer
1563+
tensor_offset = 0
1564+
actual_num_elements = param.numel()
1565+
if source_offset > 0:
1566+
tensor_offset = source_offset
1567+
actual_num_elements = param.numel() - tensor_offset
1568+
if actual_num_elements > num_elements:
1569+
actual_num_elements = num_elements
1570+
1571+
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, actual_num_elements)
1572+
update_slice = update_cpu.view(-1).narrow(0, tensor_offset, actual_num_elements)
1573+
dest_tensor.copy_(update_slice.to(self.master_weights_and_grads_dtype))
1574+
1575+
self.clear_grad_attribute(param)
1576+
return True
1577+
15051578
def copy_grads_in_partition(self, param):
15061579
if self.cpu_offload:
15071580

@@ -1513,7 +1586,8 @@ def copy_grads_in_partition(self, param):
15131586

15141587
self.update_offload_overflow_tracker_for_param_grad(param)
15151588

1516-
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
1589+
if not self._apply_muon_update_for_cpu_offload(param):
1590+
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
15171591

15181592
return
15191593
#print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# DeepSpeed Team
4+
5+
import deepspeed
6+
import torch
7+
import pytest
8+
9+
from unit.common import DistributedTest
10+
from unit.simple_model import SimpleModel
11+
from deepspeed.accelerator import get_accelerator
12+
13+
if torch.half not in get_accelerator().supported_dtypes():
14+
pytest.skip(f"fp16 not supported", allow_module_level=True)
15+
16+
17+
@pytest.mark.parametrize('zero_stage', [2])
18+
class TestMuonCPUOffload(DistributedTest):
19+
20+
def test_momentum_buffer_on_cpu(self, zero_stage):
21+
"""Verify Muon CPU offload creates momentum buffer on CPU.
22+
23+
This is the key invariant: after a training step with CPU offload,
24+
the Muon momentum buffer must reside on CPU (not GPU), confirming
25+
that muon_update ran on CPU and no GPU memory is wasted.
26+
"""
27+
hidden_dim = 32
28+
batch_size = 8
29+
config_dict = {
30+
"train_batch_size": batch_size,
31+
"optimizer": {
32+
"type": "muon",
33+
"params": {
34+
"lr": 0.01
35+
}
36+
},
37+
"fp16": {
38+
"enabled": True
39+
},
40+
"zero_optimization": {
41+
"stage": zero_stage,
42+
"reduce_scatter": False,
43+
"offload_optimizer": {
44+
"device": "cpu",
45+
"pin_memory": True,
46+
},
47+
},
48+
}
49+
50+
model = SimpleModel(hidden_dim=hidden_dim, nlayers=5)
51+
engine, optimizer, _, _ = deepspeed.initialize(
52+
config=config_dict,
53+
model=model,
54+
model_parameters=model.parameters(),
55+
dist_init_required=False,
56+
)
57+
58+
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
59+
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
60+
loss = engine(x, y)
61+
engine.backward(loss)
62+
engine.step()
63+
64+
# Muon momentum buffer must exist and be on CPU.
65+
# If muon_update was silently skipped, momentum_buffer would not be created.
66+
flatten_copy = optimizer.optimizer.param_groups[0]['params'][0]
67+
state = optimizer.optimizer.state[flatten_copy]
68+
assert 'momentum_buffer' in state, ("momentum_buffer not found in optimizer state. "
69+
"muon_update was not called in the CPU offload path.")
70+
assert state['momentum_buffer'].device.type == 'cpu', (
71+
f"Momentum buffer is on {state['momentum_buffer'].device}, expected CPU")
72+
73+
74+
@pytest.mark.parametrize('zero_stage', [2])
75+
class TestMuonCPUOffloadCosim(DistributedTest):
76+
77+
def test_cosim_offload_vs_no_offload(self, zero_stage):
78+
"""Verify CPU offload produces results consistent with GPU path.
79+
80+
With the same random seed, offload and non-offload should produce
81+
close parameters. If muon_update is skipped or wrong in either path,
82+
the results diverge significantly.
83+
"""
84+
hidden_dim = 32
85+
batch_size = 8
86+
87+
def train(offload):
88+
torch.manual_seed(42)
89+
config_dict = {
90+
"train_batch_size": batch_size,
91+
"optimizer": {
92+
"type": "muon",
93+
"params": {
94+
"lr": 0.01
95+
}
96+
},
97+
"fp16": {
98+
"enabled": True
99+
},
100+
"zero_optimization": {
101+
"stage": zero_stage,
102+
"reduce_scatter": False,
103+
},
104+
}
105+
if offload:
106+
config_dict["zero_optimization"]["offload_optimizer"] = {
107+
"device": "cpu",
108+
"pin_memory": True,
109+
}
110+
111+
model = SimpleModel(hidden_dim=hidden_dim, nlayers=5)
112+
engine, _, _, _ = deepspeed.initialize(
113+
config=config_dict,
114+
model=model,
115+
model_parameters=model.parameters(),
116+
dist_init_required=False,
117+
)
118+
119+
for _ in range(3):
120+
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
121+
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
122+
loss = engine(x, y)
123+
engine.backward(loss)
124+
engine.step()
125+
126+
return {n: p.clone().detach().float().cpu() for n, p in model.named_parameters()}
127+
128+
params_offload = train(offload=True)
129+
params_no_offload = train(offload=False)
130+
131+
for name in params_offload:
132+
p_off = params_offload[name]
133+
p_no = params_no_offload[name]
134+
# Both paths should produce the same NaN pattern
135+
nan_mask = p_off.isnan() | p_no.isnan()
136+
assert nan_mask.equal(p_off.isnan()), (f"{name}: NaN pattern differs between offload and non-offload. "
137+
"muon_update produced different results.")
138+
# On non-NaN elements, cosine similarity should be very high
139+
valid = ~nan_mask
140+
if valid.sum() > 0:
141+
cos_sim = torch.nn.functional.cosine_similarity(p_off[valid].unsqueeze(0),
142+
p_no[valid].unsqueeze(0)).item()
143+
assert cos_sim > 0.99, (f"{name}: cosine similarity {cos_sim:.4f} between offload and "
144+
f"non-offload is too low, indicating muon_update results diverge.")

0 commit comments

Comments
 (0)