Skip to content

Commit 0242b59

Browse files
author
wentiange
committed
[Feature] offload optimizer states to CPU (reduce NPU memory, trade-off performance)
1 parent c0eba14 commit 0242b59

3 files changed

Lines changed: 196 additions & 6 deletions

File tree

xtuner/v1/config/optim.py

Lines changed: 161 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99

1010
from xtuner.v1.optim import Muon
1111
from xtuner.v1.utils import get_logger
12-
12+
import types
1313

1414
logger = get_logger()
1515

16-
1716
class OptimConfig(BaseModel):
1817
model_config = ConfigDict(extra="forbid")
1918
lr: Annotated[float, Parameter(help="Learning rate for optimization")] = 1e-5
@@ -32,6 +31,7 @@ class AdamWConfig(OptimConfig):
3231
betas: Annotated[Tuple[float, float], Parameter(help="Beta coefficients for Adam optimizer")] = (0.9, 0.95)
3332
eps: Annotated[float, Parameter(help="Epsilon value for numerical stability in Adam optimizer")] = 1e-8
3433
foreach: Annotated[Optional[bool], Parameter(help="Use foreach implementation for AdamW")] = None
34+
swap_optimizer: Annotated[Optional[bool], Parameter(help="Swap optimizer states to host memory.")] = False
3535

3636
def build(self, model):
3737
params = [p for p in model.parameters() if p.requires_grad]
@@ -52,10 +52,13 @@ def build(self, model):
5252
f"Total trainable parameters: {num_total_requires_grad // 1e6}M, total parameters: {num_total // 1e6}M"
5353
)
5454
logger.info(f"Untrainable parameters names: {untrainable_names}")
55-
return torch.optim.AdamW(
55+
optimizer = torch.optim.AdamW(
5656
params, lr=self.lr, betas=self.betas, eps=self.eps, weight_decay=self.weight_decay, foreach=self.foreach
57-
)
58-
57+
)
58+
if self.swap_optimizer:
59+
SwapOptimizerOperate(optimizer).opt_states_initialization()
60+
optimizer.step = types.MethodType(swap_adamw_step, optimizer)
61+
return optimizer
5962

