Skip to content

Fix DiLoCo training compatibility issues#3471

Draft
khatwanimohit wants to merge 1 commit intomainfrom
mohit/diloco_fixes
Draft

Fix DiLoCo training compatibility issues#3471
khatwanimohit wants to merge 1 commit intomainfrom
mohit/diloco_fixes

Conversation

@khatwanimohit
Copy link
Collaborator

Description

Summary:

  • Fix mesh discovery for drjax: Pass mesh explicitly to all drjax.broadcast/drjax.map_fn calls since jax.set_mesh() uses a different thread-local than pxla.thread_resources (which drjax reads). Also add with mesh: alongside jax.set_mesh() in train and train_compile
    entrypoints.
  • Fix data loading order: Move jax.device_put after reshape_first_axis_with_diloco in DataLoader.load_next_batch, and add the diloco reshape to RampUpDataLoader as well. Guard sharding access in reshape_for_diloco for arrays without a .sharding attribute.
  • Fix outer optimizer state sharding: Shard the SGD momentum trace like model params instead of fully replicating it (replicated sharding caused issues).
  • Auto-resolve dcn_diloco_parallelism=-1: Infer the diloco DCN parallelism from num_slices and other DCN axes, matching the existing convention for dcn_data_parallelism.
  • Add use_tokamax_gmm + enable_diloco incompatibility check: Raise a clear error since tokamax's GroupSizes vmap_rule breaks under drjax's jax.vmap batching.
  • Add diloco mesh axis to the deepseek3-671b-2dfsdp config.

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Mar 20, 2026

Codecov Report

❌ Patch coverage is 41.66667% with 14 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/diloco/diloco.py 54.54% 5 Missing ⚠️
src/maxtext/trainers/pre_train/train.py 16.66% 4 Missing and 1 partial ⚠️
src/maxtext/common/data_loader.py 50.00% 1 Missing and 1 partial ⚠️
src/maxtext/utils/train_utils.py 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant