Skip to content

Extend muon to support mask_callable weight dimension numbers#1643

Open
dylantirandaz wants to merge 1 commit intogoogle-deepmind:mainfrom
dylantirandaz:muon-mask-callable-support
Open

Extend muon to support mask_callable weight dimension numbers#1643
dylantirandaz wants to merge 1 commit intogoogle-deepmind:mainfrom
dylantirandaz:muon-mask-callable-support

Conversation

@dylantirandaz
Copy link
Copy Markdown

Summary

  • Add support for weight_dimension_numbers trees where each leaf is a callable, resolved via _masking._mask_callable
  • Extends both scale_by_muon and scale_by_shape to handle this case
  • Resolves TODO(rdyro) at _muon.py:478

Test plan

  • Run new test_mask_callable_weight_dim_nums test (parameterized across frobenius/aol/schatten)
  • Verify existing muon tests still pass
  • Run bash test.sh

Add support for weight_dimension_numbers trees where each leaf is a
callable, resolved via _masking._mask_callable. This extends both
scale_by_muon and scale_by_shape to handle this case by mapping
each callable leaf over updates.

Resolves TODO(rdyro) at _muon.py:478.
@rdyro
Copy link
Copy Markdown
Collaborator

rdyro commented Mar 23, 2026

I'm not sure we actually want this functionality, can you say what your use case here is?

@dylantirandaz
Copy link
Copy Markdown
Author

Doesn't add new functionality beyond the existing top level callable. It just allows weight_dimension_numbers to be a pytree of callables, where each leaf is resolved as fn(updates)to produce the final tree. My goal was just ergonomics/composability. I can drop it if you prefer to keep the API surface smaller

@rdyro
Copy link
Copy Markdown
Collaborator

rdyro commented Mar 23, 2026

Hmm, I think this feature was mostly meant for supporting tree of modules like what might exist in equinox.

Since this is specifically for dimension numbers, I'm unsure if it makes sense to have each module define its own dimension number; maybe actually.

Finally, I think the current implementation in this PR is incorrect, for mapping over updates and dimension numbers jointly, we'd need a joint tree map, we can't pass full updates to each leaf callable.

Which way would you like to take it? Having a specific application / use case in mind would probably really help justify this change. What do you think?

@dylantirandaz
Copy link
Copy Markdown
Author

Okay thanks I agree the current implementation is wrong since each callable is the full updates tree. I can rework this to use a joint tree map so callables are resolved against their corresponding update leaf/subtree.
A use case I have in mind is heterogenous models where different sub modules naturally infer their own MuonDimensionNumbers from local parameter shapes/layouts, instead of requiring one global resolver. I will update the PR with that approach and clarify the use case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants