Skip to content

Conversation

@yaox12
Copy link
Member

@yaox12 yaox12 commented Jan 29, 2026

Description

  • Added a new score func sqrtsoftplus
  • Add tests
  • All tests are passing

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

yaox12 and others added 5 commits January 29, 2026 10:28
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 self-assigned this Feb 6, 2026
@yaox12 yaox12 marked this pull request as ready for review February 6, 2026 06:41
@yaox12 yaox12 added the MoE label Feb 6, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR extends the fused MoE router to support a new score function, sqrtsoftplus, end-to-end across the CUDA kernels, C++/pybind extension layer, and the Python autograd wrappers.

Key integration points:

  • CUDA: Adds score_function==2 branches in fused_topk_with_score_function and fused_score_for_moe_aux_loss kernels, including normalization behavior for topk>1 and a dedicated sqrtsoftplus backward.
  • C++ extension: Extends the score_function_map to include sqrtsoftplus and updates validation to allow expert_bias with sigmoid and sqrtsoftplus.
  • Python: Updates router wrappers/docstrings and expands unit tests to compare fused vs PyTorch reference behavior for sqrtsoftplus.

Issue to address before merge:

  • transformer_engine/common/fused_router/utils.h has a misleading comment that says intermediate_output stores sqrtsoftplus output; in the new implementation, forward stores original logits for sqrtsoftplus and backward recomputes the activation output from those logits. This should be corrected to avoid future misuse.

Confidence Score: 4/5

  • This PR is mostly safe to merge, with one small correctness-of-documentation issue to fix in utils.h comments.
  • Kernel integration and tests cover the new sqrtsoftplus path, and validation/plumbing appear consistent across Python/C++/CUDA. The only verified issue is a misleading comment about what intermediate_output contains for sqrtsoftplus, which could cause future maintenance errors.
  • transformer_engine/common/fused_router/utils.h

Important Files Changed

Filename Overview
tests/pytorch/test_fused_router.py Refactors PyTorch reference helper and adds sqrtsoftplus routing/aux-loss coverage in tests.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Adds sqrtsoftplus (score_function==2) path in forward/backward kernels, including normalization and activation backward.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Extends fused topk router kernels to support sqrtsoftplus and allow expert_bias with sigmoid/sqrtsoftplus.
transformer_engine/common/fused_router/utils.h Implements sqrtsoftplus activation and backward helpers; contains a misleading comment about what intermediate_output stores.
transformer_engine/common/include/transformer_engine/fused_router.h Updates public API docs to include sqrtsoftplus (score_function==2) support.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Updates pybind docstrings to reflect generalized score-function router APIs.
transformer_engine/pytorch/csrc/extensions/router.cpp Adds sqrtsoftplus to score_function_map and extends input validation for expert_bias/score_function combinations.
transformer_engine/pytorch/router.py Updates Python router wrappers/docstrings to allow sqrtsoftplus score_function.

Sequence Diagram

sequenceDiagram
    participant User as Python user
    participant RouterPy as transformer_engine/pytorch/router.py
    participant Tex as transformer_engine_torch (pybind)
    participant RouterCpp as pytorch/csrc/extensions/router.cpp
    participant Nvte as common/include/.../fused_router.h (nvte_*)
    participant CUDA as fused_router CUDA kernels

    User->>RouterPy: fused_topk_with_score_function(logits, topk, score_function="sqrtsoftplus", expert_bias?)
    RouterPy->>Tex: fused_topk_with_score_function_fwd(...)
    Tex->>RouterCpp: fused_topk_with_score_function_fwd(...)
    RouterCpp->>Nvte: nvte_fused_topk_with_score_function_forward(score_function=2)
    Nvte->>CUDA: forward kernel
    CUDA-->>CUDA: save intermediate_output = original logits (sqrtsoftplus)
    CUDA-->>CUDA: scores = sqrt(softplus(logits))
    CUDA-->>CUDA: (optional) add expert_bias for routing, revert after topk
    CUDA-->>CUDA: (topk>1) normalize selected scores by sum
    CUDA-->>Nvte: probs, routing_map, intermediate_output
    Nvte-->>RouterCpp: tensors
    RouterCpp-->>Tex: tensors
    Tex-->>RouterPy: probs, routing_map, intermediate_output

    User->>RouterPy: backward(grad_probs)
    RouterPy->>Tex: fused_topk_with_score_function_bwd(...)
    Tex->>RouterCpp: fused_topk_with_score_function_bwd(...)
    RouterCpp->>Nvte: nvte_fused_topk_with_score_function_backward(score_function=2)
    Nvte->>CUDA: backward kernel
    CUDA-->>CUDA: recompute sqrtsoftplus output from stored logits
    CUDA-->>CUDA: (topk>1) normalization backward
    CUDA-->>CUDA: apply sqrtsoftplus backward using stored logits
    CUDA-->>Nvte: grad_logits
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (1)

transformer_engine/common/include/transformer_engine/fused_router.h
Expert bias API mismatch

The header still documents expert_bias as "Only used at the sigmoid case" (line 27), but this PR enables expert_bias for sqrtsoftplus end-to-end (kernels handle score_function==2, and router.cpp allows it). Please update the API docs here (and any other public docs) to reflect that expert_bias is supported for sqrtsoftplus too, otherwise external callers will get contradictory guidance.

yaox12 and others added 2 commits February 6, 2026 06:55
Signed-off-by: Xin Yao <xiny@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Xin Yao <xiny@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +115 to +117
// sqrtsoftplus: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x)))
// We store the sqrtsoftplus output (y) in intermediate_output for backward
template <typename DataType>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misleading intermediate_output comment

The comment says "We store the sqrtsoftplus output (y) in intermediate_output for backward", but the forward kernels store the original logits into intermediate_output for score_function==2 (e.g., fused_topk_with_score_function.cu stores scores[i] before applying sqrtsoftplus, and backward reads those logits as logits_buf). This mismatch is user-visible for anyone debugging and can lead to incorrect future edits. Please update the comment to reflect that intermediate_output holds original logits for sqrtsoftplus (and activation output is recomputed in backward).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant