Skip to content

refactor(pt): fully refactor of HybridMuon optimizer#5275

Open
OutisLi wants to merge 4 commits intodeepmodeling:masterfrom
OutisLi:pr/muon
Open

refactor(pt): fully refactor of HybridMuon optimizer#5275
OutisLi wants to merge 4 commits intodeepmodeling:masterfrom
OutisLi:pr/muon

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Mar 1, 2026

  1. refactor name-based routing
  2. add slice mode for HybridMuon opt
  3. add Magma-lite damping for Muon path

Summary by CodeRabbit

  • New Features

    • HybridMuon gains routing modes (slice, 2d, flat), name-aware routing for biases/Adam variants, and a magma_muon option for Magma-lite damping. Optimizer now accepts named parameters; deprecated 2D-only options removed.
  • Documentation

    • Updated optimizer docs to describe new routing modes, magma_muon and flash_muon options, and adjusted lr_adjust default.
  • Tests

    • Expanded tests for routing modes, Magma damping, and state compatibility; some legacy tests consolidated.

Copilot AI review requested due to automatic review settings March 1, 2026 02:25
@github-actions github-actions bot added the Python label Mar 1, 2026
@dosubot dosubot bot added the enhancement label Mar 1, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 1, 2026

📝 Walkthrough

Walkthrough

Reworks HybridMuonOptimizer routing to name-based decisions and a muon_mode ("2d"/"flat"/"slice"), adds optional Magma-lite damping (magma_muon), introduces batched Newton–Schulz orthogonalization and new helper APIs, and updates training wiring, arg parsing, and tests for the new routing and Magma behavior.

Changes

Cohort / File(s) Summary
Hybrid Muon Core Implementation
deepmd/pt/optimizer/hybrid_muon.py
Replaced muon_2d_only/min_2d_dim with name-based routing and muon_mode; added named_parameters, magma_muon; new helpers get_adam_route, get_effective_shape, get_matrix_view_shape; added Magma-lite damping (_compute_magma_scale*), batched Newton–Schulz orth (_batched_newton_schulz_orth), and routing/state changes. Updated init and step typing.
Training Integration
deepmd/pt/train/training.py
Propagates new HybridMuon options (muon_mode, named_parameters, magma_muon, flash_muon) into optimizer construction; removed old muon_2d_only/min_2d_dim usage.
Configuration & Argument Validation
deepmd/utils/argcheck.py
Optimizer registration updated: removed muon_2d_only/min_2d_dim, added muon_mode, flash_muon, magma_muon; adjusted lr_adjust default and docs to reflect new routing semantics.
Test Suite
source/tests/pt/test_hybrid_muon.py
Updated tests to cover name/mode routing and Magma behavior: added MAGMA_MIN_SCALE import, removed obsolete tests, and added extensive tests for slice/2D/flat routing, per-slice Magma scoring/scales, and state-dict compatibility.

Sequence Diagram

sequenceDiagram
    actor User
    participant Optimizer as HybridMuonOptimizer
    participant Router as Routing Logic
    participant Shape as Shape Analysis
    participant Magma as Magma Scaler
    participant Step as Optimizer Step

    User->>Optimizer: step(closure)
    Optimizer->>Router: get_adam_route(param_name)
    Router-->>Optimizer: route (Adam / AdamW / Muon)

    alt Muon route
        Optimizer->>Shape: get_effective_shape -> get_matrix_view_shape(muon_mode)
        Shape-->>Optimizer: matrix/view dims
        alt magma_muon enabled
            Optimizer->>Magma: _compute_magma_scales_for_bucket(...)
            Magma-->>Optimizer: damping scales
        end
        Optimizer->>Step: _batched_newton_schulz_orth -> apply muon updates
        Step-->>Optimizer: updated params/state
    else Adam/AdamW route
        Optimizer->>Step: apply Adam / AdamW update
        Step-->>Optimizer: updated params/state
    end

    Optimizer-->>User: return loss
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

breaking change

Suggested reviewers

  • wanghan-iapcm
  • njzjz
  • iProzd
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 59.46% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main objective of the changeset—a comprehensive refactor of the HybridMuon optimizer with name-based routing, slice mode, and Magma-lite damping enhancements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
deepmd/utils/argcheck.py (1)

3758-3768: Validate muon_mode values at argcheck time.

muon_mode is free-form str here, so typos pass schema normalization and fail later during optimizer construction. Consider constraining accepted values to {"2d", "flat", "slice"} in this layer for earlier, clearer errors.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/utils/argcheck.py` around lines 3758 - 3768, The muon_mode argument in
the argcheck schema (the "muon_mode" param definition) is currently an
unconstrained str which lets typos slip through; change the schema to restrict
allowed values to the set {"2d", "flat", "slice"} (e.g. use an enum/choices
validator or an explicit check) so validation fails early with a clear message
referencing muon_mode when an invalid value is provided.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 343-345: Define specific exception classes (e.g.,
InvalidTensorShapeError(ValueError) and InvalidMuonModeError(ValueError)) near
the top of the module with the full explanatory messages as their default
docstring/message, then replace the three inline multi-line ValueError raises
with simple raises of those classes: replace the shape check in
batched_newton_schulz (the current raise ValueError(... "Batched Newton-Schulz
expects a 3D tensor...")) with raise InvalidTensorShapeError, and replace both
muon_mode validation raises (the f-string multi-line and the single-line check)
with raise InvalidMuonModeError; run ruff check/format before committing.

In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 358-361: The assertion comparing optimizer state uses exact float
equality which is fragile on CUDA; update the torch.allclose call for
model1.adam_scale vs model2.adam_scale to use a small nonzero tolerance (e.g.
atol=1e-6 and/or rtol=1e-6) instead of atol=0.0, rtol=0.0 so the test checks
near-equality while remaining stable; locate the comparison around
model1.adam_scale in the test_hybrid_muon.py and replace the zero tolerances
with a tight positive tolerance.

---

Nitpick comments:
In `@deepmd/utils/argcheck.py`:
- Around line 3758-3768: The muon_mode argument in the argcheck schema (the
"muon_mode" param definition) is currently an unconstrained str which lets typos
slip through; change the schema to restrict allowed values to the set {"2d",
"flat", "slice"} (e.g. use an enum/choices validator or an explicit check) so
validation fails early with a clear message referencing muon_mode when an
invalid value is provided.

ℹ️ Review info

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f0a966b and 52b027f.

📒 Files selected for processing (4)
  • deepmd/pt/optimizer/hybrid_muon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_hybrid_muon.py

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the PyTorch HybridMuonOptimizer to use name-based routing, adds a new muon_mode routing scheme (including per-slice Muon for higher-rank tensors), and introduces optional “Magma-lite” damping applied only on the Muon update path. It also updates training/config plumbing and expands tests to cover the new routing and damping behavior.

Changes:

  • Replace muon_2d_only / min_2d_dim routing with muon_mode (2d / flat / slice) and parameter-name-based routing rules.
  • Add magma_muon option implementing per-block momentum/gradient alignment scoring and damping on Muon updates.
  • Update training arg schema + trainer optimizer construction; expand unit tests for slice-mode routing and Magma damping.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
deepmd/pt/optimizer/hybrid_muon.py Implements muon_mode routing, name-based Adam/AdamW routing, batched NS for slice mode, and Magma-lite damping.
deepmd/pt/train/training.py Wires new optimizer args (muon_mode, magma_muon) and passes named parameters for name-based routing.
deepmd/utils/argcheck.py Updates the training config schema/docs for HybridMuon to use muon_mode and adds magma_muon.
source/tests/pt/test_hybrid_muon.py Removes outdated tests and adds new coverage for slice routing, 2d routing behavior, and Magma damping state/range.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@codecov
Copy link

codecov bot commented Mar 1, 2026

Codecov Report

❌ Patch coverage is 79.51220% with 42 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.29%. Comparing base (dabb0ca) to head (a2c52c6).

Files with missing lines Patch % Lines
deepmd/pt/optimizer/hybrid_muon.py 79.41% 42 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5275      +/-   ##
==========================================
+ Coverage   82.28%   82.29%   +0.01%     
==========================================
  Files         773      773              
  Lines       77331    77411      +80     
  Branches     3659     3659              
==========================================
+ Hits        63631    63709      +78     
- Misses      12529    12531       +2     
  Partials     1171     1171              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi marked this pull request as draft March 1, 2026 03:52
@OutisLi OutisLi marked this pull request as ready for review March 8, 2026 07:05
@dosubot dosubot bot added the new feature label Mar 8, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
source/tests/pt/test_hybrid_muon.py (1)

358-361: ⚠️ Potential issue | 🟡 Minor

Avoid exact float equality for the Adam-path invariance assertion.

atol=0.0, rtol=0.0 is still brittle on CUDA. A very small tolerance keeps the intent while avoiding flaky failures.

💡 Suggested test tweak
         self.assertTrue(
-            torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0)
+            torch.allclose(model1.adam_scale, model2.adam_scale, atol=1e-7, rtol=1e-6)
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/test_hybrid_muon.py` around lines 358 - 361, The test
currently asserts exact equality on model1.adam_scale vs model2.adam_scale using
atol=0.0, rtol=0.0 which is brittle on CUDA; update the assertion in
test_hybrid_muon.py to allow a tiny tolerance (e.g. atol=1e-6 or rtol=1e-6) when
comparing torch.allclose(model1.adam_scale, model2.adam_scale) so the Adam-path
invariance intent remains but avoids flaky failures on GPU.
🧹 Nitpick comments (1)
deepmd/utils/argcheck.py (1)

2950-2972: Update the HybridMuon option docs to match the new name-based routes.

The help text here still describes adam_beta1 / adam_beta2 as 1D-only and weight_decay as Muon-only, but the optimizer now applies those settings to explicit adam_ / adamw_ routes too. Right now the generated config docs under-document the new behavior for higher-rank parameters.

💡 Suggested doc fix
         Argument(
             "adam_beta1",
             float,
             optional=True,
             default=0.9,
             doc=doc_only_pt_supported
-            + "Adam beta1 coefficient for 1D parameters (biases, norms).",
+            + "Adam beta1 coefficient for Adam-routed parameters "
+            "(1D params and explicit `adam_` / `adamw_` routes).",
         ),
         Argument(
             "adam_beta2",
             float,
             optional=True,
             default=0.95,
             doc=doc_only_pt_supported
-            + "Adam beta2 coefficient for 1D parameters (biases, norms).",
+            + "Adam beta2 coefficient for Adam-routed parameters "
+            "(1D params and explicit `adam_` / `adamw_` routes).",
         ),
         Argument(
             "weight_decay",
             float,
             optional=True,
             default=0.001,
             doc=doc_only_pt_supported
-            + "Weight decay coefficient. Applied only to Muon-routed parameters",
+            + "Weight decay coefficient. Applied to Muon-routed parameters "
+            "and `adamw_`-routed parameters.",
         ),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/utils/argcheck.py` around lines 2950 - 2972, The docs for the Argument
entries adam_beta1, adam_beta2, and weight_decay are stale: update their doc
strings so they no longer claim the settings apply only to 1D parameters or only
to Muon-routed params; instead state that these values are applied to explicit
name-based routes (e.g., parameters routed by prefixes like "adam_" and
"adamw_") as well as the prior special cases. Modify the doc text concatenated
with doc_only_pt_supported in the Argument(...) calls for "adam_beta1",
"adam_beta2", and "weight_decay" to mention name-based routes (adam_/adamw_) and
that the optimizer also applies these settings to higher-rank parameters when
routed by name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 1027-1031: The paired loops using zip over
adam_no_decay_exp_avgs/adam_no_decay_grads_fp32 and
adam_no_decay_exp_avg_sqs/grad_sq should use strict=True to ensure lengths
remain aligned; update the two zip(...) calls in the function that computes
exponential moving averages (the blocks that call ea.lerp_(...) and
eas.lerp_(...)) to zip(..., strict=True) so Ruff B905 is satisfied and any
length drift raises immediately.

---

Duplicate comments:
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 358-361: The test currently asserts exact equality on
model1.adam_scale vs model2.adam_scale using atol=0.0, rtol=0.0 which is brittle
on CUDA; update the assertion in test_hybrid_muon.py to allow a tiny tolerance
(e.g. atol=1e-6 or rtol=1e-6) when comparing torch.allclose(model1.adam_scale,
model2.adam_scale) so the Adam-path invariance intent remains but avoids flaky
failures on GPU.

---

Nitpick comments:
In `@deepmd/utils/argcheck.py`:
- Around line 2950-2972: The docs for the Argument entries adam_beta1,
adam_beta2, and weight_decay are stale: update their doc strings so they no
longer claim the settings apply only to 1D parameters or only to Muon-routed
params; instead state that these values are applied to explicit name-based
routes (e.g., parameters routed by prefixes like "adam_" and "adamw_") as well
as the prior special cases. Modify the doc text concatenated with
doc_only_pt_supported in the Argument(...) calls for "adam_beta1", "adam_beta2",
and "weight_decay" to mention name-based routes (adam_/adamw_) and that the
optimizer also applies these settings to higher-rank parameters when routed by
name.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 5f87feaa-db3d-471d-b82f-b99922b5aab4

📥 Commits

Reviewing files that changed from the base of the PR and between 52b027f and c13ca6c.

📒 Files selected for processing (4)
  • deepmd/pt/optimizer/hybrid_muon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_hybrid_muon.py

OutisLi added 4 commits March 11, 2026 11:29
- Implement block-wise momentum-gradient alignment with EMA smoothing
  and soft scaling [0.1, 1.0] on Muon updates (magma_muon option)
- Fix AdamW weight decay to use adam_lr instead of base lr
- Wire magma_muon through training config and argcheck
- Clean up redundant optimizer tests
@OutisLi OutisLi changed the title refactor(pt): full refactor of HybridMuon optimizer refactor(pt): fully refactor of HybridMuon optimizer Mar 11, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/utils/argcheck.py (1)

2924-3025: ⚠️ Potential issue | 🟠 Major

Keep a migration path for the removed HybridMuon keys.

normalize(..., strict=True) in this module rejects unknown optimizer fields, so dropping muon_2d_only / min_2d_dim here turns older HybridMuon configs into hard schema failures before training starts. Please translate them to muon_mode or raise a targeted deprecation error instead of relying on the generic unknown-key failure.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/utils/argcheck.py` around lines 2924 - 3025, The optimizer schema drop
of legacy keys muon_2d_only and min_2d_dim causes strict normalize(...,
strict=True) to reject older configs; update the HybridMuon argument handling in
optimizer_hybrid_muon to accept these deprecated keys, translate them into the
new muon_mode semantics (e.g., muon_2d_only=True -> muon_mode='2d'; use
min_2d_dim as a threshold that maps to choosing 'flat' vs 'slice' behavior or
document as ignored) before validation, and emit a clear deprecation warning (or
raise a targeted DeprecationError) instead of letting normalize raise an
unknown-key error so existing configs continue to work while users migrate.
🧹 Nitpick comments (1)
deepmd/pt/optimizer/hybrid_muon.py (1)

