From 7d2872c9baa02d85463473081ace5dd615e4e83c Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Thu, 12 Mar 2026 02:49:25 +0000 Subject: [PATCH] diloco fixes --- src/maxtext/common/data_loader.py | 9 +++--- .../configs/models/deepseek3-671b-2dfsdp.yml | 3 +- src/maxtext/configs/types.py | 24 ++++++++++++++ src/maxtext/trainers/diloco/diloco.py | 31 ++++++++++++------- src/maxtext/trainers/pre_train/train.py | 25 +++++++-------- .../trainers/pre_train/train_compile.py | 4 ++- src/maxtext/utils/train_utils.py | 4 +-- tests/unit/diloco_test.py | 18 +++++++++++ 8 files changed, 85 insertions(+), 33 deletions(-) diff --git a/src/maxtext/common/data_loader.py b/src/maxtext/common/data_loader.py index d40bab2ef1..21bd870bc8 100644 --- a/src/maxtext/common/data_loader.py +++ b/src/maxtext/common/data_loader.py @@ -71,13 +71,10 @@ def load_next_batch_pre_sharding(self): def load_next_batch(self, *args, **kwargs): """Loads the next batch with sharding hint""" - example_batch = jax.device_put( - self.load_next_batch_pre_sharding(), - self.input_data_shardings, - ) + example_batch = self.load_next_batch_pre_sharding() if self.config.enable_diloco: example_batch = diloco.reshape_first_axis_with_diloco(self.config.num_diloco_replicas, example_batch) - return example_batch + return jax.device_put(example_batch, self.input_data_shardings) def check_example_batch(self): if self.config.max_checkify: @@ -157,6 +154,8 @@ def _slice(data): self.buffer_start = slice_end output = jax.tree.map(_slice, self.batch_buffer) self.rampup_active = rampup_manager.update() + if self.config.enable_diloco: + output = diloco.reshape_first_axis_with_diloco(self.config.num_diloco_replicas, output) return jax.device_put(output, self.input_data_shardings) diff --git a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml index c137d94c98..939587b888 100644 --- a/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml +++ b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml @@ -56,7 +56,7 @@ rope_truncate: True rope_attention_scaling: False override_logical_axis_rules: True -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context'] +mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context'] data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']], @@ -79,4 +79,5 @@ logical_axis_rules: [ ['mlp', ['fsdp_transpose', 'expert']], ['mlp_only_fsdp_transpose', ['fsdp_transpose']], ['mlp_only_tensor', ['expert']], + ['diloco', 'diloco'], ] diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 21296e965d..f391126c50 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2643,8 +2643,32 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes] # Diloco params + # Resolve dcn_diloco_parallelism=-1 if left unspecified, using the same convention as dcn_data_parallelism. + # num_diloco_replicas must be computed after this resolution, so we resolve it here rather than + # relying on fill_unspecified_mesh_axes (which runs later during mesh creation). + if self.dcn_diloco_parallelism == -1: + other_dcn_product = prod(v for v in self.dcn_parallelism if v != -1) + assert other_dcn_product > 0 and self.num_slices % other_dcn_product == 0, ( + f"Cannot resolve dcn_diloco_parallelism=-1: num_slices={self.num_slices} is not divisible " + f"by the product of other DCN parallelism values ({other_dcn_product})." + ) + self.dcn_diloco_parallelism = self.num_slices // other_dcn_product + # Keep dcn_parallelism list consistent with the resolved value. + diloco_idx = self.dcn_parallelism.index(-1) + self.dcn_parallelism[diloco_idx] = self.dcn_diloco_parallelism self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism) + # use_tokamax_gmm is incompatible with enable_diloco: drjax.map_fn wraps the train step in + # jax.vmap over the diloco axis, which causes JAX to batch through lax.scan (layer scan). + # Tokamax's vmap_rule then tries to reconstruct GroupSizes with a batched 2-D value, but + # GroupSizes.__post_init__ requires exactly a 1-D shape. + if self.enable_diloco and self.use_tokamax_gmm: + raise ValueError( + "use_tokamax_gmm=True is not compatible with enable_diloco=True due to a known " + "incompatibility between tokamax's GroupSizes vmap_rule and JAX's scan batching. " + "Please set use_tokamax_gmm=False." + ) + # Final string-to-enum conversions if they haven't been coerced by pydantic yet. if isinstance(self.decoder_block, str): self.decoder_block = DecoderBlockType(self.decoder_block.lower()) diff --git a/src/maxtext/trainers/diloco/diloco.py b/src/maxtext/trainers/diloco/diloco.py index 39e2b70793..a9ef64631a 100644 --- a/src/maxtext/trainers/diloco/diloco.py +++ b/src/maxtext/trainers/diloco/diloco.py @@ -108,9 +108,11 @@ def extend_pspec(pspec: jax.sharding.PartitionSpec | Sequence[str | Sequence[str def reshape_for_diloco(arr): batch_dim, *example_shape = arr.shape diloco_shape = (num_diloco_replicas, batch_dim // num_diloco_replicas, *example_shape) - s = arr.sharding - s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec)) - return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s) + if hasattr(arr, "sharding"): + s = arr.sharding + s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec)) + return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s) + return jnp.reshape(arr, shape=diloco_shape) return jax.tree.map(reshape_for_diloco, pytree) @@ -166,9 +168,11 @@ def add_diloco_dim(x): # Build shardings inner_state_shardings = add_diloco_to_sharding(state_mesh_shardings) - outer_opt_state_sharding = jax.tree.map( - lambda _: jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()), - outer_opt_state, + # Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState()) + # We shard the momentum trace the same way as the parameters. + outer_opt_state_sharding = ( + optax.TraceState(trace=state_mesh_shardings.params), + optax.EmptyState(), ) diloco_state_shardings = DiLoCoTrainState( inner_state=inner_state_shardings, @@ -183,6 +187,7 @@ def add_diloco_dim(x): def build_diloco_state( config: "pyconfig.HyperParameters", initialize_state: Callable[[], train_state.TrainState], + mesh: jax.sharding.Mesh | None = None, ) -> tuple[DiLoCoTrainState, PyTree]: """Given a non-DiLoCo train state, construct a DiLoCo training state.""" outer_optimizer = optax.sgd( @@ -195,7 +200,10 @@ def build_diloco_state( def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: state = initialize_state() # Inner state must be broadcast across clients. - inner_state = drjax.broadcast(state) + # Pass mesh explicitly because jax.set_mesh() uses a different thread-local + # than pxla.thread_resources (which drjax reads), so drjax cannot find the + # mesh automatically when jax.set_mesh is used. + inner_state = drjax.broadcast(state, mesh=mesh) # Outer state retains a single copy of the model parameters and optimizer state. outer_params = state.params outer_opt_state = outer_optimizer.init(outer_params) @@ -211,6 +219,7 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: def build_diloco_train_step( config: pyconfig.HyperParameters, train_step: Callable[[train_state.TrainState, Batch, PRNGKey], tuple[train_state.TrainState, Metrics]], + mesh: jax.sharding.Mesh | None = None, ) -> Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]]: """Convert a local state and train step into DiLoCo-compatible versions. @@ -234,7 +243,7 @@ def build_diloco_train_step( def synchronize(state): # Calculate the delta between the current replica's state and the global # state (since last synchronization). - broadcast_outer_params = drjax.broadcast(state.params) + broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh) model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) # Treat the average delta as the outer optimizer's gradient and apply to # the global (outer) model params. @@ -244,7 +253,7 @@ def synchronize(state): # Replace inner model params with the new global model params. # NOTE: inner optimizer state is retained despite the change in parameters, # see section 6.1 in https://arxiv.org/pdf/2311.08105. - new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state) + new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh) return state.replace( params=new_outer_params, outer_opt_state=new_opt_state, @@ -259,8 +268,8 @@ def typed_reduce_mean(in_tree): @drjax.program(placements={"diloco": config.num_diloco_replicas}) def diloco_train_step(state, batch, prng): # Broadcast the RNG across replicas. - broadcast_rng = drjax.broadcast(prng) - inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng)) + broadcast_rng = drjax.broadcast(prng, mesh=mesh) + inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh) avg_metrics = typed_reduce_mean(metrics) state = state.replace( inner_state=inner_state, diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 4b3505b224..fd97cb7a17 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -490,19 +490,18 @@ def train_loop(config, recorder, state=None): params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) - p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, - model, - mesh, - state, - state_mesh_shardings, - train_step, - eval_step, - eval_data_iterator, - params_shardings, - ) - - with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, + model, + mesh, + state, + state_mesh_shardings, + train_step, + eval_step, + eval_data_iterator, + params_shardings, + ) shaped_batch = maxtext_utils.get_shaped_batch(config) if config.shard_optimizer_over_data: state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 408340016e..74f36ea045 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -132,7 +132,9 @@ def jit_and_compile( logical_axis_rules, ): """Jit, lower, and compile func.""" - with jax.set_mesh(mesh), logical_axis_rules: + # Use both jax.set_mesh (new API) and `with mesh:` (old API) so that drjax, + # which reads from pxla.thread_resources.env.physical_mesh, can find the mesh. + with jax.set_mesh(mesh), mesh, logical_axis_rules: jitted = jax.jit( func, in_shardings=in_shardings, diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 00eb408ad3..2574ccf9aa 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -162,7 +162,7 @@ def jit_train_and_eval_step( """Returns a JIT-compiled train and eval step function.""" if config.enable_diloco: train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) - train_step = diloco.build_diloco_train_step(config, train_step_partial) + train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh) data_sharding = sharding.get_input_data_sharding(config, mesh) p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings) p_eval_step = None @@ -229,7 +229,7 @@ def setup_train_loop(config, recorder, devices=None): if config.enable_diloco: with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - state, outer_opt_state_sharding = diloco.build_diloco_state(config, lambda: state) + state, outer_opt_state_sharding = diloco.build_diloco_state(config, lambda: state, mesh=mesh) # create state_mesh_shardings for the DilocoState inner_state_shardings = diloco.add_diloco_to_sharding(state_mesh_shardings) diff --git a/tests/unit/diloco_test.py b/tests/unit/diloco_test.py index 042216eb10..177fbac98a 100644 --- a/tests/unit/diloco_test.py +++ b/tests/unit/diloco_test.py @@ -268,6 +268,24 @@ def loss_fn(params, batch): # synchronization). chex.assert_trees_all_equal(diloco_test_state.params, step_three_outer_params) + @pytest.mark.cpu_only + def test_diloco_qwen3_moe_two_slices(self): + temp_dir = gettempdir() + compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_diloco_qwen3_moe.pickle") + train_compile_main( + ( + None, + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=tpu7x-16", + "compile_topology_num_slices=2", + "ici_fsdp_parallelism=-1", + "dcn_diloco_parallelism=2", + "enable_diloco=true", + "model_name=qwen3-30b-a3b", + ) + ) + @pytest.mark.tpu_only def test_diloco_two_slices(self): temp_dir = gettempdir()