6063
class MuonConfig(OptimConfig):
6164
weight_decay: Annotated[float, Parameter(help="Weight decay coefficient for L2 regularization")] = 0.1
@@ -134,3 +137,156 @@ class LRConfig(BaseModel):
134137
)
135138
warmup_ratio: Annotated[float, Parameter(help="Ratio of warmup steps to total training steps")] = 0.03
136139
lr_min: Annotated[float, Parameter(help="Minimum learning rate for optimization")] = 1e-6
140+
141+
class SwapOptimizerOperate():
142+
143+
swap_to_device_stream = None
144+
swap_to_host_stream = None
145+
146+
swap_to_device_events_map = {}
147+
swap_to_host_events_map = {}
148+
149+
param_to_cpu_states_map = {}
150+
param_to_device_states_map = {}
151+
152+
state_keys = ['exp_avg', 'exp_avg_sq', 'max_exp_avg_sq']
153+
154+
def __init__(self, optimizer, swap_optimizer_times=16):
155+
self.optimizer = optimizer
156+
self.swap_optimizer_times = swap_optimizer_times
157+
if SwapOptimizerOperate.swap_to_device_stream is None:
158+
SwapOptimizerOperate.swap_to_device_stream = torch.npu.Stream()
159+
SwapOptimizerOperate.swap_to_host_stream = torch.npu.Stream()
160+
161+
# create all parameters list for step
162+
self.optimizer.param_to_group_map = {}
163+
164+
for group in self.optimizer.param_groups:
165+
for p in group['params']:
166+
self.optimizer.param_to_group_map[p] = group
167+
168+
# print swap param num and size
169+
swap_num = sum([main_param.to_local().numel() for main_param in self.optimizer.param_to_group_map])
170+
swap_numel = swap_num // self.swap_optimizer_times
171+
self.optimizer.swap_numel = swap_numel
172+
173+
swap_memory = swap_num * 8 / 1024 / 1024
174+
print('[Rank {}] swap optimizer param num: {}, param size: {}MB\n'.format(torch.npu.current_device(), swap_num, swap_memory), end='')
175+
176+
def opt_states_initialization(self):
177+
for group in self.optimizer.param_groups:
178+
for param in group["params"]:
179+
device_state_dtensor = self.optimizer.state[param]
180+
device_state_tensor = {}
181+
cpu_state = {}
182+
183+
amsgrad = self.optimizer.param_to_group_map[param]['amsgrad']
184+
185+
for key in self.state_keys:
186+
if key == 'max_exp_avg_sq' and not amsgrad:
187+
device_state_dtensor[key] = None
188+
device_state_tensor[key] = None
189+
cpu_state[key] = None
190+
else:
191+
device_state_dtensor[key] = torch.zeros_like(param, memory_format=torch.preserve_format)
192+
# convert dtensor to tensor
193+
device_state_tensor[key] = device_state_dtensor[key].to_local()
194+
195+
cpu_state[key] = torch.empty_like(device_state_tensor[key], pin_memory=True, device='cpu')
196+
cpu_state[key].copy_(device_state_tensor[key], non_blocking=True)
197+
198+
device_state_tensor[key].storage().resize_(0)
199+
200+
self.param_to_device_states_map[param] = device_state_tensor
201+
self.param_to_cpu_states_map[param] = cpu_state
202+
torch.npu.synchronize()
203+
204+
@classmethod
205+
def swap_all_to_host(cls):
206+
for param in cls.param_to_cpu_states_map.keys():
207+
cls.swap_tensors_to_host(param)
208+
for param in cls.param_to_cpu_states_map.keys():
209+
event = cls.swap_to_host_events_map.get(param, None)
210+
if event is not None:
211+
torch.npu.current_stream().wait_event(event)
212+
cls.swap_to_host_events_map[param] = None
213+
214+
@classmethod
215+
def swap_all_to_device(cls):
216+
for param in cls.param_to_cpu_states_map.keys():
217+
cls.swap_tensors_to_device(param)
218+
for param in cls.param_to_cpu_states_map.keys():
219+
event = cls.swap_to_device_events_map.get(param, None)
220+
if event is not None:
221+
torch.npu.current_stream().wait_event(event)
222+
cls.swap_to_device_events_map[param] = None
223+
224+
@classmethod
225+
def swap_tensors_to_device(cls, param):
226+
cpu_state = cls.param_to_cpu_states_map[param]
227+
228+
if param in cls.param_to_device_states_map:
229+
device_state = cls.param_to_device_states_map[param]
230+
for key in cls.state_keys:
231+
if device_state[key] is not None and device_state[key].storage().size() == 0:
232+
device_state[key].storage().resize_(cpu_state[key].storage().size())
233+
device_state[key].copy_(cpu_state[key], non_blocking=True)
234+
235+
cls.swap_to_device_events_map[param] = torch.npu.current_stream().record_event()
236+
237+
@classmethod
238+
def wait_swap_to_device_event(cls, param):
239+
event = cls.swap_to_device_events_map.get(param, None)
240+
if event is not None:
241+
torch.npu.current_stream().wait_event(event)
242+
cls.swap_to_device_events_map[param] = None
243+
244+
@classmethod
245+
def swap_tensors_to_host(cls, param):
246+
cpu_state = cls.param_to_cpu_states_map[param]
247+
248+
if param in cls.param_to_device_states_map:
249+
device_state = cls.param_to_device_states_map[param]
250+
for key in cls.state_keys:
251+
if key in device_state and device_state[key] is not None and device_state[key].storage().size() != 0:
252+
cpu_state[key].copy_(device_state[key], non_blocking=True)
253+
device_state[key].storage().resize_(0)
254+
255+
cls.swap_to_host_events_map[param] = torch.npu.current_stream().record_event()
256+
257+
def swap_adamw_step(self, closure=None):
258+
loss = None
259+
if closure is not None:
260+
with torch.enable_grad():
261+
loss = closure()
262+
263+
for group in self.param_groups:
264+
if 'step' in group:
265+
group['step'] += 1
266+
if group['step'].is_cpu:
267+
group['step'] = group['step'].npu()
268+
else:
269+
group['step'] = torch.tensor(1, dtype=torch.int64, device=torch.npu.current_device())
270+
271+
params_list = list(self.param_to_group_map.keys())
272+
273+
SwapOptimizerOperate.swap_all_to_device()
274+
275+
for i, param in enumerate(params_list):
276+
if param.grad is None:
277+
continue
278+
if param.grad.is_sparse:
279+
raise RuntimeError('AdamW does not support sparse gradients')
280+
281+
group = self.param_to_group_map[param]
282+
amsgrad = group['amsgrad']
283+
beta1, beta2 = group['betas']
284+
state = self.state[param]
285+
286+
torch._fused_adamw_([param.to_local()], [param.grad.to_local()], [state['exp_avg'].to_local()], [state['exp_avg_sq'].to_local()], [state['max_exp_avg_sq']] if amsgrad else [],
287+
[group['step']], amsgrad=amsgrad, lr=group['lr'], beta1=beta1, beta2=beta2, weight_decay=group['weight_decay'],
288+
eps=group['eps'], maximize=group['maximize'])
289+
290+
# it maybe removed
291+
torch.npu.synchronize()
292+
return loss

