Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,7 @@ def test_interleaved_swiglu(self):
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
@pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0))
def test_clamped_swiglu(
self,
*,
Expand All @@ -1819,6 +1820,7 @@ def test_clamped_swiglu(
quantization: Optional[str],
quantize_forward: bool,
quantize_backward: bool,
glu_linear_offset: float,
limit: float = 0.75,
alpha: float = 1.702,
):
Expand Down Expand Up @@ -1861,7 +1863,7 @@ def test_clamped_swiglu(
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y_ref = out_glu * (x_linear + 1)
y_ref = out_glu * (x_linear + glu_linear_offset)
y_ref.backward(dy_ref)

# Implementation with fusible operation
Expand All @@ -1872,6 +1874,7 @@ def test_clamped_swiglu(
te_ops.ClampedSwiGLU(
limit=limit,
alpha=alpha,
glu_linear_offset=glu_linear_offset,
glu_interleave_size=glu_interleave_size,
),
te_ops.Quantize(forward=quantize_forward, backward=False),
Expand Down Expand Up @@ -2492,6 +2495,7 @@ def test_interleaved_scaled_swiglu(self):
@pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128)))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("scales_requires_grad", (False, True))
@pytest.mark.parametrize("glu_linear_offset", (1.0, 0.0))
def test_scaled_clamped_qgeglu(
self,
*,
Expand All @@ -2501,6 +2505,7 @@ def test_scaled_clamped_qgeglu(
device: torch.device = "cuda",
input_requires_grad: bool,
scales_requires_grad: bool,
glu_linear_offset: float,
limit: float = 7.0,
alpha: float = 1.702,
) -> None:
Expand Down Expand Up @@ -2545,7 +2550,7 @@ def test_scaled_clamped_qgeglu(
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y = out_glu * (x_linear + 1)
y = out_glu * (x_linear + glu_linear_offset)
y_ref = scales_ref.unsqueeze(-1) * y
if input_requires_grad or scales_requires_grad:
y_ref.backward(dy_ref)
Expand All @@ -2554,6 +2559,7 @@ def test_scaled_clamped_qgeglu(
glu_interleave_size=glu_interleave_size,
limit=limit,
alpha=alpha,
glu_linear_offset=glu_linear_offset,
)
y_test = op(x_test, scales_test)
if input_requires_grad or scales_requires_grad:
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/activation/swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp
}

void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream) {
float glu_linear_offset, cudaStream_t stream) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define new APIs named nvte_clamped_swiglu_v2 and nvte_clamped_dswiglu_v2
and deprecate this API here to not break backward compatibility?

NVTE_API_CALL(nvte_clamped_swiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
}

void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream) {
float limit, float alpha, float glu_linear_offset, cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_dswiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
grad, input, output, param, stream);
}
5 changes: 2 additions & 3 deletions transformer_engine/common/cast/fp8/gated_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
}

if constexpr (IS_BWD) {
Expand Down
10 changes: 4 additions & 6 deletions transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float after_gate_elt;
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
}
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
Expand Down Expand Up @@ -510,9 +509,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float after_gate_elt;
bool dgate_elt = true;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit;
gate_elt = min(max(-p.limit, gate_elt), p.limit) + p.glu_linear_offset;
}
if constexpr (IS_BWD) {
float grad_elt = static_cast<float>(in_grad.data.elt[e]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,11 @@ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
Comment on lines +339 to 341
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking public C API change

nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.

void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream);
float glu_linear_offset, cudaStream_t stream);

/*! \brief Computes the gated ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
Expand Down Expand Up @@ -413,10 +414,11 @@ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor outp
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream);
float limit, float alpha, float glu_linear_offset, cudaStream_t stream);

/*! \brief Computes the gated ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/common/util/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ struct Empty {};

struct ClampedSwiGLUParam {
float limit;
float alpha = 1.702f; // Default value for QuickGELU
float alpha = 1.702f; // Default value for QuickGELU
float glu_linear_offset = 1.0f; // Offset added to the linear (gate) component after clamping
};

template <typename OType, typename IType>
Expand Down
8 changes: 3 additions & 5 deletions transformer_engine/common/util/vectorized_pointwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,8 @@ __launch_bounds__(unary_kernel_threads) __global__
ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);

if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// Clamp the gated value and add 1 at the end
ComputeType limit = p.limit;
val2 = std::min(std::max(-limit, val2), limit) + 1;
val2 = std::min(std::max(-limit, val2), limit) + p.glu_linear_offset;
}
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
if (requires_amax) {
Expand Down Expand Up @@ -542,10 +541,9 @@ __launch_bounds__(unary_kernel_threads) __global__
bool dgate_in = true;

if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
const ComputeType limit = p.limit;
dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp
gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f;
dgate_in = gate_in <= limit && gate_in >= -limit;
gate_in = std::min(std::max(-limit, gate_in), limit) + p.glu_linear_offset;
}

ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
Expand Down
15 changes: 9 additions & 6 deletions transformer_engine/jax/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,15 @@ class ClampedSwigluParams:

limit: float = 7.0
alpha: float = 1.702
glu_linear_offset: float = 1.0

def __hash__(self):
"""Custom hash function to ensure dataclass is hashable for jax jit to work.

Returns:
int: Hash value of the dataclass instance.
"""
return hash((self.limit, self.alpha))
return hash((self.limit, self.alpha, self.glu_linear_offset))

def to_ffi_lowering_dict(self):
"""Convert the activation parameters to a dictionary format for FFI lowering.
Expand All @@ -80,7 +81,11 @@ def to_ffi_lowering_dict(self):
dict: A dictionary representation of the activation parameters consumable by
XLA FFI bindings for activation functions.
"""
return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)}
return {
"limit": np.float32(self.limit),
"alpha": np.float32(self.alpha),
"glu_linear_offset": np.float32(self.glu_linear_offset),
}


@dataclass(frozen=True)
Expand Down Expand Up @@ -121,11 +126,9 @@ def _convert_to_activation_function(fn_or_string, act_params: ActivationParams):
if fn_or_string == "linear":
return lambda x: x
if fn_or_string == "clamped_linear":
# This function is used for ClampedSwiGLU
# used in GPT OSS where the gates are not only clamped
# but also shifted by +1
limit = act_params.clamped_swiglu.limit
return lambda x: jnp.clip(x, min=-limit, max=limit) + 1
offset = act_params.clamped_swiglu.glu_linear_offset
return lambda x: jnp.clip(x, min=-limit, max=limit) + offset
if fn_or_string == "quick_gelu":
return lambda x: jax.nn.sigmoid(1.702 * x) * x
if fn_or_string == "squared_relu":
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace jax {
struct ClampedSwigluConfig {
float limit;
float alpha;
float glu_linear_offset;
};

struct ActivationConfig {
Expand Down Expand Up @@ -208,7 +209,8 @@ pybind11::tuple GetTopkWorkspaceSizes(int batch_size, int seq_len, int k);

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig,
::xla::ffi::StructMember<float>("limit"),
::xla::ffi::StructMember<float>("alpha"));
::xla::ffi::StructMember<float>("alpha"),
::xla::ffi::StructMember<float>("glu_linear_offset"));

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
transformer_engine::jax::ActivationConfig,
Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/jax/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto swiglu_glu_linear_offset = act_params.clamped_swiglu.glu_linear_offset;

auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
Expand Down Expand Up @@ -138,7 +139,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
break;
case NVTE_Activation_Type::CLAMPED_SWIGLU:
nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha,
stream);
swiglu_glu_linear_offset, stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
Expand Down Expand Up @@ -271,6 +272,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
// parameters for clamped swiglu used in GPT OSS
auto swiglu_limit = act_params.clamped_swiglu.limit;
auto swiglu_alpha = act_params.clamped_swiglu.alpha;
auto swiglu_glu_linear_offset = act_params.clamped_swiglu.glu_linear_offset;

auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
Expand Down Expand Up @@ -447,7 +449,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
break;
case NVTE_Activation_Type::CLAMPED_SWIGLU:
nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(),
swiglu_limit, swiglu_alpha, stream);
swiglu_limit, swiglu_alpha, swiglu_glu_linear_offset, stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,11 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer);

py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha);
py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha,
float glu_linear_offset);

py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
float limit, float alpha);
float limit, float alpha, float glu_linear_offset);
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
Expand Down
11 changes: 7 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,16 @@ py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle q
}

/* clamped functions */
py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) {
return activation_helper<nullptr, nvte_clamped_swiglu>(input, quantizer, 2, limit, alpha);
py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha,
float glu_linear_offset) {
return activation_helper<nullptr, nvte_clamped_swiglu>(input, quantizer, 2, limit, alpha,
glu_linear_offset);
}

