|
3 | 3 | import threading |
4 | 4 | from concurrent.futures import wait |
5 | 5 | from pathlib import Path |
6 | | -from typing import Any, Dict, List, Optional, cast |
| 6 | +from typing import Any, Dict, List, cast |
7 | 7 |
|
8 | 8 | import torch |
9 | 9 | import torch.distributed as dist |
|
17 | 17 | set_model_state_dict, |
18 | 18 | set_optimizer_state_dict, |
19 | 19 | ) |
20 | | -from torch.distributed.device_mesh import DeviceMesh |
21 | 20 | from torch.nn.utils.clip_grad import _no_grad |
22 | 21 | from torch.utils._foreach_utils import ( |
23 | 22 | _device_has_foreach_support, |
|
26 | 25 |
|
27 | 26 | from xtuner.v1.config import FSDPConfig, OptimConfig |
28 | 27 | from xtuner.v1.data_proto.sequence_context import SequenceContext |
29 | | -from xtuner.v1.float8.float8_handler import Float8Handler |
30 | 28 | from xtuner.v1.model.base import BaseModel, ModelItem, XTunerBaseModelConfig |
31 | 29 | from xtuner.v1.model.utils import ModelForwardExtraLogInfo |
32 | 30 | from xtuner.v1.module.router import NoAuxRouterConfig |
@@ -138,7 +136,6 @@ class TrainEngine: |
138 | 136 | model: BaseModel |
139 | 137 | optimizer: torch.optim.Optimizer |
140 | 138 | scheduler: torch.optim.lr_scheduler.LRScheduler |
141 | | - float8_handler: Optional[Float8Handler] |
142 | 139 |
|
143 | 140 | def __init__( |
144 | 141 | self, |
@@ -168,19 +165,10 @@ def build_model(self) -> BaseModel: |
168 | 165 | with torch.device("meta"): |
169 | 166 | model = self.model_cfg.build() |
170 | 167 |
|
171 | | - self.float8_handler = None |
172 | | - if self.model_cfg.float8_cfg is not None and self.model_cfg.float8_cfg.enable_float8: |
173 | | - self.float8_handler = Float8Handler( |
174 | | - scaling_granularity_gemm=self.model_cfg.float8_cfg.scaling_granularity_gemm, |
175 | | - scaling_granularity_grouped_gemm=self.model_cfg.float8_cfg.scaling_granularity_grouped_gemm, |
176 | | - ) |
177 | 168 | model = model.fully_shard(self.fsdp_cfg) |
178 | 169 |
|
179 | 170 | if dist.get_rank() == 0: |
180 | 171 | logger.info(model) |
181 | | - |
182 | | - if self.float8_handler: |
183 | | - self.float8_handler.build_reduce_mesh(model, cast(DeviceMesh, model.fsdp_mesh)) |
184 | 172 | return model |
185 | 173 |
|
186 | 174 | def build_optimizer(self, optim_cfg: OptimConfig) -> torch.optim.Optimizer: |
@@ -218,18 +206,13 @@ def grad_accumulation_steps(self, data_batches_len: int): |
218 | 206 | intra_layer_micro_batch = self.intra_layer_micro_batch |
219 | 207 | return data_batches_len // intra_layer_micro_batch |
220 | 208 |
|
221 | | - # this method can be called outside, e.g., at the beginning of compute_actor_logprobs or compute_ref_logprobs during rl training |
222 | | - def maybe_precompute_float8_dynamic_scale_for_fsdp(self): |
223 | | - if self.float8_handler is not None: |
224 | | - self.float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model) |
225 | | - |
226 | 209 | def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]: |
227 | 210 | """Perform a training step with the given data batches and mesh. |
228 | 211 |
|
229 | 212 | Args: |
230 | 213 | data_batches (List[Dict]): The input data batches for the training step. |
231 | 214 | """ |
232 | | - self.maybe_precompute_float8_dynamic_scale_for_fsdp() |
| 215 | + self._maybe_precompute_float8_dynamic_scale_for_fsdp() |
233 | 216 |
|
234 | 217 | loss_log: LossLog = {} # type: ignore[typeddict-item] |
235 | 218 | other_log: OtherLog = {} # type: ignore[typeddict-item] |
@@ -541,3 +524,8 @@ def put_optimizer_to_device(self, device: torch.device | str): |
541 | 524 | state[key] = val.to(device, non_blocking=True) |
542 | 525 | DEVICE_MODULE.synchronize() |
543 | 526 | return |
| 527 | + |
| 528 | + def _maybe_precompute_float8_dynamic_scale_for_fsdp(self): |
| 529 | + for model in self.model.modules(): |
| 530 | + if isinstance(model, BaseModel) and model.float8_handler is not None: |
| 531 | + model.float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) |
0 commit comments