Skip to content

Conversation

@Shuang-cnt
Copy link
Collaborator

@Shuang-cnt Shuang-cnt commented Jan 21, 2026

Description

This change builds upon PR#2866 and PR#2926 to add functionality for printing logical axes

Tests

Command:
python -m MaxText.train_compile MaxText/configs/base.yml compile_topology=v5p-1024 compile_topology_num_slices=1 model_name=deepseek3-671b per_device_batch_size=1 ici_tensor_parallelism=8 ici_expert_parallelism=8 log_config=false debug_sharding=true

Output example:

Log for a real train: https://paste.googleplex.com/6519526529302528

Weight with Logic axes:

I0123 00:00:29.309755 133441853451392 maxtext_utils.py:1269]  params/params/decoder/dense_layers/mlp/wi_0/kernel
    Shape:     float32[7168,3,18432]
    Logical:   PartitionSpec('embed', 'dense_layers', 'mlp')
    Physical:  (('fsdp', 'expert'), None, 'tensor')
I0123 00:00:29.309807 133441853451392 maxtext_utils.py:1269]  params/params/decoder/dense_layers/mlp/wi_1/kernel
    Shape:     float32[7168,3,18432]
    Logical:   PartitionSpec('embed', 'dense_layers', 'mlp')
    Physical:  (('fsdp', 'expert'), None, 'tensor')
I0123 00:00:29.309876 133441853451392 maxtext_utils.py:1269]  params/params/decoder/dense_layers/mlp/wo/kernel
    Shape:     float32[18432,3,7168]
    Logical:   PartitionSpec('mlp', 'dense_layers', 'embed')
    Physical:  ('tensor', None, ('fsdp', 'expert'))

Activation Logical Axes:

I0126 22:03:30.305064 128637779627136 deepseek.py:131] Logical:  bfloat16[512,2048,7168]..................................... ('activation_batch', 'activation_norm_length', 'activation_embed')
I0126 22:03:30.305159 128637779627136 deepseek.py:131] bfloat16[512,2048,7168]......................................................... (('fsdp', 'expert'), None, 'tensor').
I0126 22:03:30.309315 128637779627136 attention_mla.py:736] Logical:  bfloat16[512,2048,7168]..................................... ('activation_batch', 'activation_length_no_exp', 'activation_embed')
I0126 22:03:30.309363 128637779627136 attention_mla.py:736] bfloat16[512,2048,7168]......................................................... (('fsdp', 'expert'), None, 'tensor').
I0126 22:03:30.312696 128637779627136 attention_mla.py:538] Logical:  bfloat16[512,2048,128,128].................................. ('activation_kv_batch', 'activation_length_no_exp', 'activation_kv_heads', 'activation_kv_head_dim')
I0126 22:03:30.312744 128637779627136 attention_mla.py:538] bfloat16[512,2048,128,128]...................................................... (('fsdp', 'expert'), None, 'tensor', None).

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 21, 2026

Codecov Report

❌ Patch coverage is 31.42857% with 24 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/maxtext_utils.py 35.00% 13 Missing ⚠️
src/MaxText/sharding.py 14.28% 5 Missing and 1 partial ⚠️
src/MaxText/train_compile.py 60.00% 2 Missing ⚠️
src/MaxText/train_utils.py 0.00% 2 Missing ⚠️
src/MaxText/model_creation_utils.py 0.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@Shuang-cnt Shuang-cnt force-pushed the user/sharony/logicaxes branch from 6433043 to 6591621 Compare January 22, 2026 02:48
Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm although note my pending comment

Copy link
Collaborator

@NuojCheng NuojCheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! Made some comments and suggestions. It would be nice to add a real train test in the PR description.

@Shuang-cnt
Copy link
Collaborator Author

Thank you for the PR! Made some comments and suggestions. It would be nice to add a real train test in the PR description.

Thanks for the suggestion. Add a real train log. Please check.

@Shuang-cnt Shuang-cnt force-pushed the user/sharony/logicaxes branch from be0bb6a to d84d4ca Compare January 28, 2026 17:23
@copybara-service copybara-service bot merged commit fcb87c0 into AI-Hypercomputer:main Jan 28, 2026
21 of 22 checks passed
@Shuang-cnt Shuang-cnt deleted the user/sharony/logicaxes branch January 28, 2026 18:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants