Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/maxtext/common/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,10 @@ def load_next_batch_pre_sharding(self):

def load_next_batch(self, *args, **kwargs):
"""Loads the next batch with sharding hint"""
example_batch = jax.device_put(
self.load_next_batch_pre_sharding(),
self.input_data_shardings,
)
example_batch = self.load_next_batch_pre_sharding()
if self.config.enable_diloco:
example_batch = diloco.reshape_first_axis_with_diloco(self.config.num_diloco_replicas, example_batch)
return example_batch
return jax.device_put(example_batch, self.input_data_shardings)

def check_example_batch(self):
if self.config.max_checkify:
Expand Down Expand Up @@ -157,6 +154,8 @@ def _slice(data):
self.buffer_start = slice_end
output = jax.tree.map(_slice, self.batch_buffer)
self.rampup_active = rampup_manager.update()
if self.config.enable_diloco:
output = diloco.reshape_first_axis_with_diloco(self.config.num_diloco_replicas, output)
return jax.device_put(output, self.input_data_shardings)


Expand Down
3 changes: 2 additions & 1 deletion src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ rope_truncate: True
rope_attention_scaling: False

override_logical_axis_rules: True
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
Expand All @@ -79,4 +79,5 @@ logical_axis_rules: [
['mlp', ['fsdp_transpose', 'expert']],
['mlp_only_fsdp_transpose', ['fsdp_transpose']],
['mlp_only_tensor', ['expert']],
['diloco', 'diloco'],
]
24 changes: 24 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2643,8 +2643,32 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]

# Diloco params
# Resolve dcn_diloco_parallelism=-1 if left unspecified, using the same convention as dcn_data_parallelism.
# num_diloco_replicas must be computed after this resolution, so we resolve it here rather than
# relying on fill_unspecified_mesh_axes (which runs later during mesh creation).
if self.dcn_diloco_parallelism == -1:
other_dcn_product = prod(v for v in self.dcn_parallelism if v != -1)
assert other_dcn_product > 0 and self.num_slices % other_dcn_product == 0, (
f"Cannot resolve dcn_diloco_parallelism=-1: num_slices={self.num_slices} is not divisible "
f"by the product of other DCN parallelism values ({other_dcn_product})."
)
self.dcn_diloco_parallelism = self.num_slices // other_dcn_product
# Keep dcn_parallelism list consistent with the resolved value.
diloco_idx = self.dcn_parallelism.index(-1)
self.dcn_parallelism[diloco_idx] = self.dcn_diloco_parallelism
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)

# use_tokamax_gmm is incompatible with enable_diloco: drjax.map_fn wraps the train step in
# jax.vmap over the diloco axis, which causes JAX to batch through lax.scan (layer scan).
# Tokamax's vmap_rule then tries to reconstruct GroupSizes with a batched 2-D value, but
# GroupSizes.__post_init__ requires exactly a 1-D shape.
if self.enable_diloco and self.use_tokamax_gmm:
raise ValueError(
"use_tokamax_gmm=True is not compatible with enable_diloco=True due to a known "
"incompatibility between tokamax's GroupSizes vmap_rule and JAX's scan batching. "
"Please set use_tokamax_gmm=False."
)

# Final string-to-enum conversions if they haven't been coerced by pydantic yet.
if isinstance(self.decoder_block, str):
self.decoder_block = DecoderBlockType(self.decoder_block.lower())
Expand Down
31 changes: 20 additions & 11 deletions src/maxtext/trainers/diloco/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ def extend_pspec(pspec: jax.sharding.PartitionSpec | Sequence[str | Sequence[str
def reshape_for_diloco(arr):
batch_dim, *example_shape = arr.shape
diloco_shape = (num_diloco_replicas, batch_dim // num_diloco_replicas, *example_shape)
s = arr.sharding
s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec))
return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s)
if hasattr(arr, "sharding"):
s = arr.sharding
s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec))
return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s)
return jnp.reshape(arr, shape=diloco_shape)

return jax.tree.map(reshape_for_diloco, pytree)

Expand Down Expand Up @@ -166,9 +168,11 @@ def add_diloco_dim(x):

# Build shardings
inner_state_shardings = add_diloco_to_sharding(state_mesh_shardings)
outer_opt_state_sharding = jax.tree.map(
lambda _: jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
outer_opt_state,
# Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState())
# We shard the momentum trace the same way as the parameters.
outer_opt_state_sharding = (
optax.TraceState(trace=state_mesh_shardings.params),
optax.EmptyState(),
)
diloco_state_shardings = DiLoCoTrainState(
inner_state=inner_state_shardings,
Expand All @@ -183,6 +187,7 @@ def add_diloco_dim(x):
def build_diloco_state(
config: "pyconfig.HyperParameters",
initialize_state: Callable[[], train_state.TrainState],
mesh: jax.sharding.Mesh | None = None,
) -> tuple[DiLoCoTrainState, PyTree]:
"""Given a non-DiLoCo train state, construct a DiLoCo training state."""
outer_optimizer = optax.sgd(
Expand All @@ -195,7 +200,10 @@ def build_diloco_state(
def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
state = initialize_state()
# Inner state must be broadcast across clients.
inner_state = drjax.broadcast(state)
# Pass mesh explicitly because jax.set_mesh() uses a different thread-local
# than pxla.thread_resources (which drjax reads), so drjax cannot find the
# mesh automatically when jax.set_mesh is used.
inner_state = drjax.broadcast(state, mesh=mesh)
# Outer state retains a single copy of the model parameters and optimizer state.
outer_params = state.params
outer_opt_state = outer_optimizer.init(outer_params)
Expand All @@ -211,6 +219,7 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
def build_diloco_train_step(
config: pyconfig.HyperParameters,
train_step: Callable[[train_state.TrainState, Batch, PRNGKey], tuple[train_state.TrainState, Metrics]],
mesh: jax.sharding.Mesh | None = None,
) -> Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]]:
"""Convert a local state and train step into DiLoCo-compatible versions.

Expand All @@ -234,7 +243,7 @@ def build_diloco_train_step(
def synchronize(state):
# Calculate the delta between the current replica's state and the global
# state (since last synchronization).
broadcast_outer_params = drjax.broadcast(state.params)
broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh)
model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params)
# Treat the average delta as the outer optimizer's gradient and apply to
# the global (outer) model params.
Expand All @@ -244,7 +253,7 @@ def synchronize(state):
# Replace inner model params with the new global model params.
# NOTE: inner optimizer state is retained despite the change in parameters,
# see section 6.1 in https://arxiv.org/pdf/2311.08105.
new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state)
new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh)
return state.replace(
params=new_outer_params,
outer_opt_state=new_opt_state,
Expand All @@ -259,8 +268,8 @@ def typed_reduce_mean(in_tree):
@drjax.program(placements={"diloco": config.num_diloco_replicas})
def diloco_train_step(state, batch, prng):
# Broadcast the RNG across replicas.
broadcast_rng = drjax.broadcast(prng)
inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng))
broadcast_rng = drjax.broadcast(prng, mesh=mesh)
inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh)
avg_metrics = typed_reduce_mean(metrics)
state = state.replace(
inner_state=inner_state,
Expand Down
25 changes: 12 additions & 13 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,19 +490,18 @@ def train_loop(config, recorder, state=None):

params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)

p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
config,
model,
mesh,
state,
state_mesh_shardings,
train_step,
eval_step,
eval_data_iterator,
params_shardings,
)

with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
config,
model,
mesh,
state,
state_mesh_shardings,
train_step,
eval_step,
eval_data_iterator,
params_shardings,
)
shaped_batch = maxtext_utils.get_shaped_batch(config)
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
Expand Down
4 changes: 3 additions & 1 deletion src/maxtext/trainers/pre_train/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def jit_and_compile(
logical_axis_rules,
):
"""Jit, lower, and compile func."""
with jax.set_mesh(mesh), logical_axis_rules:
# Use both jax.set_mesh (new API) and `with mesh:` (old API) so that drjax,
# which reads from pxla.thread_resources.env.physical_mesh, can find the mesh.
with jax.set_mesh(mesh), mesh, logical_axis_rules:
jitted = jax.jit(
func,
in_shardings=in_shardings,
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def jit_train_and_eval_step(
"""Returns a JIT-compiled train and eval step function."""
if config.enable_diloco:
train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings)
train_step = diloco.build_diloco_train_step(config, train_step_partial)
train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh)
data_sharding = sharding.get_input_data_sharding(config, mesh)
p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings)
p_eval_step = None
Expand Down Expand Up @@ -229,7 +229,7 @@ def setup_train_loop(config, recorder, devices=None):

if config.enable_diloco:
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
state, outer_opt_state_sharding = diloco.build_diloco_state(config, lambda: state)
state, outer_opt_state_sharding = diloco.build_diloco_state(config, lambda: state, mesh=mesh)

# create state_mesh_shardings for the DilocoState
inner_state_shardings = diloco.add_diloco_to_sharding(state_mesh_shardings)
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/diloco_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,24 @@ def loss_fn(params, batch):
# synchronization).
chex.assert_trees_all_equal(diloco_test_state.params, step_three_outer_params)

@pytest.mark.cpu_only
def test_diloco_qwen3_moe_two_slices(self):
temp_dir = gettempdir()
compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_diloco_qwen3_moe.pickle")
train_compile_main(
(
None,
get_test_config_path(),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=tpu7x-16",
"compile_topology_num_slices=2",
"ici_fsdp_parallelism=-1",
"dcn_diloco_parallelism=2",
"enable_diloco=true",
"model_name=qwen3-30b-a3b",
)
)

@pytest.mark.tpu_only
def test_diloco_two_slices(self):
temp_dir = gettempdir()
Expand Down
Loading