-
Notifications
You must be signed in to change notification settings - Fork 633
Add sigmoid GLU #2656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add sigmoid GLU #2656
Changes from all commits
feab8ab
f7c6385
c986894
4d0cfcd
19da274
e970a7f
681624a
d736458
3b3172b
d5dac7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1557,7 +1557,19 @@ def test_make_extra_output( | |
|
|
||
| @pytest.mark.parametrize( | ||
| "activation", | ||
| ("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"), | ||
| ( | ||
| "gelu", | ||
| "geglu", | ||
| "qgelu", | ||
| "qgeglu", | ||
| "relu", | ||
| "reglu", | ||
| "glu", | ||
| "srelu", | ||
| "sreglu", | ||
| "silu", | ||
| "swiglu", | ||
| ), | ||
| ) | ||
| @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) | ||
| @pytest.mark.parametrize("dtype", _dtypes) | ||
|
|
@@ -1577,7 +1589,7 @@ def test_activation( | |
|
|
||
| # Tensor dimensions | ||
| in_shape = list(out_shape) | ||
| if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"): | ||
| if activation in ("geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu"): | ||
| in_shape[-1] *= 2 | ||
|
|
||
| # Skip invalid configurations | ||
|
|
@@ -1617,6 +1629,13 @@ def test_activation( | |
| elif activation == "reglu": | ||
| x1, x2 = x_ref.chunk(2, dim=-1) | ||
| y_ref = torch.nn.functional.relu(x1) * x2 | ||
| elif activation == "sigmoid": | ||
| y_ref = torch.nn.functional.sigmoid(x_ref) | ||
|
Comment on lines
+1632
to
+1633
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unreachable sigmoid branch
|
||
| elif activation == "glu": | ||
| x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1] // 2) | ||
| x = x.flip(-2) # PyTorch GLU swaps gate and linear unit | ||
| x = x.reshape(in_shape) | ||
| y_ref = torch.nn.functional.glu(x) | ||
| elif activation == "srelu": | ||
timmoon10 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| y_ref = torch.nn.functional.relu(x_ref) ** 2 | ||
| elif activation == "sreglu": | ||
|
|
@@ -1636,6 +1655,7 @@ def test_activation( | |
| make_op = dict( | ||
| gelu=te_ops.GELU, | ||
| geglu=te_ops.GEGLU, | ||
| glu=te_ops.GLU, | ||
| qgelu=te_ops.QGELU, | ||
| qgeglu=te_ops.QGEGLU, | ||
| relu=te_ops.ReLU, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,6 +89,7 @@ | |
| all_activations = [ | ||
| "gelu", | ||
| "geglu", | ||
| "glu", | ||
| "qgelu", | ||
| "qgeglu", | ||
| "relu", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -113,6 +113,7 @@ def nvfp4_vanilla(): | |
| all_activations = [ | ||
| "gelu", | ||
| "geglu", | ||
| "glu", | ||
| "qgelu", | ||
| "qgeglu", | ||
| "relu", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #include "../util/math.h" | ||
| #include "./activation_template.h" | ||
|
|
||
| void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_glu); | ||
| using namespace transformer_engine; | ||
| Empty e = {}; | ||
| gated_act_fn<fp32, Empty, sigmoid<fp32, fp32>>(input, output, e, stream); | ||
| } | ||
|
|
||
| void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, | ||
| cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_dglu); | ||
| using namespace transformer_engine; | ||
| Empty e = {}; | ||
| dgated_act_fn<fp32, Empty, sigmoid<fp32, fp32>, dsigmoid<fp32, fp32>>(grad, input, output, e, | ||
| stream); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |
| from .activation import ( | ||
| GELU, | ||
| GEGLU, | ||
| GLU, | ||
| QGELU, | ||
| QGEGLU, | ||
| ReLU, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sigmoid is not an option in the test, is this a leftover code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not that important, but I find it awkward to have GeLU/GeGLU, ReLU/ReGLU, SiLU/SwiGLU, and then just GLU.