576-576: Pass named_parameters in the public example.

_param_name_map stays empty unless named_parameters is provided, so the new adam_ / adamw_ routing rules are disabled in the example as written. That makes the advertised name-based behavior easy to miss for direct HybridMuonOptimizer users.

📘 Suggested doc fix
-    >>> optimizer = HybridMuonOptimizer(model.parameters(), lr=5e-4)
+    >>> optimizer = HybridMuonOptimizer(
+    ...     model.parameters(),
+    ...     lr=5e-4,
+    ...     named_parameters=tuple(model.named_parameters()),
+    ... )

Also applies to: 617-621

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/optimizer/hybrid_muon.py` at line 576, The example constructs
HybridMuonOptimizer with model.parameters(), so _param_name_map remains empty
and the name-based routing for adam_/adamw_ never activates; update the example
to pass model.named_parameters() to HybridMuonOptimizer (and likewise in the
other example at lines 617-621) so the optimizer can populate _param_name_map
and enable the adam_/adamw_ name-based routing behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 905-930: The routing of parameters between Adam/AdamW and Muon can
change with muon_mode and names; update load_state_dict() to detect when a saved
parameter's route (computed via get_adam_route(param_name),
get_effective_shape(p.shape) and get_matrix_view_shape(effective_shape,
muon_mode) using the same _param_name_map and muon_mode) differs from the route
implied by the incoming state for that parameter (the presence of Adam/AdamW
keys vs Muon keys) and either (A) migrate compatible optimizer state into the
new structure (map saved Adam/AdamW moment tensors into the Muon per-matrix
buffers or vice-versa using param id/name to locate entries and preserving
moments), or (B) fail fast with a clear exception listing the affected
param_names so a user can reinitialize checkpoints; implement one of these
behaviors in load_state_dict(), using the existing grouping logic that produces
adam_no_decay, adam_decay and Muon matrix entries to make the comparison and
migration decisions.

In `@deepmd/utils/argcheck.py`:
- Around line 2974-2982: The default lr_adjust=0.0 in the arg definition
silently changes runtime behavior; change the default to preserve prior behavior
by making lr_adjust optional/None (or explicitly keep the previous default
value) and update the normalization/forwarding logic so training code and
HybridMuonOptimizer only change behavior if lr_adjust is explicitly set;
specifically modify the lr_adjust parameter definition in argcheck (the
lr_adjust symbol) to default to None (or the previous numeric default) and
adjust the code paths in the normalization/forwarding logic and in
HybridMuonOptimizer usage so they check for None/absence and retain legacy
scaling/Adam behavior unless the user explicitly supplies lr_adjust, and update
the docstring to call out the opt-in nature and deprecation path.

---

Outside diff comments:
In `@deepmd/utils/argcheck.py`:
- Around line 2924-3025: The optimizer schema drop of legacy keys muon_2d_only
and min_2d_dim causes strict normalize(..., strict=True) to reject older
configs; update the HybridMuon argument handling in optimizer_hybrid_muon to
accept these deprecated keys, translate them into the new muon_mode semantics
(e.g., muon_2d_only=True -> muon_mode='2d'; use min_2d_dim as a threshold that
maps to choosing 'flat' vs 'slice' behavior or document as ignored) before
validation, and emit a clear deprecation warning (or raise a targeted
DeprecationError) instead of letting normalize raise an unknown-key error so
existing configs continue to work while users migrate.

---

Nitpick comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Line 576: The example constructs HybridMuonOptimizer with model.parameters(),
so _param_name_map remains empty and the name-based routing for adam_/adamw_
never activates; update the example to pass model.named_parameters() to
HybridMuonOptimizer (and likewise in the other example at lines 617-621) so the
optimizer can populate _param_name_map and enable the adam_/adamw_ name-based
routing behavior.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 8b903a49-a213-42c5-8ce2-794af9f1d3fb

📥 Commits

Reviewing files that changed from the base of the PR and between c13ca6c and a2c52c6.

📒 Files selected for processing (4)
  • deepmd/pt/optimizer/hybrid_muon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_hybrid_muon.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants