Skip to content

Commit fc23567

Browse files
committed
[Refactor] Move float8_handler initialization from TrainEngine to BaseModel
- Remove float8_handler as a direct attribute of TrainEngine - Add float8_handler as a lazy-initialized property in BaseModel - Move Float8Handler.build() logic to Float8Config.build() - Update _maybe_precompute_float8_dynamic_scale_for_fsdp to iterate through model modules ghstack-source-id: fdad8ec Pull-Request: InternLM#1517
1 parent 3067930 commit fc23567

5 files changed

Lines changed: 30 additions & 21 deletions

File tree

xtuner/v1/engine/train_engine.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import threading
44
from concurrent.futures import wait
55
from pathlib import Path
6-
from typing import Any, Dict, List, Optional, cast
6+
from typing import Any, Dict, List, cast
77

88
import torch
99
import torch.distributed as dist
@@ -17,7 +17,6 @@
1717
set_model_state_dict,
1818
set_optimizer_state_dict,
1919
)
20-
from torch.distributed.device_mesh import DeviceMesh
2120
from torch.nn.utils.clip_grad import _no_grad
2221
from torch.utils._foreach_utils import (
2322
_device_has_foreach_support,
@@ -26,7 +25,6 @@
2625

2726
from xtuner.v1.config import FSDPConfig, OptimConfig
2827
from xtuner.v1.data_proto.sequence_context import SequenceContext
29-
from xtuner.v1.float8.float8_handler import Float8Handler
3028
from xtuner.v1.model.base import BaseModel, ModelItem, XTunerBaseModelConfig
3129
from xtuner.v1.model.utils import ModelForwardExtraLogInfo
3230
from xtuner.v1.module.router import NoAuxRouterConfig
@@ -138,7 +136,6 @@ class TrainEngine:
138136
model: BaseModel
139137
optimizer: torch.optim.Optimizer
140138
scheduler: torch.optim.lr_scheduler.LRScheduler
141-
float8_handler: Optional[Float8Handler]
142139

143140
def __init__(
144141
self,
@@ -168,19 +165,10 @@ def build_model(self) -> BaseModel:
168165
with torch.device("meta"):
169166
model = self.model_cfg.build()
170167

171-
self.float8_handler = None
172-
if self.model_cfg.float8_cfg is not None and self.model_cfg.float8_cfg.enable_float8:
173-
self.float8_handler = Float8Handler(
174-
scaling_granularity_gemm=self.model_cfg.float8_cfg.scaling_granularity_gemm,
175-
scaling_granularity_grouped_gemm=self.model_cfg.float8_cfg.scaling_granularity_grouped_gemm,
176-
)
177168
model = model.fully_shard(self.fsdp_cfg)
178169

179170
if dist.get_rank() == 0:
180171
logger.info(model)
181-
182-
if self.float8_handler:
183-
self.float8_handler.build_reduce_mesh(model, cast(DeviceMesh, model.fsdp_mesh))
184172
return model
185173

186174
def build_optimizer(self, optim_cfg: OptimConfig) -> torch.optim.Optimizer:
@@ -218,18 +206,13 @@ def grad_accumulation_steps(self, data_batches_len: int):
218206
intra_layer_micro_batch = self.intra_layer_micro_batch
219207
return data_batches_len // intra_layer_micro_batch
220208

221-
# this method can be called outside, e.g., at the beginning of compute_actor_logprobs or compute_ref_logprobs during rl training
222-
def maybe_precompute_float8_dynamic_scale_for_fsdp(self):
223-
if self.float8_handler is not None:
224-
self.float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model)
225-
226209
def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
227210
"""Perform a training step with the given data batches and mesh.
228211
229212
Args:
230213
data_batches (List[Dict]): The input data batches for the training step.
231214
"""
232-
self.maybe_precompute_float8_dynamic_scale_for_fsdp()
215+
self._maybe_precompute_float8_dynamic_scale_for_fsdp()
233216

234217
loss_log: LossLog = {} # type: ignore[typeddict-item]
235218
other_log: OtherLog = {} # type: ignore[typeddict-item]
@@ -541,3 +524,8 @@ def put_optimizer_to_device(self, device: torch.device | str):
541524
state[key] = val.to(device, non_blocking=True)
542525
DEVICE_MODULE.synchronize()
543526
return
527+
528+
def _maybe_precompute_float8_dynamic_scale_for_fsdp(self):
529+
for model in self.model.modules():
530+
if isinstance(model, BaseModel) and model.float8_handler is not None:
531+
model.float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)

xtuner/v1/float8/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,11 @@ def is_tilewise(self) -> bool:
4949
def is_tensorwise(self) -> bool:
5050
"""Whether the scaling granularity is TENSORWISE."""
5151
return self.scaling_granularity_gemm == ScalingGranularity.TENSORWISE
52+
53+
def build(self):
54+
from .float8_handler import Float8Handler
55+
56+
return Float8Handler(
57+
scaling_granularity_gemm=self.scaling_granularity_gemm,
58+
scaling_granularity_grouped_gemm=self.scaling_granularity_grouped_gemm,
59+
)

xtuner/v1/model/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,19 @@ def compile_cfg(self) -> dict[str, TorchCompileOption]:
394394

395395
return _compile_cfg
396396

397+
@property
398+
def float8_handler(self):
399+
if (
400+
self.config.float8_cfg is not None
401+
and self.config.float8_cfg.enable_float8
402+
and self._float8_handler is None
403+
):
404+
self._float8_handler = self.config.float8_cfg.build()
405+
406+
if self.fsdp_mesh is not None:
407+
self._float8_handler.build_reduce_mesh(self, self.fsdp_mesh)
408+
return self._float8_handler
409+
397410
@torch.no_grad()
398411
def init_weights(self):
399412
# TODO: HardCode here. The initialization method should be module specific. All module in model

xtuner/v1/model/moe/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from torch.distributed.tensor import DTensor, Replicate, distribute_tensor
2424
from tqdm import tqdm
25-
from typing_extensions import NotRequired, overload, override
25+
from typing_extensions import overload, override
2626

2727
from xtuner.v1.config import FSDPConfig
2828
from xtuner.v1.data_proto import SequenceContext

xtuner/v1/rl/base/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def compute_actor_logprobs(
388388
shifted_labels_list: list[torch.Tensor],
389389
) -> list[torch.Tensor]:
390390
# precompute float8 dynamic scale only once
391-
self._engine.maybe_precompute_float8_dynamic_scale_for_fsdp()
391+
self._engine._maybe_precompute_float8_dynamic_scale_for_fsdp()
392392
old_logprobs_list: list[torch.Tensor] = []
393393
for seq_ctx, shifted_labels in zip(seq_ctx_list, shifted_labels_list):
394394
output = self._engine.forward_only(seq_ctx=seq_ctx)

0 commit comments

Comments
 (0)