From 86b9199a526d52eabf86f296727174578d1d110d Mon Sep 17 00:00:00 2001 From: Hongxiao Bai Date: Wed, 29 Apr 2026 00:01:52 +0000 Subject: [PATCH 1/2] swiglu offset Signed-off-by: Hongxiao Bai --- tests/pytorch/test_fusible_ops.py | 10 ++++++++-- transformer_engine/common/activation/swiglu.cu | 8 ++++---- transformer_engine/common/cast/fp8/gated_fp8.cuh | 5 ++--- .../common/cast/mxfp8/gated_mxfp8.cuh | 10 ++++------ .../include/transformer_engine/activation.h | 6 ++++-- transformer_engine/common/util/math.h | 3 ++- .../common/util/vectorized_pointwise.h | 8 +++----- .../jax/cpp_extensions/activation.py | 15 +++++++++------ transformer_engine/jax/csrc/extensions.h | 4 +++- .../jax/csrc/extensions/activation.cpp | 6 ++++-- transformer_engine/pytorch/csrc/extensions.h | 5 +++-- .../pytorch/csrc/extensions/activation.cpp | 11 +++++++---- .../pytorch/csrc/extensions/pybind.cpp | 5 +++-- .../pytorch/module/layernorm_mlp.py | 15 +++++++-------- transformer_engine/pytorch/ops/basic/swiglu.py | 16 ++++++++++++++-- transformer_engine/pytorch/transformer.py | 5 +++-- 16 files changed, 80 insertions(+), 52 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index c73f560565..8a69d5a077 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1795,6 +1795,7 @@ def test_interleaved_swiglu(self): @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantize_forward", (False, True)) @pytest.mark.parametrize("quantize_backward", (False, True)) + @pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0)) def test_clamped_swiglu( self, *, @@ -1805,6 +1806,7 @@ def test_clamped_swiglu( quantization: Optional[str], quantize_forward: bool, quantize_backward: bool, + glu_linear_offset: float, limit: float = 0.75, alpha: float = 1.702, ): @@ -1847,7 +1849,7 @@ def test_clamped_swiglu( x_glu = x_glu.clamp(min=None, max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) out_glu = x_glu * torch.sigmoid(alpha * x_glu) - y_ref = out_glu * (x_linear + 1) + y_ref = out_glu * (x_linear + glu_linear_offset) y_ref.backward(dy_ref) # Implementation with fusible operation @@ -1858,6 +1860,7 @@ def test_clamped_swiglu( te_ops.ClampedSwiGLU( limit=limit, alpha=alpha, + glu_linear_offset=glu_linear_offset, glu_interleave_size=glu_interleave_size, ), te_ops.Quantize(forward=quantize_forward, backward=False), @@ -2240,6 +2243,7 @@ def test_interleaved_scaled_swiglu(self): @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("scales_requires_grad", (False, True)) + @pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0)) def test_scaled_clamped_qgeglu( self, *, @@ -2249,6 +2253,7 @@ def test_scaled_clamped_qgeglu( device: torch.device = "cuda", input_requires_grad: bool, scales_requires_grad: bool, + glu_linear_offset: float, limit: float = 7.0, alpha: float = 1.702, ) -> None: @@ -2293,7 +2298,7 @@ def test_scaled_clamped_qgeglu( x_glu = x_glu.clamp(min=None, max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) out_glu = x_glu * torch.sigmoid(alpha * x_glu) - y = out_glu * (x_linear + 1) + y = out_glu * (x_linear + glu_linear_offset) y_ref = scales_ref.unsqueeze(-1) * y if input_requires_grad or scales_requires_grad: y_ref.backward(dy_ref) @@ -2302,6 +2307,7 @@ def test_scaled_clamped_qgeglu( glu_interleave_size=glu_interleave_size, limit=limit, alpha=alpha, + glu_linear_offset=glu_linear_offset, ) y_test = op(x_test, scales_test) if input_requires_grad or scales_requires_grad: diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 12478af4cf..6ea0bd49f5 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -85,18 +85,18 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp } void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, - cudaStream_t stream) { + float glu_linear_offset, cudaStream_t stream) { NVTE_API_CALL(nvte_clamped_swiglu); using namespace transformer_engine; - ClampedSwiGLUParam param = {limit, alpha}; + ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; gated_act_fn>(input, output, param, stream); } void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - float limit, float alpha, cudaStream_t stream) { + float limit, float alpha, float glu_linear_offset, cudaStream_t stream) { NVTE_API_CALL(nvte_clamped_dswiglu); using namespace transformer_engine; - ClampedSwiGLUParam param = {limit, alpha}; + ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; dgated_act_fn, clamped_dsilu>( grad, input, output, param, stream); } diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh index 6123d7130b..522a9add8f 100644 --- a/transformer_engine/common/cast/fp8/gated_fp8.cuh +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -169,9 +169,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); bool dgate_elt = true; // gating is ideally an identity function if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; + gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset; } if constexpr (IS_BWD) { diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 49169a4e14..83b5a49cae 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -245,9 +245,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float after_gate_elt; bool dgate_elt = true; // gating is ideally an identity function if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; + gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset; } if constexpr (IS_BWD) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); @@ -510,9 +509,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float after_gate_elt; bool dgate_elt = true; if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; + gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset; } if constexpr (IS_BWD) { float grad_elt = static_cast(in_grad.data.elt[e]); diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 854f52c203..8ef20d4dc0 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -336,10 +336,11 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) * It computes Act(input[N, :H]) x input[N, H:] * \param[in] limit Clipping limits for gate and pre-activation. * \param[in] alpha Scaling factor for the sigmoid function used in the activation. + * \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0). * \param[in] stream CUDA stream used for the operation. */ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, - cudaStream_t stream); + float glu_linear_offset, cudaStream_t stream); /*! \brief Computes the gated ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -413,10 +414,11 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp * \param[in,out] output Outgoing gradient of shape [N, H * 2]. * \param[in] limit Clipping limits for gate and pre-activation. * \param[in] alpha Scaling factor for the sigmoid function used in the activation. + * \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0). * \param[in] stream CUDA stream used for the operation. */ void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - float limit, float alpha, cudaStream_t stream); + float limit, float alpha, float glu_linear_offset, cudaStream_t stream); /*! \brief Computes the gated ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 05fe2f5398..64b6fa2d48 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -13,7 +13,8 @@ struct Empty {}; struct ClampedSwiGLUParam { float limit; - float alpha = 1.702f; // Default value for QuickGELU + float alpha = 1.702f; // Default value for QuickGELU + float glu_linear_offset = 1.0f; // Offset added to the linear (gate) component after clamping }; template diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 0aa2df7d26..7707c68a08 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -434,9 +434,8 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType val2 = static_cast(loader1.separate()[i]); if constexpr (std::is_same::value) { - // Clamp the gated value and add 1 at the end ComputeType limit = p.limit; - val2 = std::min(std::max(-limit, val2), limit) + 1; + val2 = std::min(std::max(-limit, val2), limit) + p.glu_linear_offset; } ComputeType temp = static_cast(Activation(val, p) * val2); if (requires_amax) { @@ -542,10 +541,9 @@ __launch_bounds__(unary_kernel_threads) __global__ bool dgate_in = true; if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values const ComputeType limit = p.limit; - dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp - gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f; + dgate_in = gate_in <= limit && gate_in >= -limit; + gate_in = std::min(std::max(-limit, gate_in), limit) + p.glu_linear_offset; } ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 8c0edae97e..5058192c3f 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -64,6 +64,7 @@ class ClampedSwigluParams: limit: float = 7.0 alpha: float = 1.702 + glu_linear_offset: float = 1.0 def __hash__(self): """Custom hash function to ensure dataclass is hashable for jax jit to work. @@ -71,7 +72,7 @@ def __hash__(self): Returns: int: Hash value of the dataclass instance. """ - return hash((self.limit, self.alpha)) + return hash((self.limit, self.alpha, self.glu_linear_offset)) def to_ffi_lowering_dict(self): """Convert the activation parameters to a dictionary format for FFI lowering. @@ -80,7 +81,11 @@ def to_ffi_lowering_dict(self): dict: A dictionary representation of the activation parameters consumable by XLA FFI bindings for activation functions. """ - return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + return { + "limit": np.float32(self.limit), + "alpha": np.float32(self.alpha), + "glu_linear_offset": np.float32(self.glu_linear_offset), + } @dataclass(frozen=True) @@ -121,11 +126,9 @@ def _convert_to_activation_function(fn_or_string, act_params: ActivationParams): if fn_or_string == "linear": return lambda x: x if fn_or_string == "clamped_linear": - # This function is used for ClampedSwiGLU - # used in GPT OSS where the gates are not only clamped - # but also shifted by +1 limit = act_params.clamped_swiglu.limit - return lambda x: jnp.clip(x, min=-limit, max=limit) + 1 + offset = act_params.clamped_swiglu.glu_linear_offset + return lambda x: jnp.clip(x, min=-limit, max=limit) + offset if fn_or_string == "quick_gelu": return lambda x: jax.nn.sigmoid(1.702 * x) * x if fn_or_string == "squared_relu": diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ecfedc8a2..416b18ada0 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -39,6 +39,7 @@ namespace jax { struct ClampedSwigluConfig { float limit; float alpha; + float glu_linear_offset; }; struct ActivationConfig { @@ -208,7 +209,8 @@ pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, ::xla::ffi::StructMember("limit"), - ::xla::ffi::StructMember("alpha")); + ::xla::ffi::StructMember("alpha"), + ::xla::ffi::StructMember("glu_linear_offset")); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::ActivationConfig, diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index ce5828d6f3..0b5e0c9566 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -23,6 +23,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; + auto swiglu_glu_linear_offset = act_params.clamped_swiglu.glu_linear_offset; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -138,7 +139,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal break; case NVTE_Activation_Type::CLAMPED_SWIGLU: nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, - stream); + swiglu_glu_linear_offset, stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -271,6 +272,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; + auto swiglu_glu_linear_offset = act_params.clamped_swiglu.glu_linear_offset; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -447,7 +449,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, break; case NVTE_Activation_Type::CLAMPED_SWIGLU: nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - swiglu_limit, swiglu_alpha, stream); + swiglu_limit, swiglu_alpha, swiglu_glu_linear_offset, stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4a2ea7412b..93faee559d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -274,10 +274,11 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha); +py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha, + float glu_linear_offset); py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, - float limit, float alpha); + float limit, float alpha, float glu_linear_offset); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 2df3b66553..f66ca223bb 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -328,13 +328,16 @@ py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle q } /* clamped functions */ -py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) { - return activation_helper(input, quantizer, 2, limit, alpha); +py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha, + float glu_linear_offset) { + return activation_helper(input, quantizer, 2, limit, alpha, + glu_linear_offset); } py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, - float limit, float alpha) { - return dactivation_helper(grad, input, quantizer, limit, alpha); + float limit, float alpha, float glu_linear_offset) { + return dactivation_helper(grad, input, quantizer, limit, alpha, + glu_linear_offset); } } // namespace pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index eb7576d905..e8cced726f 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -183,7 +183,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer")); m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), - py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); + py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f, py::arg("glu_linear_offset") = 1.0f); /* Backward of GLU */ m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); @@ -212,7 +212,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer")); m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu, "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), - py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); + py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f, + py::arg("glu_linear_offset") = 1.0f); /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4fa7eb2856..427858a142 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1798,7 +1798,7 @@ class LayerNormMLP(TransformerEngineBaseModule): activation_params : dict, default = None Additional parameters for the activation function. At the moment, only used for ``'clamped_swiglu'`` activation which - supports ``'limit'`` and ``'alpha'`` parameters. + supports ``'limit'``, ``'alpha'``, and ``'glu_linear_offset'`` parameters. init_method : Callable, default = None used for initializing FC1 weights in the following way: ``init_method(weight)``. When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``. @@ -2451,17 +2451,16 @@ def onnx_forward( fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32 act_params = self.activation_params or {} - # Default params for clamped_swiglu in Transformer Engine - clamped_swiglu_limit, clamped_swiglu_alpha = act_params.get("limit", 7.0), act_params.get( - "alpha", 1.702 - ) + clamped_swiglu_limit = act_params.get("limit", 7.0) + clamped_swiglu_alpha = act_params.get("alpha", 1.702) + clamped_swiglu_offset = act_params.get("glu_linear_offset", 1.0) - def _clamped_swiglu(x, limit, alpha): + def _clamped_swiglu(x, limit, alpha, offset): x_glu, x_linear = x.chunk(2, dim=-1) x_glu = x_glu.clamp(min=None, max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) out_glu = x_glu * torch.sigmoid(alpha * x_glu) - y = out_glu * (x_linear + 1) + y = out_glu * (x_linear + offset) return y activation_map = { @@ -2479,7 +2478,7 @@ def _clamped_swiglu(x, limit, alpha): "silu": torch.nn.functional.silu, "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "clamped_swiglu": lambda x: _clamped_swiglu( - x, clamped_swiglu_limit, clamped_swiglu_alpha + x, clamped_swiglu_limit, clamped_swiglu_alpha, clamped_swiglu_offset ), } if self.activation not in activation_map: diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index 9c0bc86bc1..19080f769e 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -208,6 +208,9 @@ class ClampedSwiGLU(BasicOperation): The clamp limit. alpha : float The scaling factor for the sigmoid function used in the activation. + glu_linear_offset : float + Offset added to the linear (gate) component after clamping. + Set to ``0.0`` to disable the offset. cache_quantized_input : bool, default = ``False`` Quantize input tensor when caching for use in the backward pass. glu_interleave_size : int, optional @@ -222,12 +225,14 @@ def __init__( *, limit: float = 7.0, alpha: float = 1.702, + glu_linear_offset: float = 1.0, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None, ): super().__init__() self.limit: float = limit self.alpha: float = alpha + self.glu_linear_offset: float = glu_linear_offset self.cache_quantized_input: bool = cache_quantized_input self.glu_interleave_size: Optional[int] = glu_interleave_size @@ -236,12 +241,13 @@ def _tex_clamped_swiglu_forward( swiglu_in: torch.Tensor, next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: - """Call :func:`tex.clamped_swiglu` with this op's ``limit`` / ``alpha``.""" + """Call :func:`tex.clamped_swiglu` with this op's ``limit`` / ``alpha`` / ``glu_linear_offset``.""" return tex.clamped_swiglu( swiglu_in, next_op_input_quantizer, self.limit, self.alpha, + self.glu_linear_offset, ) def _tex_clamped_dswiglu( @@ -250,13 +256,14 @@ def _tex_clamped_dswiglu( swiglu_in: torch.Tensor, quantizer: Optional[Quantizer], ) -> torch.Tensor: - """Call :func:`tex.clamped_dswiglu` with this op's ``limit`` / ``alpha``.""" + """Call :func:`tex.clamped_dswiglu` with this op's ``limit`` / ``alpha`` / ``glu_linear_offset``.""" return tex.clamped_dswiglu( dy, swiglu_in, quantizer, self.limit, self.alpha, + self.glu_linear_offset, ) def op_forward( @@ -557,6 +564,9 @@ class ScaledClampedQGeGLU(_ScaledGLU): Clamp limit (see :class:`ClampedSwiGLU`). alpha : float, default ``1.702`` Sigmoid scale (see :class:`ClampedSwiGLU`). + glu_linear_offset : float, default ``1.0`` + Offset added to the linear component after clamping + (see :class:`ClampedSwiGLU`). """ @@ -566,11 +576,13 @@ def __init__( *, limit: float = 7.0, alpha: float = 1.702, + glu_linear_offset: float = 1.0, ) -> None: super().__init__(glu_interleave_size) self._clamped: ClampedSwiGLU = ClampedSwiGLU( limit=limit, alpha=alpha, + glu_linear_offset=glu_linear_offset, ) def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 4b96ccf739..d377e5f3b3 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -189,8 +189,9 @@ class TransformerLayer(torch.nn.Module): activation_params : Optional[dict], default = None Additional parameters for the activation function. At the moment, only used for ``'clamped_swiglu'`` activation which - supports ``'limit'`` and ``'alpha'`` parameters. You can set these as - ``activation_params={'limit': 7.0, 'alpha': 1.702}``. + supports ``'limit'``, ``'alpha'``, and ``'glu_linear_offset'`` parameters. + You can set these as + ``activation_params={'limit': 7.0, 'alpha': 1.702, 'glu_linear_offset': 1.0}``. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the From 1eab899349ae7d86b2e520a1da3dc8d377eed2f0 Mon Sep 17 00:00:00 2001 From: Hongxiao Bai Date: Wed, 29 Apr 2026 01:00:38 +0000 Subject: [PATCH 2/2] fix fusion pattern check Signed-off-by: Hongxiao Bai --- transformer_engine/pytorch/ops/_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index e21915a5a6..baf1b00504 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -179,6 +179,7 @@ def fuse_grouped_mlp_ops( matches_pattern = False elif isinstance(window[1], ScaledClampedQGeGLU) and ( abs(window[1]._clamped.alpha - 1.702) > 0.001 + or abs(window[1]._clamped.glu_linear_offset - 1.0) > 0.001 or not _nvidia_cudnn_frontend_supports_scaled_clamped_qgeglu() ): matches_pattern = False