py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer,
float limit, float alpha) {
return dactivation_helper<nullptr, nvte_clamped_dswiglu>(grad, input, quantizer, limit, alpha);
float limit, float alpha, float glu_linear_offset) {
return dactivation_helper<nullptr, nvte_clamped_dswiglu>(grad, input, quantizer, limit, alpha,
glu_linear_offset);
}

} // namespace pytorch
Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("quantizer"));
m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu,
"SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"),
py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f, py::arg("glu_linear_offset") = 1.0f);
/* Backward of GLU */
m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
Expand Down Expand Up @@ -212,7 +212,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("fwd_input"), py::arg("quantizer"));
m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu,
"Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"),
py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f,
py::arg("glu_linear_offset") = 1.0f);
/* DBias + DAct fusions*/
m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
Expand Down
15 changes: 7 additions & 8 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,7 +1798,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation_params : dict, default = None
Additional parameters for the activation function.
At the moment, only used for ``'clamped_swiglu'`` activation which
supports ``'limit'`` and ``'alpha'`` parameters.
supports ``'limit'``, ``'alpha'``, and ``'glu_linear_offset'`` parameters.
init_method : Callable, default = None
used for initializing FC1 weights in the following way: ``init_method(weight)``.
When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
Expand Down Expand Up @@ -2451,17 +2451,16 @@ def onnx_forward(

fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32
act_params = self.activation_params or {}
# Default params for clamped_swiglu in Transformer Engine
clamped_swiglu_limit, clamped_swiglu_alpha = act_params.get("limit", 7.0), act_params.get(
"alpha", 1.702
)
clamped_swiglu_limit = act_params.get("limit", 7.0)
clamped_swiglu_alpha = act_params.get("alpha", 1.702)
clamped_swiglu_offset = act_params.get("glu_linear_offset", 1.0)

def _clamped_swiglu(x, limit, alpha):
def _clamped_swiglu(x, limit, alpha, offset):
x_glu, x_linear = x.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y = out_glu * (x_linear + 1)
y = out_glu * (x_linear + offset)
return y

activation_map = {
Expand All @@ -2479,7 +2478,7 @@ def _clamped_swiglu(x, limit, alpha):
"silu": torch.nn.functional.silu,
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"clamped_swiglu": lambda x: _clamped_swiglu(
x, clamped_swiglu_limit, clamped_swiglu_alpha
x, clamped_swiglu_limit, clamped_swiglu_alpha, clamped_swiglu_offset
),
}
if self.activation not in activation_map:
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def fuse_grouped_mlp_ops(
matches_pattern = False
elif isinstance(window[1], ScaledClampedQGeGLU) and (
abs(window[1]._clamped.alpha - 1.702) > 0.001
or abs(window[1]._clamped.glu_linear_offset - 1.0) > 0.001
):
matches_pattern = False
elif window[0].num_groups != window[2].num_groups:
Expand Down
Loading
Loading