xtuner/v1/engine/train_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def put_model_to_device(self, device: torch.device | str):
511511

512512
def put_optimizer_to_device(self, device: torch.device | str):
513513
"""Put the optimizer to the given device."""
514-
if self.fsdp_cfg.cpu_offload:
514+
if self.fsdp_cfg.cpu_offload or self.optim_cfg.swap_optimizer:
515515
return
516516
if not self.optimizer.state:
517517
return

xtuner/v1/train/trainer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,13 @@ def fit(self):
737737
grad_norm = self._engine.clip_grad_norm(do_clip=self._do_clip, dtype=self._grad_norm_dtype)
738738
self._engine.step_optimizer(grad_norm)
739739

740+
if self._optim_config.swap_optimizer:
741+
is_save_dcp = (self._is_save_dcp(is_snapshot=False) or self._is_save_dcp(is_snapshot=True))
742+
if not is_save_dcp:
743+
from xtuner.v1.config.optim import SwapOptimizerOperate
744+
SwapOptimizerOperate.swap_all_to_host()
745+
torch.npu.synchronize()
746+
740747
time_after_train_step = time.time()
741748
ProberList.after_step()
742749

@@ -1092,6 +1099,22 @@ def _maybe_check_health(self):
10921099
raise RuntimeError("Health check failed, exit training")
10931100
logger.info(f"Health check passed at step {self.cur_step}")
10941101

1102+
1103+
def _is_save_dcp(self, is_snapshot: bool = False) -> bool:
1104+
ckp_interval = self._checkpoint_interval if not is_snapshot else self._snapshot_interval
1105+
cur_step = self._cur_step + 1
1106+
if ckp_interval is None:
1107+
return False
1108+
1109+
if ckp_interval == -1: # only save at the end of training
1110+
if cur_step != self.total_step:
1111+
return False
1112+
else:
1113+
if cur_step % ckp_interval != 0 and (is_snapshot or cur_step != self.total_step):
1114+
# if is_snapshot, only save at interval
1115+
# else save at interval or at the end of training
1116+
return False
1117+
return True
10951118
def _maybe_save(self, is_snapshot: bool = False) -> bool:
10961119
ckp_interval = self._checkpoint_interval if not is_snapshot else self._snapshot_interval
10971120
if ckp_interval is None:
@@ -1123,6 +1146,12 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11231146
optimizer_dir=optimizer_path,
11241147
)
11251148

1149+
if self._optim_config.swap_optimizer:
1150+
torch.npu.synchronize()
1151+
from xtuner.v1.config.optim import SwapOptimizerOperate
1152+
SwapOptimizerOperate.swap_all_to_host()
1153+
torch.npu.synchronize()
1154+
11261155
# Save dataloader
11271156
self._save_dataloader(dataloader_path)
11281157

@@ -1759,6 +1788,11 @@ def _load_checkpoint(self):
17591788
else None
17601789
)
17611790

1791+
if self._optim_config.swap_optimizer:
1792+
from xtuner.v1.config.optim import SwapOptimizerOperate
1793+
SwapOptimizerOperate.swap_all_to_device()
1794+
1795+
17621796
self._engine.load_dcp(
17631797
model_dir=model_path,
17641798
optimizer_dir=optimizer_path,

0 commit comments

Comments
 (0)