33import threading
44from concurrent .futures import wait
55from pathlib import Path
6- from typing import Any , Dict , List , Optional , cast
6+ from typing import Any , Dict , List , cast
77
88import torch
99import torch .distributed as dist
1717 set_model_state_dict ,
1818 set_optimizer_state_dict ,
1919)
20- from torch .distributed .device_mesh import DeviceMesh
2120from torch .nn .utils .clip_grad import _no_grad
2221from torch .utils ._foreach_utils import (
2322 _device_has_foreach_support ,
2625
2726from xtuner .v1 .config import FSDPConfig , OptimConfig
2827from xtuner .v1 .data_proto .sequence_context import SequenceContext
29- from xtuner .v1 .float8 .float8_handler import Float8Handler
3028from xtuner .v1 .model .base import BaseModel , ModelItem , XTunerBaseModelConfig
3129from xtuner .v1 .model .utils import ModelForwardExtraLogInfo
3230from xtuner .v1 .module .router import NoAuxRouterConfig
@@ -138,7 +136,6 @@ class TrainEngine:
138136 model : BaseModel
139137 optimizer : torch .optim .Optimizer
140138 scheduler : torch .optim .lr_scheduler .LRScheduler
141- float8_handler : Optional [Float8Handler ]
142139
143140 def __init__ (
144141 self ,
@@ -168,19 +165,10 @@ def build_model(self) -> BaseModel:
168165 with torch .device ("meta" ):
169166 model = self .model_cfg .build ()
170167
171- self .float8_handler = None
172- if self .model_cfg .float8_cfg is not None and self .model_cfg .float8_cfg .enable_float8 :
173- self .float8_handler = Float8Handler (
174- scaling_granularity_gemm = self .model_cfg .float8_cfg .scaling_granularity_gemm ,
175- scaling_granularity_grouped_gemm = self .model_cfg .float8_cfg .scaling_granularity_grouped_gemm ,
176- )
177168 model = model .fully_shard (self .fsdp_cfg )
178169
179170 if dist .get_rank () == 0 :
180171 logger .info (model )
181-
182- if self .float8_handler :
183- self .float8_handler .build_reduce_mesh (model , cast (DeviceMesh , model .fsdp_mesh ))
184172 return model
185173
186174 def build_optimizer (self , optim_cfg : OptimConfig ) -> torch .optim .Optimizer :
@@ -200,18 +188,13 @@ def grad_accumulation_steps(self, data_batches_len: int):
200188 intra_layer_micro_batch = self .intra_layer_micro_batch
201189 return data_batches_len // intra_layer_micro_batch
202190
203- # this method can be called outside, e.g., at the beginning of compute_actor_logprobs or compute_ref_logprobs during rl training
204- def maybe_precompute_float8_dynamic_scale_for_fsdp (self ):
205- if self .float8_handler is not None :
206- self .float8_handler .precompute_float8_dynamic_scale_for_fsdp (self .model )
207-
208191 def train_step (self , data_batches : list [ModelItem ]) -> tuple [LossLog , OtherLog ]:
209192 """Perform a training step with the given data batches and mesh.
210193
211194 Args:
212195 data_batches (List[Dict]): The input data batches for the training step.
213196 """
214- self .maybe_precompute_float8_dynamic_scale_for_fsdp ()
197+ self ._maybe_precompute_float8_dynamic_scale_for_fsdp ()
215198
216199 loss_log : LossLog = {} # type: ignore[typeddict-item]
217200 other_log : OtherLog = {} # type: ignore[typeddict-item]
@@ -523,3 +506,8 @@ def put_optimizer_to_device(self, device: torch.device | str):
523506 state [key ] = val .to (device , non_blocking = True )
524507 DEVICE_MODULE .synchronize ()
525508 return
509+
510+ def _maybe_precompute_float8_dynamic_scale_for_fsdp (self ):
511+ for model in self .model .modules ():
512+ if isinstance (model , BaseModel ) and model .float8_handler is not None :
513+ model .float8_handler .precompute_float8_dynamic_scale_for_fsdp (model )
0 commit comments