Extend muon to support mask_callable weight dimension numbers#1643
Extend muon to support mask_callable weight dimension numbers#1643dylantirandaz wants to merge 1 commit intogoogle-deepmind:mainfrom
Conversation
Add support for weight_dimension_numbers trees where each leaf is a callable, resolved via _masking._mask_callable. This extends both scale_by_muon and scale_by_shape to handle this case by mapping each callable leaf over updates. Resolves TODO(rdyro) at _muon.py:478.
|
I'm not sure we actually want this functionality, can you say what your use case here is? |
|
Doesn't add new functionality beyond the existing top level callable. It just allows weight_dimension_numbers to be a pytree of callables, where each leaf is resolved as fn(updates)to produce the final tree. My goal was just ergonomics/composability. I can drop it if you prefer to keep the API surface smaller |
|
Hmm, I think this feature was mostly meant for supporting tree of modules like what might exist in equinox. Since this is specifically for dimension numbers, I'm unsure if it makes sense to have each module define its own dimension number; maybe actually. Finally, I think the current implementation in this PR is incorrect, for mapping over updates and dimension numbers jointly, we'd need a joint tree map, we can't pass full updates to each leaf callable. Which way would you like to take it? Having a specific application / use case in mind would probably really help justify this change. What do you think? |
|
Okay thanks I agree the current implementation is wrong since each callable is the full updates tree. I can rework this to use a joint tree map so callables are resolved against their corresponding update leaf/subtree. |
Summary
weight_dimension_numberstrees where each leaf is a callable, resolved via_masking._mask_callablescale_by_muonandscale_by_shapeto handle this caseTODO(rdyro)at_muon.py:478Test plan
test_mask_callable_weight_dim_numstest (parameterized across frobenius/aol/schatten)bash test.sh