refactor(pt): fully refactor of HybridMuon optimizer#5275
refactor(pt): fully refactor of HybridMuon optimizer#5275OutisLi wants to merge 4 commits intodeepmodeling:masterfrom
Conversation
📝 WalkthroughWalkthroughReworks HybridMuonOptimizer routing to name-based decisions and a muon_mode ("2d"/"flat"/"slice"), adds optional Magma-lite damping (magma_muon), introduces batched Newton–Schulz orthogonalization and new helper APIs, and updates training wiring, arg parsing, and tests for the new routing and Magma behavior. Changes
Sequence DiagramsequenceDiagram
actor User
participant Optimizer as HybridMuonOptimizer
participant Router as Routing Logic
participant Shape as Shape Analysis
participant Magma as Magma Scaler
participant Step as Optimizer Step
User->>Optimizer: step(closure)
Optimizer->>Router: get_adam_route(param_name)
Router-->>Optimizer: route (Adam / AdamW / Muon)
alt Muon route
Optimizer->>Shape: get_effective_shape -> get_matrix_view_shape(muon_mode)
Shape-->>Optimizer: matrix/view dims
alt magma_muon enabled
Optimizer->>Magma: _compute_magma_scales_for_bucket(...)
Magma-->>Optimizer: damping scales
end
Optimizer->>Step: _batched_newton_schulz_orth -> apply muon updates
Step-->>Optimizer: updated params/state
else Adam/AdamW route
Optimizer->>Step: apply Adam / AdamW update
Step-->>Optimizer: updated params/state
end
Optimizer-->>User: return loss
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
deepmd/utils/argcheck.py (1)
3758-3768: Validatemuon_modevalues at argcheck time.
muon_modeis free-formstrhere, so typos pass schema normalization and fail later during optimizer construction. Consider constraining accepted values to{"2d", "flat", "slice"}in this layer for earlier, clearer errors.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/utils/argcheck.py` around lines 3758 - 3768, The muon_mode argument in the argcheck schema (the "muon_mode" param definition) is currently an unconstrained str which lets typos slip through; change the schema to restrict allowed values to the set {"2d", "flat", "slice"} (e.g. use an enum/choices validator or an explicit check) so validation fails early with a clear message referencing muon_mode when an invalid value is provided.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 343-345: Define specific exception classes (e.g.,
InvalidTensorShapeError(ValueError) and InvalidMuonModeError(ValueError)) near
the top of the module with the full explanatory messages as their default
docstring/message, then replace the three inline multi-line ValueError raises
with simple raises of those classes: replace the shape check in
batched_newton_schulz (the current raise ValueError(... "Batched Newton-Schulz
expects a 3D tensor...")) with raise InvalidTensorShapeError, and replace both
muon_mode validation raises (the f-string multi-line and the single-line check)
with raise InvalidMuonModeError; run ruff check/format before committing.
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 358-361: The assertion comparing optimizer state uses exact float
equality which is fragile on CUDA; update the torch.allclose call for
model1.adam_scale vs model2.adam_scale to use a small nonzero tolerance (e.g.
atol=1e-6 and/or rtol=1e-6) instead of atol=0.0, rtol=0.0 so the test checks
near-equality while remaining stable; locate the comparison around
model1.adam_scale in the test_hybrid_muon.py and replace the zero tolerances
with a tight positive tolerance.
---
Nitpick comments:
In `@deepmd/utils/argcheck.py`:
- Around line 3758-3768: The muon_mode argument in the argcheck schema (the
"muon_mode" param definition) is currently an unconstrained str which lets typos
slip through; change the schema to restrict allowed values to the set {"2d",
"flat", "slice"} (e.g. use an enum/choices validator or an explicit check) so
validation fails early with a clear message referencing muon_mode when an
invalid value is provided.
ℹ️ Review info
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_hybrid_muon.py
There was a problem hiding this comment.
Pull request overview
This PR refactors the PyTorch HybridMuonOptimizer to use name-based routing, adds a new muon_mode routing scheme (including per-slice Muon for higher-rank tensors), and introduces optional “Magma-lite” damping applied only on the Muon update path. It also updates training/config plumbing and expands tests to cover the new routing and damping behavior.
Changes:
- Replace
muon_2d_only/min_2d_dimrouting withmuon_mode(2d/flat/slice) and parameter-name-based routing rules. - Add
magma_muonoption implementing per-block momentum/gradient alignment scoring and damping on Muon updates. - Update training arg schema + trainer optimizer construction; expand unit tests for slice-mode routing and Magma damping.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
deepmd/pt/optimizer/hybrid_muon.py |
Implements muon_mode routing, name-based Adam/AdamW routing, batched NS for slice mode, and Magma-lite damping. |
deepmd/pt/train/training.py |
Wires new optimizer args (muon_mode, magma_muon) and passes named parameters for name-based routing. |
deepmd/utils/argcheck.py |
Updates the training config schema/docs for HybridMuon to use muon_mode and adds magma_muon. |
source/tests/pt/test_hybrid_muon.py |
Removes outdated tests and adds new coverage for slice routing, 2d routing behavior, and Magma damping state/range. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5275 +/- ##
==========================================
+ Coverage 82.28% 82.29% +0.01%
==========================================
Files 773 773
Lines 77331 77411 +80
Branches 3659 3659
==========================================
+ Hits 63631 63709 +78
- Misses 12529 12531 +2
Partials 1171 1171 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
source/tests/pt/test_hybrid_muon.py (1)
358-361:⚠️ Potential issue | 🟡 MinorAvoid exact float equality for the Adam-path invariance assertion.
atol=0.0, rtol=0.0is still brittle on CUDA. A very small tolerance keeps the intent while avoiding flaky failures.💡 Suggested test tweak
self.assertTrue( - torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0) + torch.allclose(model1.adam_scale, model2.adam_scale, atol=1e-7, rtol=1e-6) )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt/test_hybrid_muon.py` around lines 358 - 361, The test currently asserts exact equality on model1.adam_scale vs model2.adam_scale using atol=0.0, rtol=0.0 which is brittle on CUDA; update the assertion in test_hybrid_muon.py to allow a tiny tolerance (e.g. atol=1e-6 or rtol=1e-6) when comparing torch.allclose(model1.adam_scale, model2.adam_scale) so the Adam-path invariance intent remains but avoids flaky failures on GPU.
🧹 Nitpick comments (1)
deepmd/utils/argcheck.py (1)
2950-2972: Update the HybridMuon option docs to match the new name-based routes.The help text here still describes
adam_beta1/adam_beta2as 1D-only andweight_decayas Muon-only, but the optimizer now applies those settings to explicitadam_/adamw_routes too. Right now the generated config docs under-document the new behavior for higher-rank parameters.💡 Suggested doc fix
Argument( "adam_beta1", float, optional=True, default=0.9, doc=doc_only_pt_supported - + "Adam beta1 coefficient for 1D parameters (biases, norms).", + + "Adam beta1 coefficient for Adam-routed parameters " + "(1D params and explicit `adam_` / `adamw_` routes).", ), Argument( "adam_beta2", float, optional=True, default=0.95, doc=doc_only_pt_supported - + "Adam beta2 coefficient for 1D parameters (biases, norms).", + + "Adam beta2 coefficient for Adam-routed parameters " + "(1D params and explicit `adam_` / `adamw_` routes).", ), Argument( "weight_decay", float, optional=True, default=0.001, doc=doc_only_pt_supported - + "Weight decay coefficient. Applied only to Muon-routed parameters", + + "Weight decay coefficient. Applied to Muon-routed parameters " + "and `adamw_`-routed parameters.", ),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/utils/argcheck.py` around lines 2950 - 2972, The docs for the Argument entries adam_beta1, adam_beta2, and weight_decay are stale: update their doc strings so they no longer claim the settings apply only to 1D parameters or only to Muon-routed params; instead state that these values are applied to explicit name-based routes (e.g., parameters routed by prefixes like "adam_" and "adamw_") as well as the prior special cases. Modify the doc text concatenated with doc_only_pt_supported in the Argument(...) calls for "adam_beta1", "adam_beta2", and "weight_decay" to mention name-based routes (adam_/adamw_) and that the optimizer also applies these settings to higher-rank parameters when routed by name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 1027-1031: The paired loops using zip over
adam_no_decay_exp_avgs/adam_no_decay_grads_fp32 and
adam_no_decay_exp_avg_sqs/grad_sq should use strict=True to ensure lengths
remain aligned; update the two zip(...) calls in the function that computes
exponential moving averages (the blocks that call ea.lerp_(...) and
eas.lerp_(...)) to zip(..., strict=True) so Ruff B905 is satisfied and any
length drift raises immediately.
---
Duplicate comments:
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 358-361: The test currently asserts exact equality on
model1.adam_scale vs model2.adam_scale using atol=0.0, rtol=0.0 which is brittle
on CUDA; update the assertion in test_hybrid_muon.py to allow a tiny tolerance
(e.g. atol=1e-6 or rtol=1e-6) when comparing torch.allclose(model1.adam_scale,
model2.adam_scale) so the Adam-path invariance intent remains but avoids flaky
failures on GPU.
---
Nitpick comments:
In `@deepmd/utils/argcheck.py`:
- Around line 2950-2972: The docs for the Argument entries adam_beta1,
adam_beta2, and weight_decay are stale: update their doc strings so they no
longer claim the settings apply only to 1D parameters or only to Muon-routed
params; instead state that these values are applied to explicit name-based
routes (e.g., parameters routed by prefixes like "adam_" and "adamw_") as well
as the prior special cases. Modify the doc text concatenated with
doc_only_pt_supported in the Argument(...) calls for "adam_beta1", "adam_beta2",
and "weight_decay" to mention name-based routes (adam_/adamw_) and that the
optimizer also applies these settings to higher-rank parameters when routed by
name.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 5f87feaa-db3d-471d-b82f-b99922b5aab4
📒 Files selected for processing (4)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_hybrid_muon.py
- Implement block-wise momentum-gradient alignment with EMA smoothing and soft scaling [0.1, 1.0] on Muon updates (magma_muon option) - Fix AdamW weight decay to use adam_lr instead of base lr - Wire magma_muon through training config and argcheck - Clean up redundant optimizer tests
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/utils/argcheck.py (1)
2924-3025:⚠️ Potential issue | 🟠 MajorKeep a migration path for the removed HybridMuon keys.
normalize(..., strict=True)in this module rejects unknown optimizer fields, so droppingmuon_2d_only/min_2d_dimhere turns older HybridMuon configs into hard schema failures before training starts. Please translate them tomuon_modeor raise a targeted deprecation error instead of relying on the generic unknown-key failure.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/utils/argcheck.py` around lines 2924 - 3025, The optimizer schema drop of legacy keys muon_2d_only and min_2d_dim causes strict normalize(..., strict=True) to reject older configs; update the HybridMuon argument handling in optimizer_hybrid_muon to accept these deprecated keys, translate them into the new muon_mode semantics (e.g., muon_2d_only=True -> muon_mode='2d'; use min_2d_dim as a threshold that maps to choosing 'flat' vs 'slice' behavior or document as ignored) before validation, and emit a clear deprecation warning (or raise a targeted DeprecationError) instead of letting normalize raise an unknown-key error so existing configs continue to work while users migrate.
🧹 Nitpick comments (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
576-576: Passnamed_parametersin the public example.
_param_name_mapstays empty unlessnamed_parametersis provided, so the newadam_/adamw_routing rules are disabled in the example as written. That makes the advertised name-based behavior easy to miss for directHybridMuonOptimizerusers.📘 Suggested doc fix
- >>> optimizer = HybridMuonOptimizer(model.parameters(), lr=5e-4) + >>> optimizer = HybridMuonOptimizer( + ... model.parameters(), + ... lr=5e-4, + ... named_parameters=tuple(model.named_parameters()), + ... )Also applies to: 617-621
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/optimizer/hybrid_muon.py` at line 576, The example constructs HybridMuonOptimizer with model.parameters(), so _param_name_map remains empty and the name-based routing for adam_/adamw_ never activates; update the example to pass model.named_parameters() to HybridMuonOptimizer (and likewise in the other example at lines 617-621) so the optimizer can populate _param_name_map and enable the adam_/adamw_ name-based routing behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 905-930: The routing of parameters between Adam/AdamW and Muon can
change with muon_mode and names; update load_state_dict() to detect when a saved
parameter's route (computed via get_adam_route(param_name),
get_effective_shape(p.shape) and get_matrix_view_shape(effective_shape,
muon_mode) using the same _param_name_map and muon_mode) differs from the route
implied by the incoming state for that parameter (the presence of Adam/AdamW
keys vs Muon keys) and either (A) migrate compatible optimizer state into the
new structure (map saved Adam/AdamW moment tensors into the Muon per-matrix
buffers or vice-versa using param id/name to locate entries and preserving
moments), or (B) fail fast with a clear exception listing the affected
param_names so a user can reinitialize checkpoints; implement one of these
behaviors in load_state_dict(), using the existing grouping logic that produces
adam_no_decay, adam_decay and Muon matrix entries to make the comparison and
migration decisions.
In `@deepmd/utils/argcheck.py`:
- Around line 2974-2982: The default lr_adjust=0.0 in the arg definition
silently changes runtime behavior; change the default to preserve prior behavior
by making lr_adjust optional/None (or explicitly keep the previous default
value) and update the normalization/forwarding logic so training code and
HybridMuonOptimizer only change behavior if lr_adjust is explicitly set;
specifically modify the lr_adjust parameter definition in argcheck (the
lr_adjust symbol) to default to None (or the previous numeric default) and
adjust the code paths in the normalization/forwarding logic and in
HybridMuonOptimizer usage so they check for None/absence and retain legacy
scaling/Adam behavior unless the user explicitly supplies lr_adjust, and update
the docstring to call out the opt-in nature and deprecation path.
---
Outside diff comments:
In `@deepmd/utils/argcheck.py`:
- Around line 2924-3025: The optimizer schema drop of legacy keys muon_2d_only
and min_2d_dim causes strict normalize(..., strict=True) to reject older
configs; update the HybridMuon argument handling in optimizer_hybrid_muon to
accept these deprecated keys, translate them into the new muon_mode semantics
(e.g., muon_2d_only=True -> muon_mode='2d'; use min_2d_dim as a threshold that
maps to choosing 'flat' vs 'slice' behavior or document as ignored) before
validation, and emit a clear deprecation warning (or raise a targeted
DeprecationError) instead of letting normalize raise an unknown-key error so
existing configs continue to work while users migrate.
---
Nitpick comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Line 576: The example constructs HybridMuonOptimizer with model.parameters(),
so _param_name_map remains empty and the name-based routing for adam_/adamw_
never activates; update the example to pass model.named_parameters() to
HybridMuonOptimizer (and likewise in the other example at lines 617-621) so the
optimizer can populate _param_name_map and enable the adam_/adamw_ name-based
routing behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 8b903a49-a213-42c5-8ce2-794af9f1d3fb
📒 Files selected for processing (4)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_hybrid_muon.py
Summary by CodeRabbit
New Features
Documentation
Tests