diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7691582f97..8920b370f9 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1809,6 +1809,7 @@ def test_interleaved_swiglu(self): @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantize_forward", (False, True)) @pytest.mark.parametrize("quantize_backward", (False, True)) + @pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0)) def test_clamped_swiglu( self, *, @@ -1819,6 +1820,7 @@ def test_clamped_swiglu( quantization: Optional[str], quantize_forward: bool, quantize_backward: bool, + glu_linear_offset: float, limit: float = 0.75, alpha: float = 1.702, ): @@ -1861,7 +1863,7 @@ def test_clamped_swiglu( x_glu = x_glu.clamp(min=None, max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) out_glu = x_glu * torch.sigmoid(alpha * x_glu) - y_ref = out_glu * (x_linear + 1) + y_ref = out_glu * (x_linear + glu_linear_offset) y_ref.backward(dy_ref) # Implementation with fusible operation @@ -1872,6 +1874,7 @@ def test_clamped_swiglu( te_ops.ClampedSwiGLU( limit=limit, alpha=alpha, + glu_linear_offset=glu_linear_offset, glu_interleave_size=glu_interleave_size, ), te_ops.Quantize(forward=quantize_forward, backward=False), @@ -2492,6 +2495,7 @@ def test_interleaved_scaled_swiglu(self): @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("scales_requires_grad", (False, True)) + @pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0)) def test_scaled_clamped_qgeglu( self, *, @@ -2501,6 +2505,7 @@ def test_scaled_clamped_qgeglu( device: torch.device = "cuda", input_requires_grad: bool, scales_requires_grad: bool, + glu_linear_offset: float, limit: float = 7.0, alpha: float = 1.702, ) -> None: @@ -2545,7 +2550,7 @@ def test_scaled_clamped_qgeglu( x_glu = x_glu.clamp(min=None, max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) out_glu = x_glu * torch.sigmoid(alpha * x_glu) - y = out_glu * (x_linear + 1) + y = out_glu * (x_linear + glu_linear_offset) y_ref = scales_ref.unsqueeze(-1) * y if input_requires_grad or scales_requires_grad: y_ref.backward(dy_ref) @@ -2554,6 +2559,7 @@ def test_scaled_clamped_qgeglu( glu_interleave_size=glu_interleave_size, limit=limit, alpha=alpha, + glu_linear_offset=glu_linear_offset, ) y_test = op(x_test, scales_test) if input_requires_grad or scales_requires_grad: diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 12478af4cf..6ea0bd49f5 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -85,18 +85,18 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp } void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, - cudaStream_t stream) { + float glu_linear_offset, cudaStream_t stream) { NVTE_API_CALL(nvte_clamped_swiglu); using namespace transformer_engine; - ClampedSwiGLUParam param = {limit, alpha}; + ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; gated_act_fn>(input, output, param, stream); } void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - float limit, float alpha, cudaStream_t stream) { + float limit, float alpha, float glu_linear_offset, cudaStream_t stream) { NVTE_API_CALL(nvte_clamped_dswiglu); using namespace transformer_engine; - ClampedSwiGLUParam param = {limit, alpha}; + ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset}; dgated_act_fn, clamped_dsilu>( grad, input, output, param, stream); } diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh index 6123d7130b..522a9add8f 100644 --- a/transformer_engine/common/cast/fp8/gated_fp8.cuh +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -169,9 +169,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); bool dgate_elt = true; // gating is ideally an identity function if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; + gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset; } if constexpr (IS_BWD) { diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 49169a4e14..83b5a49cae 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -245,9 +245,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float after_gate_elt; bool dgate_elt = true; // gating is ideally an identity function if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; + gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset; } if constexpr (IS_BWD) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); @@ -510,9 +509,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float after_gate_elt; bool dgate_elt = true; if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; + gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset; } if constexpr (IS_BWD) { float grad_elt = static_cast(in_grad.data.elt[e]); diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 854f52c203..8ef20d4dc0 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -336,10 +336,11 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) * It computes Act(input[N, :H]) x input[N, H:] * \param[in] limit Clipping limits for gate and pre-activation. * \param[in] alpha Scaling factor for the sigmoid function used in the activation. + * \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0). * \param[in] stream CUDA stream used for the operation. */ void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, - cudaStream_t stream); + float glu_linear_offset, cudaStream_t stream); /*! \brief Computes the gated ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, @@ -413,10 +414,11 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp * \param[in,out] output Outgoing gradient of shape [N, H * 2]. * \param[in] limit Clipping limits for gate and pre-activation. * \param[in] alpha Scaling factor for the sigmoid function used in the activation. + * \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0). * \param[in] stream CUDA stream used for the operation. */ void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, - float limit, float alpha, cudaStream_t stream); + float limit, float alpha, float glu_linear_offset, cudaStream_t stream); /*! \brief Computes the gated ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 05fe2f5398..64b6fa2d48 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -13,7 +13,8 @@ struct Empty {}; struct ClampedSwiGLUParam { float limit; - float alpha = 1.702f; // Default value for QuickGELU + float alpha = 1.702f; // Default value for QuickGELU + float glu_linear_offset = 1.0f; // Offset added to the linear (gate) component after clamping }; template diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 0aa2df7d26..7707c68a08 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -434,9 +434,8 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType val2 = static_cast(loader1.separate()[i]); if constexpr (std::is_same::value) { - // Clamp the gated value and add 1 at the end ComputeType limit = p.limit; - val2 = std::min(std::max(-limit, val2), limit) + 1; + val2 = std::min(std::max(-limit, val2), limit) + p.glu_linear_offset; } ComputeType temp = static_cast(Activation(val, p) * val2); if (requires_amax) { @@ -542,10 +541,9 @@ __launch_bounds__(unary_kernel_threads) __global__ bool dgate_in = true; if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values const ComputeType limit = p.limit; - dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp - gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f; + dgate_in = gate_in <= limit && gate_in >= -limit; + gate_in = std::min(std::max(-limit, gate_in), limit) + p.glu_linear_offset; } ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 8c0edae97e..5058192c3f 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -64,6 +64,7 @@ class ClampedSwigluParams: limit: float = 7.0 alpha: float = 1.702 + glu_linear_offset: float = 1.0 def __hash__(self): """Custom hash function to ensure dataclass is hashable for jax jit to work. @@ -71,7 +72,7 @@ def __hash__(self): Returns: int: Hash value of the dataclass instance. """ - return hash((self.limit, self.alpha)) + return hash((self.limit, self.alpha, self.glu_linear_offset)) def to_ffi_lowering_dict(self): """Convert the activation parameters to a dictionary format for FFI lowering. @@ -80,7 +81,11 @@ def to_ffi_lowering_dict(self): dict: A dictionary representation of the activation parameters consumable by XLA FFI bindings for activation functions. """ - return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + return { + "limit": np.float32(self.limit), + "alpha": np.float32(self.alpha), + "glu_linear_offset": np.float32(self.glu_linear_offset), + } @dataclass(frozen=True) @@ -121,11 +126,9 @@ def _convert_to_activation_function(fn_or_string, act_params: ActivationParams): if fn_or_string == "linear": return lambda x: x if fn_or_string == "clamped_linear": - # This function is used for ClampedSwiGLU - # used in GPT OSS where the gates are not only clamped - # but also shifted by +1 limit = act_params.clamped_swiglu.limit - return lambda x: jnp.clip(x, min=-limit, max=limit) + 1 + offset = act_params.clamped_swiglu.glu_linear_offset + return lambda x: jnp.clip(x, min=-limit, max=limit) + offset if fn_or_string == "quick_gelu": return lambda x: jax.nn.sigmoid(1.702 * x) * x if fn_or_string == "squared_relu": diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ecfedc8a2..416b18ada0 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -39,6 +39,7 @@ namespace jax { struct ClampedSwigluConfig { float limit; float alpha; + float glu_linear_offset; }; struct ActivationConfig { @@ -208,7 +209,8 @@ pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, ::xla::ffi::StructMember("limit"), - ::xla::ffi::StructMember("alpha")); + ::xla::ffi::StructMember("alpha"), + ::xla::ffi::StructMember("glu_linear_offset")); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::ActivationConfig, diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index ce5828d6f3..0b5e0c9566 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -23,6 +23,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; + auto swiglu_glu_linear_offset = act_params.clamped_swiglu.glu_linear_offset; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -138,7 +139,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal break; case NVTE_Activation_Type::CLAMPED_SWIGLU: nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, - stream); + swiglu_glu_linear_offset, stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -271,6 +272,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; + auto swiglu_glu_linear_offset = act_params.clamped_swiglu.glu_linear_offset; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -447,7 +449,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, break; case NVTE_Activation_Type::CLAMPED_SWIGLU: nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - swiglu_limit, swiglu_alpha, stream); + swiglu_limit, swiglu_alpha, swiglu_glu_linear_offset, stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4a2ea7412b..93faee559d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -274,10 +274,11 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha); +py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha, + float glu_linear_offset); py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, - float limit, float alpha); + float limit, float alpha, float glu_linear_offset); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 2df3b66553..f66ca223bb 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -328,13 +328,16 @@ py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle q } /* clamped functions */ -py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) { - return activation_helper(input, quantizer, 2, limit, alpha); +py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha, + float glu_linear_offset) { + return activation_helper(input, quantizer, 2, limit, alpha, + glu_linear_offset); } py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, - float limit, float alpha) { - return dactivation_helper(grad, input, quantizer, limit, alpha); + float limit, float alpha, float glu_linear_offset) { + return dactivation_helper(grad, input, quantizer, limit, alpha, + glu_linear_offset); } } // namespace pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index eb7576d905..e8cced726f 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -183,7 +183,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer")); m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), - py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); + py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f, py::arg("glu_linear_offset") = 1.0f); /* Backward of GLU */ m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); @@ -212,7 +212,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer")); m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu, "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), - py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); + py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f, + py::arg("glu_linear_offset") = 1.0f); /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4fa7eb2856..427858a142 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1798,7 +1798,7 @@ class LayerNormMLP(TransformerEngineBaseModule): activation_params : dict, default = None Additional parameters for the activation function. At the moment, only used for ``'clamped_swiglu'`` activation which - supports ``'limit'`` and ``'alpha'`` parameters. + supports ``'limit'``, ``'alpha'``, and ``'glu_linear_offset'`` parameters. init_method : Callable, default = None used for initializing FC1 weights in the following way: ``init_method(weight)``. When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``. @@ -2451,17 +2451,16 @@ def onnx_forward( fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32 act_params = self.activation_params or {} - # Default params for clamped_swiglu in Transformer Engine - clamped_swiglu_limit, clamped_swiglu_alpha = act_params.get("limit", 7.0), act_params.get( - "alpha", 1.702 - ) + clamped_swiglu_limit = act_params.get("limit", 7.0) + clamped_swiglu_alpha = act_params.get("alpha", 1.702) + clamped_swiglu_offset = act_params.get("glu_linear_offset", 1.0) - def _clamped_swiglu(x, limit, alpha): + def _clamped_swiglu(x, limit, alpha, offset): x_glu, x_linear = x.chunk(2, dim=-1) x_glu = x_glu.clamp(min=None, max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) out_glu = x_glu * torch.sigmoid(alpha * x_glu) - y = out_glu * (x_linear + 1) + y = out_glu * (x_linear + offset) return y activation_map = { @@ -2479,7 +2478,7 @@ def _clamped_swiglu(x, limit, alpha): "silu": torch.nn.functional.silu, "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "clamped_swiglu": lambda x: _clamped_swiglu( - x, clamped_swiglu_limit, clamped_swiglu_alpha + x, clamped_swiglu_limit, clamped_swiglu_alpha, clamped_swiglu_offset ), } if self.activation not in activation_map: diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 9325d87ae7..cf73cd7154 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -258,6 +258,7 @@ def fuse_grouped_mlp_ops( matches_pattern = False elif isinstance(window[1], ScaledClampedQGeGLU) and ( abs(window[1]._clamped.alpha - 1.702) > 0.001 + or abs(window[1]._clamped.glu_linear_offset - 1.0) > 0.001 ): matches_pattern = False elif window[0].num_groups != window[2].num_groups: diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index 9c0bc86bc1..19080f769e 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -208,6 +208,9 @@ class ClampedSwiGLU(BasicOperation): The clamp limit. alpha : float The scaling factor for the sigmoid function used in the activation. + glu_linear_offset : float + Offset added to the linear (gate) component after clamping. + Set to ``0.0`` to disable the offset. cache_quantized_input : bool, default = ``False`` Quantize input tensor when caching for use in the backward pass. glu_interleave_size : int, optional @@ -222,12 +225,14 @@ def __init__( *, limit: float = 7.0, alpha: float = 1.702, + glu_linear_offset: float = 1.0, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None, ): super().__init__() self.limit: float = limit self.alpha: float = alpha + self.glu_linear_offset: float = glu_linear_offset self.cache_quantized_input: bool = cache_quantized_input self.glu_interleave_size: Optional[int] = glu_interleave_size @@ -236,12 +241,13 @@ def _tex_clamped_swiglu_forward( swiglu_in: torch.Tensor, next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: - """Call :func:`tex.clamped_swiglu` with this op's ``limit`` / ``alpha``.""" + """Call :func:`tex.clamped_swiglu` with this op's ``limit`` / ``alpha`` / ``glu_linear_offset``.""" return tex.clamped_swiglu( swiglu_in, next_op_input_quantizer, self.limit, self.alpha, + self.glu_linear_offset, ) def _tex_clamped_dswiglu( @@ -250,13 +256,14 @@ def _tex_clamped_dswiglu( swiglu_in: torch.Tensor, quantizer: Optional[Quantizer], ) -> torch.Tensor: - """Call :func:`tex.clamped_dswiglu` with this op's ``limit`` / ``alpha``.""" + """Call :func:`tex.clamped_dswiglu` with this op's ``limit`` / ``alpha`` / ``glu_linear_offset``.""" return tex.clamped_dswiglu( dy, swiglu_in, quantizer, self.limit, self.alpha, + self.glu_linear_offset, ) def op_forward( @@ -557,6 +564,9 @@ class ScaledClampedQGeGLU(_ScaledGLU): Clamp limit (see :class:`ClampedSwiGLU`). alpha : float, default ``1.702`` Sigmoid scale (see :class:`ClampedSwiGLU`). + glu_linear_offset : float, default ``1.0`` + Offset added to the linear component after clamping + (see :class:`ClampedSwiGLU`). """ @@ -566,11 +576,13 @@ def __init__( *, limit: float = 7.0, alpha: float = 1.702, + glu_linear_offset: float = 1.0, ) -> None: super().__init__(glu_interleave_size) self._clamped: ClampedSwiGLU = ClampedSwiGLU( limit=limit, alpha=alpha, + glu_linear_offset=glu_linear_offset, ) def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 4b96ccf739..d377e5f3b3 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -189,8 +189,9 @@ class TransformerLayer(torch.nn.Module): activation_params : Optional[dict], default = None Additional parameters for the activation function. At the moment, only used for ``'clamped_swiglu'`` activation which - supports ``'limit'`` and ``'alpha'`` parameters. You can set these as - ``activation_params={'limit': 7.0, 'alpha': 1.702}``. + supports ``'limit'``, ``'alpha'``, and ``'glu_linear_offset'`` parameters. + You can set these as + ``activation_params={'limit': 7.0, 'alpha': 1.702, 'glu_linear_offset': 1.0}``. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the