Skip to content

Commit c208b26

Browse files
Train with cpu only (#443)
Co-authored-by: Oleksiy Ostapenko <ostapy2@gmail.com>
1 parent 17cd7ed commit c208b26

47 files changed

Lines changed: 385 additions & 287 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

fast_llm/core/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,11 @@ def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, ta
185185
@contextlib.contextmanager
186186
def set_generator(generator: torch.Generator) -> typing.Generator[None, None, None]:
187187
"""Use the generator as default, for ops that don't support a generator argument."""
188-
default_generator: torch.Generator = torch.cuda.default_generators[torch.cuda.current_device()]
188+
default_generator: torch.Generator = (
189+
torch.cuda.default_generators[generator.device.index]
190+
if generator.device.type == "cuda"
191+
else torch.default_generator
192+
)
189193
assert generator is not default_generator
190194
old_state = default_generator.get_state()
191195
default_generator.set_state(generator.get_state())

fast_llm/core/kernels.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,35 @@
1212
from amp_C import multi_tensor_scale as _multi_tensor_scale # noqa
1313
from apex.multi_tensor_apply import multi_tensor_applier as _multi_tensor_applier # noqa
1414

15-
_apex_available = True
15+
_apex_available = torch.cuda.is_available()
1616
except ImportError:
1717
_apex_available = False
1818

1919

2020
def l2_norm(tensors: list[torch.Tensor], noop_flag: torch.Tensor) -> torch.Tensor:
21-
assert _apex_available
22-
norm, _ = _multi_tensor_applier(
23-
_multi_tensor_l2norm,
24-
noop_flag,
25-
[tensors],
26-
False, # no per-parameter norm
27-
)
21+
if _apex_available:
22+
norm, _ = _multi_tensor_applier(
23+
_multi_tensor_l2norm,
24+
noop_flag,
25+
[tensors],
26+
False, # no per-parameter norm
27+
)
28+
else:
29+
norm = sum(torch.norm(tensor) ** 2 for tensor in tensors) ** 0.5
2830
return norm
2931

3032

3133
def scale_(tensors: list[torch.Tensor], noop_flag: torch.Tensor, scale: torch.Tensor | float) -> None:
32-
assert _apex_available
33-
_multi_tensor_applier(
34-
_multi_tensor_scale,
35-
noop_flag,
36-
[tensors, tensors],
37-
scale,
38-
)
34+
if _apex_available:
35+
_multi_tensor_applier(
36+
_multi_tensor_scale,
37+
noop_flag,
38+
[tensors, tensors],
39+
scale,
40+
)
41+
else:
42+
for tensor in tensors:
43+
tensor.mul_(scale)
3944

4045

4146
# TODO: Same as torch._fused_adam_?
@@ -52,16 +57,35 @@ def fused_adam(
5257
eps: float,
5358
step: int,
5459
) -> None:
55-
_multi_tensor_applier(
56-
_multi_tensor_adam,
57-
noop_flag,
58-
[grads, params, exp_avgs, exp_avg_sqs],
59-
lr,
60-
beta1,
61-
beta2,
62-
eps,
63-
step,
64-
1, # adamw
65-
1, # bias correction
66-
wd,
67-
)
60+
if _apex_available:
61+
_multi_tensor_applier(
62+
_multi_tensor_adam,
63+
noop_flag,
64+
[grads, params, exp_avgs, exp_avg_sqs],
65+
lr,
66+
beta1,
67+
beta2,
68+
eps,
69+
step,
70+
1, # adamw
71+
1, # bias correction
72+
wd,
73+
)
74+
else:
75+
import torch.optim.adamw as adamw
76+
77+
adamw.adamw(
78+
params,
79+
grads,
80+
exp_avgs,
81+
exp_avg_sqs,
82+
None,
83+
lr=lr,
84+
beta1=beta1,
85+
beta2=beta2,
86+
eps=eps,
87+
state_steps=torch.full([len(params)], step, dtype=torch.int64, device=params[0].device).unbind(),
88+
weight_decay=wd,
89+
amsgrad=False,
90+
maximize=False,
91+
)

fast_llm/engine/checkpoint/convert.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig
99
from fast_llm.engine.config_utils.runnable import RunnableConfig
1010
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode
11-
from fast_llm.functional.config import TritonConfig
1211
from fast_llm.utils import Assert
1312

1413
if typing.TYPE_CHECKING:
@@ -64,8 +63,8 @@ def _convert_model_partial(
6463
logger.info(f"Loading {self.input.format} checkpoint from {self.input.path}...")
6564
model = model_class.from_pretrained(
6665
self.input,
66+
{("distributed", "use_cuda"): not self.use_cpu},
6767
mode=StageMode.weights,
68-
use_cpu=self.use_cpu,
6968
stage_filter=stage_filter,
7069
)
7170
logger.info(f"Saving {output.format} checkpoint to {output.path}...")
@@ -78,9 +77,6 @@ def run(self):
7877
# TODO: Set logging in tests
7978
logging.getLogger().setLevel(logging.INFO)
8079
self.to_logs()
81-
# Disable Triton to convert model on CPU
82-
if self.use_cpu:
83-
TritonConfig.TRITON_ENABLED = False
8480
# Skip on exist_ok=False if the model has already been processed
8581
if not self.exist_ok and (self.output.path / "ok").exists():
8682
logger.info(
@@ -100,8 +96,8 @@ def run(self):
10096
# Create a dummy version to determine the stage split.
10197
model = model_class.from_pretrained(
10298
self.input.to_copy({"model_weights": False}),
99+
{("distributed", "use_cuda"): not self.use_cpu},
103100
mode=StageMode.off_device,
104-
use_cpu=self.use_cpu,
105101
)
106102
stages_per_step = math.ceil(self.layers_per_step / model._config.multi_stage.layers_per_stage)
107103
num_stages = len(model.stages)

fast_llm/engine/config_utils/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def configure_logging(
101101
def get_run(self, distributed: "Distributed") -> "Run":
102102
from fast_llm.functional.config import TritonConfig
103103

104-
TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels
104+
TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels # and distributed.config.use_cuda
105105
TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels
106106
run = Run(config=self, distributed=distributed)
107107
set_global_variables(not self.run.torch_dynamo_enable)

fast_llm/engine/distributed/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ class DistributedConfig(Config):
227227
hint=FieldHint.optional,
228228
valid=check_field(Assert.gt, 0),
229229
)
230+
use_cuda: bool = Field(
231+
default=True,
232+
desc="Enable CUDA device.",
233+
hint=FieldHint.expert,
234+
)
230235
seed: int = Field(default=1234, desc="A seed for training.", hint=FieldHint.optional)
231236
# TODO: Rename to compute_dtype (not just for training), move elsewhere
232237
compute_dtype: DataType = Field(

fast_llm/engine/distributed/distributed.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
world_size: int | None = None,
2828
local_world_size: int | None = None,
2929
timeout: float = 60,
30-
use_cpu: bool = False,
30+
use_cuda: bool = True,
3131
init_method: str = "env://",
3232
backend: DistributedBackend = DistributedBackend.nccl,
3333
):
@@ -38,19 +38,20 @@ def __init__(
3838
DistributedConfig.default_local_world_size if local_world_size is None else local_world_size
3939
)
4040
self._timeout = timeout
41-
self._use_cpu = use_cpu
41+
self._use_cuda = use_cuda
4242
self._backend = backend
4343
self._process_groups = {}
4444

45-
if self._use_cpu:
46-
if backend == DistributedBackend.nccl:
47-
Assert.eq(self._world_size, 1)
48-
self._device = torch.device("cpu")
49-
else:
45+
if self._use_cuda:
46+
assert torch.cuda.is_available()
5047
Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count())
5148
torch.cuda.init()
5249
self._device = torch.device(self._rank % self._local_world_size)
5350
torch.cuda.set_device(self._device)
51+
else:
52+
if backend == DistributedBackend.nccl:
53+
Assert.eq(self._world_size, 1)
54+
self._device = torch.device("cpu")
5455

5556
if self._world_size > 1:
5657
if self._rank == 0:
@@ -153,7 +154,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]):
153154
TODO: Clarify cpu support.
154155
"""
155156

156-
def __init__(self, config: DistributedConfig, use_cpu: bool = False):
157+
def __init__(self, config: DistributedConfig):
157158
super().__init__(config)
158159
assert self._config.reference_config is None
159160

@@ -164,15 +165,15 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
164165
self._config.world_size,
165166
self._config.local_world_size,
166167
self._config.timeout,
167-
use_cpu,
168+
self._config.use_cuda,
168169
backend=self._config.backend,
169170
)
170171
else:
171172
self._pool = _default_pool
172173
Assert.geq(self._pool.world_size, self._config.world_size)
173174
Assert.eq(self._pool.rank, self._config.rank)
174175
Assert.geq(self._pool.local_world_size, self._config.local_world_size)
175-
Assert.eq(self._pool.device.type, "cpu" if use_cpu else "cuda")
176+
Assert.eq(self._pool.device.type, "cuda" if self._config.use_cuda else "cpu")
176177
Assert.eq(self._pool.backend, self._config.backend)
177178

178179
self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world])
@@ -259,5 +260,5 @@ def set_step(self, step: int, phase: PhaseType) -> None:
259260
self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED)
260261

261262
def __del__(self):
262-
if self._local_pool:
263+
if getattr(self, "_local_pool", False) and hasattr(self, "_pool"):
263264
self._pool.shutdown()

fast_llm/engine/inference/huggingface.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def from_pretrained(
8585
optimizer_state_names: tuple[str, ...] | None = None,
8686
# setup: bool = True,
8787
mode: StageMode = StageMode.training,
88-
use_cpu: bool = False,
8988
stage_filter: set | None = None,
9089
**kwargs,
9190
) -> typing.Self:
@@ -104,7 +103,6 @@ def from_pretrained(
104103
optimizer_state_names=optimizer_state_names,
105104
setup=True,
106105
mode=mode,
107-
use_cpu=use_cpu,
108106
stage_filter=stage_filter,
109107
)
110108

fast_llm/engine/multi_stage/fast_llm_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def from_pretrained(
4848
optimizer_state_names: tuple[str, ...] | None = None,
4949
setup: bool = True,
5050
mode: StageMode = StageMode.training,
51-
use_cpu: bool = False,
5251
stage_filter: set | None = None,
5352
) -> typing.Self:
5453
metadata = cls.config_class.load_metadata(pretrained_config)
@@ -69,7 +68,7 @@ def from_pretrained(
6968
)
7069

7170
if setup:
72-
model.setup(Distributed(config.distributed, use_cpu=use_cpu), mode=mode)
71+
model.setup(Distributed(config.distributed), mode=mode)
7372

7473
if mode.on_device:
7574
if pretrained_config.model_weights:

fast_llm/engine/schedule/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,21 @@ class EventType(str, enum.Enum):
191191
send = "send"
192192
recv = "recv"
193193
pipe_wait_compute = "pipe_wait_compute"
194+
195+
196+
class MockStream:
197+
stream_id: int = 0
198+
199+
def wait_stream(self, stream):
200+
pass
201+
202+
def __eq__(self, other):
203+
return isinstance(other, MockStream)
204+
205+
206+
class MockEvent:
207+
def record(self, stream=None):
208+
pass
209+
210+
def wait(self):
211+
pass

0 commit comments

Comments
 (0)