-
Notifications
You must be signed in to change notification settings - Fork 632
Add sigmoid GLU #2656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add sigmoid GLU #2656
Conversation
Greptile OverviewGreptile SummaryThis PR adds the original sigmoid-gated GLU activation across Transformer Engine’s C++/CUDA core, plus PyTorch and JAX bindings. It introduces new CUDA entrypoints ( Main issue found: the new PyTorch fusible-op test adds an unreachable Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User (PyTorch/JAX)
participant TE_PY as transformer_engine.pytorch (tex)
participant TE_JAX as transformer_engine_jax
participant ExtCPP as C++ extensions (activation.cpp)
participant Core as TE Core (activation.h)
participant CUDA as CUDA kernels (glu.cu)
alt PyTorch path
User->>TE_PY: tex.glu(input, quantizer)
TE_PY->>ExtCPP: glu(input, quantizer)
ExtCPP->>Core: nvte_glu(input_tensor, output_tensor, stream)
Core->>CUDA: gated_act_fn<sigmoid>(input, output)
CUDA-->>Core: output
Core-->>ExtCPP: output
ExtCPP-->>TE_PY: output (possibly quantized)
TE_PY-->>User: output
User->>TE_PY: tex.dglu(grad, fwd_input, quantizer)
TE_PY->>ExtCPP: dglu(grad, input, quantizer)
ExtCPP->>Core: nvte_dglu(grad_tensor, input_tensor, dx_tensor, stream)
Core->>CUDA: dgated_act_fn<sigmoid, dsigmoid>(grad, input, dx)
CUDA-->>Core: dx
Core-->>ExtCPP: dx
ExtCPP-->>TE_PY: dx
TE_PY-->>User: dx
else JAX FFI path
User->>TE_JAX: act_lu(x, activation_type=("sigmoid","linear"))
TE_JAX->>ExtCPP: ActLuFFI(act_enum=GLU)
ExtCPP->>Core: nvte_glu(input_tensor, output_tensor, stream)
Core->>CUDA: gated_act_fn<sigmoid>(input, output)
CUDA-->>Core: output
Core-->>ExtCPP: output
ExtCPP-->>TE_JAX: output buffers
TE_JAX-->>User: output
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 3 comments
Additional Comments (3)
|
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
for more information, see https://pre-commit.ci Signed-off-by: Kim, Jin <jinn.kim@sk.com>
8adb19a to
f7c6385
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, pending CI
If we're adding GLU, it would also be natural to have sigmoid. Not critical though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 4 comments
Additional Comments (3)
Fix: add explicit
The pybind enum export for Fix: add
Fix: add the appropriate key (consistent with the existing conventions, e.g. how GEGLU/REGLU/SWIGLU are represented) mapping to |
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
Additional Comments (1)
|
| elif activation == "sigmoid": | ||
| y_ref = torch.nn.functional.sigmoid(x_ref) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sigmoid is not an option in the test, is this a leftover code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not that important, but I find it awkward to have GeLU/GeGLU, ReLU/ReGLU, SiLU/SwiGLU, and then just GLU.
|
Could we also add the test for regular LayerNormMLP module in the test_numerics.py? |
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 2 comments
Additional Comments (2)
Also update the JAX activation dispatch switch to handle
The activation dispatch Add a |
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
Additional Comments (1)
|
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Done. The latest commits address both points: Tests:
JAX support:
|
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16 files reviewed, 1 comment
| elif activation == "sigmoid": | ||
| y_ref = torch.nn.functional.sigmoid(x_ref) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unreachable sigmoid branch
test_activation never parametrizes activation="sigmoid" (see the @pytest.mark.parametrize("activation", ...) list above), but this new elif activation == "sigmoid": branch was added anyway, so it can’t execute. This makes the test misleading/fragile (it looks like sigmoid is covered when it isn’t). Either add "sigmoid" to the parametrization or remove this branch.
Description
Add the original GLU (Gated Linear Unit) activation function as described in
Dauphin et al. (2017) and referenced in
Shazeer (2020), "GLU Variants Improve Transformer".
GLU is defined as:
where$\sigma$ is the sigmoid function and the input is split into two halves $a$ and $b$ along the last dimension.
Transformer Engine already supports several GLU variants (GEGLU, ReGLU, SReGLU, SwiGLU, etc.)
but was missing the original sigmoid-gated GLU. This PR fills that gap so that users can
simply pass
activation="glu"toLayerNormMLPorTransformerLayer.Type of change
Changes
transformer_engine/common/activation/glu.cu(new file): CUDA kernelsnvte_gluandnvte_dgluusing existingsigmoid/dsigmoidprimitives frommath.hand thegated_act_fn/dgated_act_fntemplates.transformer_engine/common/include/transformer_engine/activation.h: AddedGLUtoNVTE_Activation_Typeenum; declarednvte_gluandnvte_dgluwith doxygen documentation.transformer_engine/common/CMakeLists.txt: Registeredactivation/glu.cuin botharch_specific_sourcesandfast_mathbuild lists.transformer_engine/pytorch/csrc/extensions/activation.cpp: Addedglu()anddglu()C++ wrapper functions.transformer_engine/pytorch/csrc/extensions.h: Declaredgluanddglu.transformer_engine/pytorch/csrc/extensions/pybind.cpp: Exposedtex.gluandtex.dgluto Python.transformer_engine/pytorch/module/layernorm_mlp.py: Added"glu"to_get_act_func_supported_list(all 3 recipe branches), FC1 output-doubling condition, ONNX exportactivation_map, and docstring.transformer_engine/pytorch/ops/basic/activation.py: AddedGLUoperation class with forward (tex.glu) and backward (tex.dglu).transformer_engine/pytorch/ops/basic/__init__.py: ExportedGLU.transformer_engine/pytorch/transformer.py: UpdatedTransformerLayerdocstring to list'glu'as a supported activation.Checklist: