Skip to content

Conversation

@singleheart
Copy link

@singleheart singleheart commented Feb 6, 2026

Description

Add the original GLU (Gated Linear Unit) activation function as described in
Dauphin et al. (2017) and referenced in
Shazeer (2020), "GLU Variants Improve Transformer".

GLU is defined as:

$$\text{GLU}(a, b) = \sigma(a) \odot b$$

where $\sigma$ is the sigmoid function and the input is split into two halves $a$ and $b$ along the last dimension.

Transformer Engine already supports several GLU variants (GEGLU, ReGLU, SReGLU, SwiGLU, etc.)
but was missing the original sigmoid-gated GLU. This PR fills that gap so that users can
simply pass activation="glu" to LayerNormMLP or TransformerLayer.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • transformer_engine/common/activation/glu.cu (new file): CUDA kernels nvte_glu and nvte_dglu using existing sigmoid/dsigmoid primitives from math.h and the gated_act_fn/dgated_act_fn templates.
  • transformer_engine/common/include/transformer_engine/activation.h: Added GLU to NVTE_Activation_Type enum; declared nvte_glu and nvte_dglu with doxygen documentation.
  • transformer_engine/common/CMakeLists.txt: Registered activation/glu.cu in both arch_specific_sources and fast_math build lists.
  • transformer_engine/pytorch/csrc/extensions/activation.cpp: Added glu() and dglu() C++ wrapper functions.
  • transformer_engine/pytorch/csrc/extensions.h: Declared glu and dglu.
  • transformer_engine/pytorch/csrc/extensions/pybind.cpp: Exposed tex.glu and tex.dglu to Python.
  • transformer_engine/pytorch/module/layernorm_mlp.py: Added "glu" to _get_act_func_supported_list (all 3 recipe branches), FC1 output-doubling condition, ONNX export activation_map, and docstring.
  • transformer_engine/pytorch/ops/basic/activation.py: Added GLU operation class with forward (tex.glu) and backward (tex.dglu).
  • transformer_engine/pytorch/ops/basic/__init__.py: Exported GLU.
  • transformer_engine/pytorch/transformer.py: Updated TransformerLayer docstring to list 'glu' as a supported activation.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR adds the original sigmoid-gated GLU activation across Transformer Engine’s C++/CUDA core, plus PyTorch and JAX bindings. It introduces new CUDA entrypoints (nvte_glu/nvte_dglu) implemented via the existing gated-activation templates, wires the new NVTE_Activation_Type::GLU through the JAX FFI dispatch, and exposes tex.glu/tex.dglu plus a new te_ops.GLU basic op in PyTorch. It also updates LayerNormMLP/TransformerLayer documentation and ONNX export activation mapping so users can select activation="glu".

Main issue found: the new PyTorch fusible-op test adds an unreachable activation == "sigmoid" branch even though sigmoid is not parametrized, which makes the test misleading. Existing prior-thread concerns about the GLU reference reshape/ordering should be addressed separately; I avoided duplicating those comments here.

Confidence Score: 4/5

  • This PR is close to safe to merge, with one clear test-quality issue to address and otherwise straightforward wiring changes.
  • Core CUDA implementation and PyTorch/JAX bindings follow established patterns for existing gated activations. The only definite problem found is an unreachable test branch added in the fusible-ops test (suggesting coverage that isn’t actually parameterized). Prior-thread issues about the GLU reference should still be resolved before relying on the new tests.
  • tests/pytorch/test_fusible_ops.py (activation reference/coverage logic)

Important Files Changed

Filename Overview
tests/pytorch/test_fusible_ops.py Adds 'glu' to activation op tests, but the GLU reference uses reshape+flip+F.glu; prior PR thread notes this reshape is invalid with doubled last-dim and will raise at runtime, so tests likely fail.
tests/pytorch/test_numerics.py Extends activation list to include 'glu' for numerics coverage; change is straightforward and consistent with other activations.
tests/pytorch/test_sanity.py Adds 'glu' to sanity activation list; simple list update consistent with existing patterns.
transformer_engine/common/CMakeLists.txt Registers new CUDA source activation/glu.cu in arch-specific sources and fast-math list; build integration appears correct.
transformer_engine/common/activation/glu.cu Adds nvte_glu/nvte_dglu kernels via existing gated_act_fn templates using sigmoid/dsigmoid primitives; implementation matches other gated activations.
transformer_engine/common/include/transformer_engine/activation.h Adds GLU to NVTE_Activation_Type and declares nvte_glu/nvte_dglu with documentation; consistent with existing activation API.
transformer_engine/jax/cpp_extensions/activation.py Maps ("sigmoid","linear") to NVTE_Activation_Type.GLU; correct if JAX activation tuples represent (gate_fn, linear) but naming may not expose a 'glu' string directly.
transformer_engine/jax/csrc/extensions/activation.cpp Adds NVTE_Activation_Type::GLU dispatch to nvte_glu/nvte_dglu in forward/backward FFI; matches existing switch patterns.
transformer_engine/jax/csrc/extensions/pybind.cpp Exposes GLU in NVTE_Activation_Type pybind enum for JAX extension; simple enum addition.
transformer_engine/pytorch/csrc/extensions.h Declares glu/dglu pybind functions; consistent with other activation declarations.
transformer_engine/pytorch/csrc/extensions/activation.cpp Implements glu/dglu wrappers via activation_helper/dactivation_helper; uses gate factor 2 for forward as expected for gated activations.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Adds tex.glu and tex.dglu bindings; consistent with existing activation bindings.
transformer_engine/pytorch/module/layernorm_mlp.py Adds 'glu' to supported activations, doubles FC1 output for glu, and adds ONNX export mapping; changes align with other gated activations.
transformer_engine/pytorch/ops/basic/init.py Exports GLU op from basic ops package; trivial update.
transformer_engine/pytorch/ops/basic/activation.py Adds GLU operation wrapper calling tex.glu/tex.dglu and documents TE vs PyTorch GLU convention; consistent with other activation op wrappers.
transformer_engine/pytorch/transformer.py Updates TransformerLayer docstring to include 'glu' as supported activation; doc-only change.

Sequence Diagram

sequenceDiagram
  participant User as User (PyTorch/JAX)
  participant TE_PY as transformer_engine.pytorch (tex)
  participant TE_JAX as transformer_engine_jax
  participant ExtCPP as C++ extensions (activation.cpp)
  participant Core as TE Core (activation.h)
  participant CUDA as CUDA kernels (glu.cu)

  alt PyTorch path
    User->>TE_PY: tex.glu(input, quantizer)
    TE_PY->>ExtCPP: glu(input, quantizer)
    ExtCPP->>Core: nvte_glu(input_tensor, output_tensor, stream)
    Core->>CUDA: gated_act_fn<sigmoid>(input, output)
    CUDA-->>Core: output
    Core-->>ExtCPP: output
    ExtCPP-->>TE_PY: output (possibly quantized)
    TE_PY-->>User: output

    User->>TE_PY: tex.dglu(grad, fwd_input, quantizer)
    TE_PY->>ExtCPP: dglu(grad, input, quantizer)
    ExtCPP->>Core: nvte_dglu(grad_tensor, input_tensor, dx_tensor, stream)
    Core->>CUDA: dgated_act_fn<sigmoid, dsigmoid>(grad, input, dx)
    CUDA-->>Core: dx
    Core-->>ExtCPP: dx
    ExtCPP-->>TE_PY: dx
    TE_PY-->>User: dx
  else JAX FFI path
    User->>TE_JAX: act_lu(x, activation_type=("sigmoid","linear"))
    TE_JAX->>ExtCPP: ActLuFFI(act_enum=GLU)
    ExtCPP->>Core: nvte_glu(input_tensor, output_tensor, stream)
    Core->>CUDA: gated_act_fn<sigmoid>(input, output)
    CUDA-->>Core: output
    Core-->>ExtCPP: output
    ExtCPP-->>TE_JAX: output buffers
    TE_JAX-->>User: output
  end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (3)

tests/pytorch/test_sanity.py
"glu" not added to test list - new activation won't be tested

all_activations = [
    "gelu",
    "geglu",
    "glu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
    "clamped_swiglu",
]

tests/pytorch/test_numerics.py
"glu" missing from test list

all_activations = [
    "gelu",
    "geglu",
    "glu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
]

tests/pytorch/test_fusible_ops.py
"glu" missing from test parameters - add glu to tuple and handle in test logic below (around line 1631)

singleheart and others added 2 commits February 6, 2026 20:15
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
for more information, see https://pre-commit.ci

Signed-off-by: Kim, Jin <jinn.kim@sk.com>
@singleheart singleheart force-pushed the feature/add-sigmoid-glu branch from 8adb19a to f7c6385 Compare February 6, 2026 11:15
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

timmoon10 and others added 2 commits February 6, 2026 18:40
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10

This comment was marked as outdated.

timmoon10
timmoon10 previously approved these changes Feb 6, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

If we're adding GLU, it would also be natural to have sigmoid. Not critical though.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (3)

transformer_engine/jax/csrc/extensions/activation.cpp
Missing GLU dispatch

NVTE_Activation_Type::GLU was added to the shared activation enum, but the JAX FFI dispatch here doesn’t handle it. Calling act_lu/dact_lu with GLU will hit the default branch and raise NVTE_ERROR("Unsupported ActivationEnum").

Fix: add explicit case NVTE_Activation_Type::GLU: branches that call nvte_glu(...) and nvte_dglu(...) (the backward path is in the DActLuDBiasQuantizeFFI switch later in this file as well).


transformer_engine/jax/csrc/extensions/pybind.cpp
GLU not exported to Python

The pybind enum export for NVTE_Activation_Type doesn’t include GLU, so JAX Python can’t reference NVTE_Activation_Type.GLU even though it exists in the C++ header. This will block using GLU from Python-side activation selection.

Fix: add .value("GLU", NVTE_Activation_Type::GLU) to this enum binding.


transformer_engine/jax/cpp_extensions/activation.py
ActivationEnum missing GLU

ActivationEnum doesn’t have an entry mapping any activation tuple to NVTE_Activation_Type.GLU, so GLU can’t be selected through the existing string/tuple activation routing in JAX.

Fix: add the appropriate key (consistent with the existing conventions, e.g. how GEGLU/REGLU/SWIGLU are represented) mapping to NVTE_Activation_Type.GLU.

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10

This comment was marked as outdated.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (1)

tests/pytorch/test_fusible_ops.py
glu is missing from the make_op dictionary, causing a KeyError when the test runs

        make_op = dict(
            gelu=te_ops.GELU,
            geglu=te_ops.GEGLU,
            glu=te_ops.GLU,
            qgelu=te_ops.QGELU,
            qgeglu=te_ops.QGEGLU,
            relu=te_ops.ReLU,
            reglu=te_ops.ReGLU,
            srelu=te_ops.SReLU,
            sreglu=te_ops.SReGLU,
            silu=te_ops.SiLU,
            swiglu=te_ops.SwiGLU,
        )[activation]

Comment on lines +1632 to +1633
elif activation == "sigmoid":
y_ref = torch.nn.functional.sigmoid(x_ref)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sigmoid is not an option in the test, is this a leftover code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not that important, but I find it awkward to have GeLU/GeGLU, ReLU/ReGLU, SiLU/SwiGLU, and then just GLU.

@ptrendx
Copy link
Member

ptrendx commented Feb 7, 2026

Could we also add the test for regular LayerNormMLP module in the test_numerics.py?
Also @jberchtold-nvidia FYI - we will also want to expose it in JAX.

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 7, 2026

Additional Comments (2)

transformer_engine/jax/csrc/extensions/pybind.cpp
Missing GLU enum export

NVTE_Activation_Type gained a new GLU value in transformer_engine/common/include/transformer_engine/activation.h, but the JAX pybind enum here doesn’t export it. This will break JAX-side code that expects the enum values to match the core header (e.g., passing NVTE_Activation_Type.GLU from Python will fail / be unavailable).

Also update the JAX activation dispatch switch to handle NVTE_Activation_Type::GLU (see transformer_engine/jax/csrc/extensions/activation.cpp:105-143).


transformer_engine/jax/csrc/extensions/activation.cpp
GLU not handled in switch

The activation dispatch switch (act_type) doesn’t include a case NVTE_Activation_Type::GLU, even though GLU was added to the shared NVTE_Activation_Type enum. If JAX passes the new enum value, this currently falls into default: and raises NVTE_ERROR("Unsupported ActivationEnum").

Add a GLU case that calls nvte_glu(...), and ensure the backward/quantized paths (other switches later in this file) are also updated consistently.

@timmoon10
Copy link
Collaborator

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 7, 2026

Additional Comments (1)

tests/pytorch/test_fusible_ops.py
missing glu in test's make_op dictionary - test will fail with KeyError when running the glu case

        make_op = dict(
            gelu=te_ops.GELU,
            geglu=te_ops.GEGLU,
            glu=te_ops.GLU,
            qgelu=te_ops.QGELU,
            qgeglu=te_ops.QGEGLU,
            relu=te_ops.ReLU,
            reglu=te_ops.ReGLU,
            srelu=te_ops.SReLU,
            sreglu=te_ops.SReGLU,
            silu=te_ops.SiLU,
            swiglu=te_ops.SwiGLU,
        )[activation]

Signed-off-by: Kim, Jin <jinn.kim@sk.com>
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@singleheart
Copy link
Author

Could we also add the test for regular LayerNormMLP module in the test_numerics.py? Also @jberchtold-nvidia FYI - we will also want to expose it in JAX.

Done. The latest commits address both points:

Tests:

  • Added GLU to all_activations in test_numerics.py (covers LayerNormMLP tests)
  • Added GLU to all_activations in test_sanity.py
  • Fixed missing glu=te_ops.GLU entry in make_op dict in test_fusible_ops.py

JAX support:

  • jax/csrc/extensions/activation.cpp: Added NVTE_Activation_Type::GLU cases for both forward (nvte_glu) and backward (nvte_dglu) dispatch
  • jax/csrc/extensions/pybind.cpp: Exported GLU enum value to Python
  • jax/cpp_extensions/activation.py: Added ("sigmoid", "linear"): NVTE_Activation_Type.GLU to ActivationEnum

@ptrendx
Copy link
Member

ptrendx commented Feb 9, 2026

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +1632 to +1633
elif activation == "sigmoid":
y_ref = torch.nn.functional.sigmoid(x_ref)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unreachable sigmoid branch

test_activation never parametrizes activation="sigmoid" (see the @pytest.mark.parametrize("activation", ...) list above), but this new elif activation == "sigmoid": branch was added anyway, so it can’t execute. This makes the test misleading/fragile (it looks like sigmoid is covered when it isn’t). Either add "sigmoid" to the parametrization or remove this branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants