-
Notifications
You must be signed in to change notification settings - Fork 632
[Common][PyTorch] Add a new score func sqrtsoftplus to the fused router
#2633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Xin Yao <xiny@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Greptile OverviewGreptile SummaryThis PR extends the fused MoE router to support a new score function, Key integration points:
Issue to address before merge:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as Python user
participant RouterPy as transformer_engine/pytorch/router.py
participant Tex as transformer_engine_torch (pybind)
participant RouterCpp as pytorch/csrc/extensions/router.cpp
participant Nvte as common/include/.../fused_router.h (nvte_*)
participant CUDA as fused_router CUDA kernels
User->>RouterPy: fused_topk_with_score_function(logits, topk, score_function="sqrtsoftplus", expert_bias?)
RouterPy->>Tex: fused_topk_with_score_function_fwd(...)
Tex->>RouterCpp: fused_topk_with_score_function_fwd(...)
RouterCpp->>Nvte: nvte_fused_topk_with_score_function_forward(score_function=2)
Nvte->>CUDA: forward kernel
CUDA-->>CUDA: save intermediate_output = original logits (sqrtsoftplus)
CUDA-->>CUDA: scores = sqrt(softplus(logits))
CUDA-->>CUDA: (optional) add expert_bias for routing, revert after topk
CUDA-->>CUDA: (topk>1) normalize selected scores by sum
CUDA-->>Nvte: probs, routing_map, intermediate_output
Nvte-->>RouterCpp: tensors
RouterCpp-->>Tex: tensors
Tex-->>RouterPy: probs, routing_map, intermediate_output
User->>RouterPy: backward(grad_probs)
RouterPy->>Tex: fused_topk_with_score_function_bwd(...)
Tex->>RouterCpp: fused_topk_with_score_function_bwd(...)
RouterCpp->>Nvte: nvte_fused_topk_with_score_function_backward(score_function=2)
Nvte->>CUDA: backward kernel
CUDA-->>CUDA: recompute sqrtsoftplus output from stored logits
CUDA-->>CUDA: (topk>1) normalization backward
CUDA-->>CUDA: apply sqrtsoftplus backward using stored logits
CUDA-->>Nvte: grad_logits
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 2 comments
Additional Comments (1)
The header still documents |
Signed-off-by: Xin Yao <xiny@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
Signed-off-by: Xin Yao <xiny@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
| // sqrtsoftplus: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x))) | ||
| // We store the sqrtsoftplus output (y) in intermediate_output for backward | ||
| template <typename DataType> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Misleading intermediate_output comment
The comment says "We store the sqrtsoftplus output (y) in intermediate_output for backward", but the forward kernels store the original logits into intermediate_output for score_function==2 (e.g., fused_topk_with_score_function.cu stores scores[i] before applying sqrtsoftplus, and backward reads those logits as logits_buf). This mismatch is user-visible for anyone debugging and can lead to incorrect future edits. Please update the comment to reflect that intermediate_output holds original logits for sqrtsoftplus (and activation output is recomputed in backward).
Description
sqrtsoftplusType of change
Changes
Please list the changes introduced in this PR:
Checklist: