From feab8abe3b13fc523e9d6f3af876de98b26ca2c8 Mon Sep 17 00:00:00 2001 From: "Kim, Jin" Date: Fri, 6 Feb 2026 18:01:47 +0900 Subject: [PATCH 1/8] Add sigmoid GLU Signed-off-by: Kim, Jin --- transformer_engine/common/CMakeLists.txt | 2 ++ transformer_engine/common/activation/glu.cu | 24 ++++++++++++++ .../include/transformer_engine/activation.h | 27 +++++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 5 +++ .../pytorch/csrc/extensions/activation.cpp | 8 +++++ .../pytorch/csrc/extensions/pybind.cpp | 6 ++++ .../pytorch/module/layernorm_mlp.py | 8 +++-- .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/activation.py | 33 +++++++++++++++++++ transformer_engine/pytorch/transformer.py | 2 +- 10 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 transformer_engine/common/activation/glu.cu diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index efe958f844..ae6ddb1714 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -168,6 +168,7 @@ list(APPEND transformer_engine_cuda_sources list(APPEND transformer_engine_cuda_arch_specific_sources activation/gelu.cu + activation/glu.cu activation/relu.cu activation/swiglu.cu cast/cast.cu @@ -352,6 +353,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) list(APPEND nvte_sources_with_fast_math activation/gelu.cu + activation/glu.cu activation/relu.cu activation/swiglu.cu) endif() diff --git a/transformer_engine/common/activation/glu.cu b/transformer_engine/common/activation/glu.cu new file mode 100644 index 0000000000..45a6670672 --- /dev/null +++ b/transformer_engine/common/activation/glu.cu @@ -0,0 +1,24 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_glu); + using namespace transformer_engine; + Empty e = {}; + gated_act_fn>(input, output, e, stream); +} + +void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_dglu); + using namespace transformer_engine; + Empty e = {}; + dgated_act_fn, dsigmoid>(grad, input, output, e, + stream); +} diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 55cd44d9de..9c7f7407c2 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -31,6 +31,7 @@ extern "C" { enum class NVTE_Activation_Type { GELU, GEGLU, + GLU, SILU, SWIGLU, RELU, @@ -152,6 +153,32 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the GLU (Gated Linear Unit) activation of the input. + * GLU(a,b) = sigmoid(a) * b + * See "Language Modeling with Gated Convolutional Networks" (arXiv:1612.08083) + * and "GLU Variants Improve Transformer" (arXiv:2002.05202). + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes sigmoid(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Computes the GLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream); + /*! \brief Computes the gated GeLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f7cf32eaf6..adbf62d1ac 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -163,6 +163,11 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = st * Activations **************************************************************************************************/ +/* GLU (sigmoid gate) */ +py::object glu(const at::Tensor &input, py::handle quantizer); + +py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + /* GELU and variants*/ py::object gelu(const at::Tensor &input, py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 9ea14e1af0..99b9c1fefa 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -246,6 +246,14 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua return dactivation_helper(grad, input, quantizer); } +py::object glu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object dglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + py::object geglu(const at::Tensor& input, py::handle quantizer) { return activation_helper(input, quantizer, 2); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 79dd9ea5ce..99a45190cf 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -132,6 +132,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false, py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt); + /* GLU (sigmoid gate) */ + m.def("glu", transformer_engine::pytorch::glu, "GLU activation", py::arg("input"), + py::arg("quantizer")); /* GELU and variants*/ m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), py::arg("quantizer")); @@ -158,6 +161,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); + /* Backward of GLU */ + m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bec6744518..5337429733 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -98,6 +98,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), + "glu": (tex.glu, tex.dglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "relu": (tex.relu, tex.drelu, None), @@ -114,6 +115,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): return { "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), "geglu": (tex.geglu, tex.dgeglu, None), + "glu": (tex.glu, tex.dglu, None), "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "relu": (tex.relu, tex.drelu, tex.dbias_drelu), @@ -136,6 +138,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), + "glu": (tex.glu, tex.dglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "relu": (tex.relu, tex.drelu, None), @@ -1665,7 +1668,7 @@ class LayerNormMLP(TransformerEngineBaseModule): type of normalization applied. activation : str, default = 'gelu' activation function used. - Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, + Options: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. activation_params : dict, default = None Additional parameters for the activation function. @@ -1884,7 +1887,7 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]: + if self.activation in ["geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition @@ -2308,6 +2311,7 @@ def _clamped_swiglu(x, limit, alpha): activation_map = { "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "glu": lambda x: torch.sigmoid(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") * x.chunk(2, -1)[1], diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 665ffe359c..e3a5a2587e 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -7,6 +7,7 @@ from .activation import ( GELU, GEGLU, + GLU, QGELU, QGEGLU, ReLU, diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 9d54e12dba..1d13a09803 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -20,6 +20,7 @@ __all__ = [ "GELU", "GEGLU", + "GLU", "QGELU", "QGEGLU", "ReLU", @@ -164,6 +165,38 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dgelu(*args, **kwargs) +class GLU(_ActivationOperation): + r"""Gated Linear Unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GLU}(a,b) = \sigma(a) * b + + where :math:`\sigma` is the sigmoid function. + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `Language Modeling with Gated Convolutional Networks`__ + and `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.glu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dglu(*args, **kwargs) + + class GEGLU(_ActivationOperation): r"""Gaussian Error Gated Linear Unit diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index fdb3869199..cf7ce5e1a4 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -184,7 +184,7 @@ class TransformerLayer(torch.nn.Module): if set to ``False``, the transformer layer will not learn any additive biases. activation : str, default = 'gelu' Type of activation used in MLP block. - Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, + Options are: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. activation_params : Optional[dict], default = None Additional parameters for the activation function. From f7c6385a9f777816c9ac5e9e0c95098488148bd2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Feb 2026 09:08:17 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Kim, Jin --- transformer_engine/pytorch/module/layernorm_mlp.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 5337429733..fa6dc2901d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1887,7 +1887,15 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]: + if self.activation in [ + "geglu", + "glu", + "qgeglu", + "reglu", + "sreglu", + "swiglu", + "clamped_swiglu", + ]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition From c986894072f92d3881915ec2d528888a341e8f10 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 6 Feb 2026 18:40:47 +0000 Subject: [PATCH 3/8] Add test for GLU op Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index a23de29e02..9eda47e199 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1557,7 +1557,10 @@ def test_make_extra_output( @pytest.mark.parametrize( "activation", - ("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"), + ( + "gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", + "glu", "srelu", "sreglu", "silu", "swiglu", + ), ) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) @@ -1577,7 +1580,7 @@ def test_activation( # Tensor dimensions in_shape = list(out_shape) - if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"): + if activation in ("geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu"): in_shape[-1] *= 2 # Skip invalid configurations @@ -1617,6 +1620,13 @@ def test_activation( elif activation == "reglu": x1, x2 = x_ref.chunk(2, dim=-1) y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "sigmoid": + y_ref = torch.nn.functional.sigmoid(x_ref) + elif activation == "glu": + x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1]) + x = x.flip(-2) # PyTorch GLU swaps gate and linear unit + x = x.reshape(in_shape) + y_ref = torch.nn.glu(x) elif activation == "srelu": y_ref = torch.nn.functional.relu(x_ref) ** 2 elif activation == "sreglu": From 4d0cfcd57936bc0629a79eaad4622ad473b5ba31 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:47:07 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fusible_ops.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9eda47e199..6c1cfd66ee 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1558,8 +1558,17 @@ def test_make_extra_output( @pytest.mark.parametrize( "activation", ( - "gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", - "glu", "srelu", "sreglu", "silu", "swiglu", + "gelu", + "geglu", + "qgelu", + "qgeglu", + "relu", + "reglu", + "glu", + "srelu", + "sreglu", + "silu", + "swiglu", ), ) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) From 19da2744918490214024ad7efe74262e5f3b0315 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 6 Feb 2026 11:15:40 -0800 Subject: [PATCH 5/8] Fix incorrect reshape Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 6c1cfd66ee..7fd1b9afe5 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1632,7 +1632,7 @@ def test_activation( elif activation == "sigmoid": y_ref = torch.nn.functional.sigmoid(x_ref) elif activation == "glu": - x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1]) + x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1] // 2) x = x.flip(-2) # PyTorch GLU swaps gate and linear unit x = x.reshape(in_shape) y_ref = torch.nn.glu(x) From e970a7faec3c1f5b2e4b6469bfcc0df136fa9f5d Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 6 Feb 2026 17:28:03 -0800 Subject: [PATCH 6/8] Apply suggestion from @timmoon10 Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7fd1b9afe5..1fc16e1ab4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1635,7 +1635,7 @@ def test_activation( x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1] // 2) x = x.flip(-2) # PyTorch GLU swaps gate and linear unit x = x.reshape(in_shape) - y_ref = torch.nn.glu(x) + y_ref = torch.nn.functional.glu(x) elif activation == "srelu": y_ref = torch.nn.functional.relu(x_ref) ** 2 elif activation == "sreglu": From d7364587510ee89bcfef4f346200105be5b17e60 Mon Sep 17 00:00:00 2001 From: "Kim, Jin" Date: Mon, 9 Feb 2026 20:46:18 +0900 Subject: [PATCH 7/8] Add omitted tests for GLU op Signed-off-by: Kim, Jin --- tests/pytorch/test_fusible_ops.py | 1 + tests/pytorch/test_numerics.py | 1 + tests/pytorch/test_sanity.py | 1 + 3 files changed, 3 insertions(+) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1fc16e1ab4..7659e488c4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1655,6 +1655,7 @@ def test_activation( make_op = dict( gelu=te_ops.GELU, geglu=te_ops.GEGLU, + glu=te_ops.GLU, qgelu=te_ops.QGELU, qgeglu=te_ops.QGEGLU, relu=te_ops.ReLU, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index abe2806e66..4fc610ddb9 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -89,6 +89,7 @@ all_activations = [ "gelu", "geglu", + "glu", "qgelu", "qgeglu", "relu", diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index e9d24c1a8e..69bf811d2e 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -113,6 +113,7 @@ def nvfp4_vanilla(): all_activations = [ "gelu", "geglu", + "glu", "qgelu", "qgeglu", "relu", From 3b3172bc5de0ea8219eaddd84de7adadcf1640d0 Mon Sep 17 00:00:00 2001 From: "Kim, Jin" Date: Mon, 9 Feb 2026 20:50:18 +0900 Subject: [PATCH 8/8] Add GLU activation type support in JAX extension Signed-off-by: Kim, Jin --- transformer_engine/jax/cpp_extensions/activation.py | 1 + transformer_engine/jax/csrc/extensions/activation.cpp | 6 ++++++ transformer_engine/jax/csrc/extensions/pybind.cpp | 1 + 3 files changed, 8 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 573603ef3a..8c0edae97e 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -44,6 +44,7 @@ ActivationEnum = { ("gelu",): NVTE_Activation_Type.GELU, ("gelu", "linear"): NVTE_Activation_Type.GEGLU, + ("sigmoid", "linear"): NVTE_Activation_Type.GLU, ("silu",): NVTE_Activation_Type.SILU, ("silu", "linear"): NVTE_Activation_Type.SWIGLU, ("relu",): NVTE_Activation_Type.RELU, diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 6c5a976344..ce5828d6f3 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -109,6 +109,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal case NVTE_Activation_Type::GEGLU: nvte_geglu(input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::GLU: + nvte_glu(input_tensor.data(), output_tensor.data(), stream); + break; case NVTE_Activation_Type::SILU: nvte_silu(input_tensor.data(), output_tensor.data(), stream); break; @@ -427,6 +430,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, case NVTE_Activation_Type::GEGLU: nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::GLU: + nvte_dglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; case NVTE_Activation_Type::SWIGLU: nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index a5986404c9..bd4b8fe2c2 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -150,6 +150,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) .value("GELU", NVTE_Activation_Type::GELU) .value("GEGLU", NVTE_Activation_Type::GEGLU) + .value("GLU", NVTE_Activation_Type::GLU) .value("SILU", NVTE_Activation_Type::SILU) .value("SWIGLU", NVTE_Activation_Type::SWIGLU) .value("RELU", NVTE_Activation_Type::RELU)