99
1010from xtuner .v1 .optim import Muon
1111from xtuner .v1 .utils import get_logger
12-
12+ import types
1313
1414logger = get_logger ()
1515
16-
1716class OptimConfig (BaseModel ):
1817 model_config = ConfigDict (extra = "forbid" )
1918 lr : Annotated [float , Parameter (help = "Learning rate for optimization" )] = 1e-5
@@ -32,6 +31,7 @@ class AdamWConfig(OptimConfig):
3231 betas : Annotated [Tuple [float , float ], Parameter (help = "Beta coefficients for Adam optimizer" )] = (0.9 , 0.95 )
3332 eps : Annotated [float , Parameter (help = "Epsilon value for numerical stability in Adam optimizer" )] = 1e-8
3433 foreach : Annotated [Optional [bool ], Parameter (help = "Use foreach implementation for AdamW" )] = None
34+ swap_optimizer : Annotated [Optional [bool ], Parameter (help = "Swap optimizer states to host memory." )] = False
3535
3636 def build (self , model ):
3737 params = [p for p in model .parameters () if p .requires_grad ]
@@ -52,10 +52,13 @@ def build(self, model):
5252 f"Total trainable parameters: { num_total_requires_grad // 1e6 } M, total parameters: { num_total // 1e6 } M"
5353 )
5454 logger .info (f"Untrainable parameters names: { untrainable_names } " )
55- return torch .optim .AdamW (
55+ optimizer = torch .optim .AdamW (
5656 params , lr = self .lr , betas = self .betas , eps = self .eps , weight_decay = self .weight_decay , foreach = self .foreach
57- )
58-
57+ )
58+ if self .swap_optimizer :
59+ SwapOptimizerOperate (optimizer ).opt_states_initialization ()
60+ optimizer .step = types .MethodType (swap_adamw_step , optimizer )
61+ return optimizer
5962
6063class MuonConfig (OptimConfig ):
6164 weight_decay : Annotated [float , Parameter (help = "Weight decay coefficient for L2 regularization" )] = 0.1
@@ -134,3 +137,156 @@ class LRConfig(BaseModel):
134137 )
135138 warmup_ratio : Annotated [float , Parameter (help = "Ratio of warmup steps to total training steps" )] = 0.03
136139 lr_min : Annotated [float , Parameter (help = "Minimum learning rate for optimization" )] = 1e-6
140+
141+ class SwapOptimizerOperate ():
142+
143+ swap_to_device_stream = None
144+ swap_to_host_stream = None
145+
146+ swap_to_device_events_map = {}
147+ swap_to_host_events_map = {}
148+
149+ param_to_cpu_states_map = {}
150+ param_to_device_states_map = {}
151+
152+ state_keys = ['exp_avg' , 'exp_avg_sq' , 'max_exp_avg_sq' ]
153+
154+ def __init__ (self , optimizer , swap_optimizer_times = 16 ):
155+ self .optimizer = optimizer
156+ self .swap_optimizer_times = swap_optimizer_times
157+ if SwapOptimizerOperate .swap_to_device_stream is None :
158+ SwapOptimizerOperate .swap_to_device_stream = torch .npu .Stream ()
159+ SwapOptimizerOperate .swap_to_host_stream = torch .npu .Stream ()
160+
161+ # create all parameters list for step
162+ self .optimizer .param_to_group_map = {}
163+
164+ for group in self .optimizer .param_groups :
165+ for p in group ['params' ]:
166+ self .optimizer .param_to_group_map [p ] = group
167+
168+ # print swap param num and size
169+ swap_num = sum ([main_param .to_local ().numel () for main_param in self .optimizer .param_to_group_map ])
170+ swap_numel = swap_num // self .swap_optimizer_times
171+ self .optimizer .swap_numel = swap_numel
172+
173+ swap_memory = swap_num * 8 / 1024 / 1024
174+ print ('[Rank {}] swap optimizer param num: {}, param size: {}MB\n ' .format (torch .npu .current_device (), swap_num , swap_memory ), end = '' )
175+
176+ def opt_states_initialization (self ):
177+ for group in self .optimizer .param_groups :
178+ for param in group ["params" ]:
179+ device_state_dtensor = self .optimizer .state [param ]
180+ device_state_tensor = {}
181+ cpu_state = {}
182+
183+ amsgrad = self .optimizer .param_to_group_map [param ]['amsgrad' ]
184+
185+ for key in self .state_keys :
186+ if key == 'max_exp_avg_sq' and not amsgrad :
187+ device_state_dtensor [key ] = None
188+ device_state_tensor [key ] = None
189+ cpu_state [key ] = None
190+ else :
191+ device_state_dtensor [key ] = torch .zeros_like (param , memory_format = torch .preserve_format )
192+ # convert dtensor to tensor
193+ device_state_tensor [key ] = device_state_dtensor [key ].to_local ()
194+
195+ cpu_state [key ] = torch .empty_like (device_state_tensor [key ], pin_memory = True , device = 'cpu' )
196+ cpu_state [key ].copy_ (device_state_tensor [key ], non_blocking = True )
197+
198+ device_state_tensor [key ].storage ().resize_ (0 )
199+
200+ self .param_to_device_states_map [param ] = device_state_tensor
201+ self .param_to_cpu_states_map [param ] = cpu_state
202+ torch .npu .synchronize ()
203+
204+ @classmethod
205+ def swap_all_to_host (cls ):
206+ for param in cls .param_to_cpu_states_map .keys ():
207+ cls .swap_tensors_to_host (param )
208+ for param in cls .param_to_cpu_states_map .keys ():
209+ event = cls .swap_to_host_events_map .get (param , None )
210+ if event is not None :
211+ torch .npu .current_stream ().wait_event (event )
212+ cls .swap_to_host_events_map [param ] = None
213+
214+ @classmethod
215+ def swap_all_to_device (cls ):
216+ for param in cls .param_to_cpu_states_map .keys ():
217+ cls .swap_tensors_to_device (param )
218+ for param in cls .param_to_cpu_states_map .keys ():
219+ event = cls .swap_to_device_events_map .get (param , None )
220+ if event is not None :
221+ torch .npu .current_stream ().wait_event (event )
222+ cls .swap_to_device_events_map [param ] = None
223+
224+ @classmethod
225+ def swap_tensors_to_device (cls , param ):
226+ cpu_state = cls .param_to_cpu_states_map [param ]
227+
228+ if param in cls .param_to_device_states_map :
229+ device_state = cls .param_to_device_states_map [param ]
230+ for key in cls .state_keys :
231+ if device_state [key ] is not None and device_state [key ].storage ().size () == 0 :
232+ device_state [key ].storage ().resize_ (cpu_state [key ].storage ().size ())
233+ device_state [key ].copy_ (cpu_state [key ], non_blocking = True )
234+
235+ cls .swap_to_device_events_map [param ] = torch .npu .current_stream ().record_event ()
236+
237+ @classmethod
238+ def wait_swap_to_device_event (cls , param ):
239+ event = cls .swap_to_device_events_map .get (param , None )
240+ if event is not None :
241+ torch .npu .current_stream ().wait_event (event )
242+ cls .swap_to_device_events_map [param ] = None
243+
244+ @classmethod
245+ def swap_tensors_to_host (cls , param ):
246+ cpu_state = cls .param_to_cpu_states_map [param ]
247+
248+ if param in cls .param_to_device_states_map :
249+ device_state = cls .param_to_device_states_map [param ]
250+ for key in cls .state_keys :
251+ if key in device_state and device_state [key ] is not None and device_state [key ].storage ().size () != 0 :
252+ cpu_state [key ].copy_ (device_state [key ], non_blocking = True )
253+ device_state [key ].storage ().resize_ (0 )
254+
255+ cls .swap_to_host_events_map [param ] = torch .npu .current_stream ().record_event ()
256+
257+ def swap_adamw_step (self , closure = None ):
258+ loss = None
259+ if closure is not None :
260+ with torch .enable_grad ():
261+ loss = closure ()
262+
263+ for group in self .param_groups :
264+ if 'step' in group :
265+ group ['step' ] += 1
266+ if group ['step' ].is_cpu :
267+ group ['step' ] = group ['step' ].npu ()
268+ else :
269+ group ['step' ] = torch .tensor (1 , dtype = torch .int64 , device = torch .npu .current_device ())
270+
271+ params_list = list (self .param_to_group_map .keys ())
272+
273+ SwapOptimizerOperate .swap_all_to_device ()
274+
275+ for i , param in enumerate (params_list ):
276+ if param .grad is None :
277+ continue
278+ if param .grad .is_sparse :
279+ raise RuntimeError ('AdamW does not support sparse gradients' )
280+
281+ group = self .param_to_group_map [param ]
282+ amsgrad = group ['amsgrad' ]
283+ beta1 , beta2 = group ['betas' ]
284+ state = self .state [param ]
285+
286+ torch ._fused_adamw_ ([param .to_local ()], [param .grad .to_local ()], [state ['exp_avg' ].to_local ()], [state ['exp_avg_sq' ].to_local ()], [state ['max_exp_avg_sq' ]] if amsgrad else [],
287+ [group ['step' ]], amsgrad = amsgrad , lr = group ['lr' ], beta1 = beta1 , beta2 = beta2 , weight_decay = group ['weight_decay' ],
288+ eps = group ['eps' ], maximize = group ['maximize' ])
289+
290+ # it maybe removed
291+ torch .npu .synchronize ()
292+ return loss
0 commit comments