@@ -27,7 +27,7 @@ def __init__(
2727 world_size : int | None = None ,
2828 local_world_size : int | None = None ,
2929 timeout : float = 60 ,
30- use_cpu : bool = False ,
30+ use_cuda : bool = True ,
3131 init_method : str = "env://" ,
3232 backend : DistributedBackend = DistributedBackend .nccl ,
3333 ):
@@ -38,19 +38,20 @@ def __init__(
3838 DistributedConfig .default_local_world_size if local_world_size is None else local_world_size
3939 )
4040 self ._timeout = timeout
41- self ._use_cpu = use_cpu
41+ self ._use_cuda = use_cuda
4242 self ._backend = backend
4343 self ._process_groups = {}
4444
45- if self ._use_cpu :
46- if backend == DistributedBackend .nccl :
47- Assert .eq (self ._world_size , 1 )
48- self ._device = torch .device ("cpu" )
49- else :
45+ if self ._use_cuda :
46+ assert torch .cuda .is_available ()
5047 Assert .in_range_incl (self ._local_world_size , 1 , torch .cuda .device_count ())
5148 torch .cuda .init ()
5249 self ._device = torch .device (self ._rank % self ._local_world_size )
5350 torch .cuda .set_device (self ._device )
51+ else :
52+ if backend == DistributedBackend .nccl :
53+ Assert .eq (self ._world_size , 1 )
54+ self ._device = torch .device ("cpu" )
5455
5556 if self ._world_size > 1 :
5657 if self ._rank == 0 :
@@ -153,7 +154,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]):
153154 TODO: Clarify cpu support.
154155 """
155156
156- def __init__ (self , config : DistributedConfig , use_cpu : bool = False ):
157+ def __init__ (self , config : DistributedConfig ):
157158 super ().__init__ (config )
158159 assert self ._config .reference_config is None
159160
@@ -164,15 +165,15 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
164165 self ._config .world_size ,
165166 self ._config .local_world_size ,
166167 self ._config .timeout ,
167- use_cpu ,
168+ self . _config . use_cuda ,
168169 backend = self ._config .backend ,
169170 )
170171 else :
171172 self ._pool = _default_pool
172173 Assert .geq (self ._pool .world_size , self ._config .world_size )
173174 Assert .eq (self ._pool .rank , self ._config .rank )
174175 Assert .geq (self ._pool .local_world_size , self ._config .local_world_size )
175- Assert .eq (self ._pool .device .type , "cpu " if use_cpu else "cuda " )
176+ Assert .eq (self ._pool .device .type , "cuda " if self . _config . use_cuda else "cpu " )
176177 Assert .eq (self ._pool .backend , self ._config .backend )
177178
178179 self .world_group = self .add_group (self ._config .distributed_dims [DistributedDimNames .world ])
@@ -259,5 +260,5 @@ def set_step(self, step: int, phase: PhaseType) -> None:
259260 self .tp_generator .manual_seed ((self ._tp_seed + seed_shift ) % MAX_SEED )
260261
261262 def __del__ (self ):
262- if self . _local_pool :
263+ if getattr ( self , " _local_pool" , False ) and hasattr ( self , "_pool" ) :
263264 self ._pool .shutdown ()
0 commit comments