diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index a23de29e02..af224276b9 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -7,6 +7,7 @@ from collections.abc import Iterable import io import math +import random from typing import Optional import pytest @@ -172,6 +173,29 @@ def make_reference_and_test_tensors( return ref, test +def assert_close( + a: Optional[torch.Tensor], + b: Optional[torch.Tensor], + *, + rtol: float, + atol: float, +) -> None: + """Assert that two tensors are close.""" + if a is None and b is None: + return + assert a is not None + assert b is not None + a = a.detach() + b = b.detach() + if isinstance(a, QuantizedTensor): + a = a.dequantize() + if isinstance(b, QuantizedTensor): + b = b.dequantize() + a = a.to(dtype=torch.float64, device="cpu") + b = b.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + class TestSequentialContainer: """Tests for sequential container""" @@ -1680,6 +1704,7 @@ def test_swiglu( quantization: Optional[str], quantize_forward: bool, quantize_backward: bool, + glu_interleave_size: Optional[int] = None, ): # Tensor dimensions @@ -1706,7 +1731,17 @@ def test_swiglu( ) # Plain PyTorch implementation - x1, x2 = x_ref.chunk(2, dim=-1) + x = x_ref + if glu_interleave_size is not None: + x = x.reshape( + *in_shape[:-1], + in_shape[-1] // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(-3, -2) + x = x.reshape(in_shape) + x1, x2 = x.chunk(2, dim=-1) y_ref = torch.nn.functional.silu(x1) * x2 y_ref.backward(dy_ref) @@ -1714,7 +1749,7 @@ def test_swiglu( recipe = make_recipe(quantization) forward = te_ops.Sequential( te_ops.Quantize(forward=False, backward=quantize_backward), - te_ops.SwiGLU(), + te_ops.SwiGLU(glu_interleave_size=glu_interleave_size), te_ops.Quantize(forward=quantize_forward, backward=False), ) with te.autocast(enabled=quantized_compute, recipe=recipe): @@ -1727,10 +1762,18 @@ def test_swiglu( tols = quantization_tols(quantization) # Check results - y_test = y_test.to(dtype=torch.float64, device="cpu") - dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(y_test, y_ref, **tols) - torch.testing.assert_close(dx_test, x_ref.grad, **tols) + assert_close(y_test, y_ref, **tols) + assert_close(x_test.grad, x_ref.grad, **tols) + + def test_interleaved_swiglu(self): + self.test_swiglu( + out_shape=(32, 192), + dtype=torch.float32, + quantization=None, + quantize_forward=False, + quantize_backward=False, + glu_interleave_size=32, + ) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) @@ -1924,6 +1967,231 @@ def test_dropout( abs(z_score) < 2.5758 ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_compute", (False, True)) + @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("weight_requires_grad", (False, True)) + def test_grouped_linear( + self, + *, + group_size: int = 4, + bias: bool, + weight_shape: tuple[int, int] = (128, 128), + split_alignment: int = 128, + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_compute: bool, + quantized_weight: bool, + input_requires_grad: bool, + weight_requires_grad: bool, + ) -> None: + """Grouped GEMM""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = (split_sizes.sum().item(), in_features) + out_shape = (in_shape[0], out_features) + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not specified") + if quantization is not None and not (quantized_compute or quantized_weight): + pytest.skip("Quantization scheme is not used") + if quantization is not None and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + ws_ref, ws_test = [], [] + bs_ref, bs_test = [], [] + for _ in range(group_size): + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + b_ref, b_test = None, None + if bias: + b_ref, b_test = make_reference_and_test_tensors( + out_features, + test_dtype=dtype, + test_device=device, + requires_grad=weight_requires_grad, + ) + ws_ref.append(w_ref) + ws_test.append(w_test) + bs_ref.append(b_ref) + bs_test.append(b_test) + + # Plain PyTorch implementation + xs_ref = torch.split(x_ref, split_sizes.tolist()) + ys_ref = [] + for x, w, b in zip(xs_ref, ws_ref, bs_ref): + ys_ref.append(torch.nn.functional.linear(x, w, bias=b)) + y_ref = torch.cat(ys_ref) + if input_requires_grad or weight_requires_grad: + y_ref.backward(dy_ref) + + # Construct fusible operation + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + op = te_ops.GroupedLinear( + group_size, + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + for group_idx in range(group_size): + getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) + if bias: + getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) + del ws_test, bs_test + for param in op.parameters(): + param.requires_grad_(requires_grad=weight_requires_grad) + + # Forward and backward pass with op + with te.autocast(enabled=quantized_compute, recipe=recipe): + y_test = op(x_test, split_sizes) + if input_requires_grad or weight_requires_grad: + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + if input_requires_grad: + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + else: + assert x_test.grad is None + for group_idx in range(group_size): + w_test = getattr(op, f"weight{group_idx}") + if weight_requires_grad: + dw_test = w_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols) + else: + assert w_test.grad is None + if bias: + b_test = getattr(op, f"bias{group_idx}") + if weight_requires_grad: + db_test = b_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols) + else: + assert b_test.grad is None + + @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) + @pytest.mark.parametrize("glu_interleave_size", (None, 32)) + @pytest.mark.parametrize("input_requires_grad", (False, True)) + @pytest.mark.parametrize("scales_requires_grad", (False, True)) + def test_scaled_swiglu( + self, + *, + in_shape: Iterable[int], + glu_interleave_size: Optional[int], + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + input_requires_grad: bool, + scales_requires_grad: bool, + ) -> None: + """Multiply two tensors""" + + # Tensor dims + out_shape = list(in_shape) + out_shape[-1] //= 2 + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=input_requires_grad, + ) + scales_ref, scales_test = make_reference_and_test_tensors( + in_shape[:-1], + test_dtype=dtype, + test_device=device, + requires_grad=scales_requires_grad, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x = x_ref + if glu_interleave_size is not None: + x = x.reshape( + -1, + in_shape[-1] // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(1, 2) + x = x.reshape(in_shape) + x1, x2 = x.chunk(2, dim=-1) + y = torch.nn.functional.silu(x1) * x2 + y_ref = scales_ref.unsqueeze(-1) * y + if input_requires_grad or scales_requires_grad: + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size) + y_test = op(x_test, scales_test) + if input_requires_grad or scales_requires_grad: + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + if input_requires_grad: + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + else: + assert x_test.grad is None + if scales_requires_grad: + ds_test = scales_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(ds_test, scales_ref.grad, **tols) + else: + assert scales_test.grad is None + class TestFusedOps: """Tests for fused operations""" @@ -2931,6 +3199,192 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("glu_interleave_size", (None, 32)) + def test_grouped_mlp( + self, + *, + group_size: int = 4, + bias: bool, + hidden_size: int = 256, + dtype: torch.dtype, + quantization: Optional[str], + device: torch.device = "cuda", + split_alignment: int = 256, + glu_interleave_size: Optional[int], + ) -> None: + """GroupedLinear + ScaledSwiGLU + GroupedLinear""" + + # Split sizes + split_sizes = [split_alignment * i for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + + # Make input shape + in_shape = (split_sizes.sum().item(), hidden_size) + out_shape = in_shape + + # Skip invalid configurations + with_quantization = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + if with_quantization and dtype not in (torch.bfloat16, torch.float16): + pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + probs_ref, probs_test = make_reference_and_test_tensors( + (in_shape[0],), + test_dtype=dtype, + test_device=device, + ) + fc1_ws_ref, fc1_ws_test = [], [] + fc1_bs_ref, fc1_bs_test = [], [] + fc2_ws_ref, fc2_ws_test = [], [] + fc2_bs_ref, fc2_bs_test = [], [] + for _ in range(group_size): + fc1_w_ref, fc1_w_test = make_reference_and_test_tensors( + (2 * hidden_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + fc2_w_ref, fc2_w_test = make_reference_and_test_tensors( + (hidden_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + fc1_b_ref, fc1_b_test = None, None + fc2_b_ref, fc2_b_test = None, None + if bias: + fc1_b_ref, fc1_b_test = make_reference_and_test_tensors( + (2 * hidden_size,), + test_dtype=dtype, + test_device=device, + ) + fc2_b_ref, fc2_b_test = make_reference_and_test_tensors( + (hidden_size,), + test_dtype=dtype, + test_device=device, + ) + fc1_ws_ref.append(fc1_w_ref) + fc1_bs_ref.append(fc1_b_ref) + fc1_ws_test.append(fc1_w_test) + fc1_bs_test.append(fc1_b_test) + fc2_ws_ref.append(fc2_w_ref) + fc2_bs_ref.append(fc2_b_ref) + fc2_ws_test.append(fc2_w_test) + fc2_bs_test.append(fc2_b_test) + with torch.no_grad(): + for t in fc1_ws_ref + fc1_ws_test + fc2_ws_ref + fc2_ws_test: + t -= 0.5 + t *= 1 / 2 + for t in (x_ref, x_test, dy_ref, dy_test): + t -= 0.5 + t *= 1 / 2 + if bias: + for t in fc1_bs_ref + fc1_bs_test + fc2_bs_ref + fc2_bs_test: + t -= 0.5 + + # Reference implementation + xs = torch.split(x_ref, split_sizes.tolist()) + probs = torch.split(probs_ref, split_sizes.tolist()) + ys = [] + for group_idx in range(group_size): + x = xs[group_idx] + x = torch.nn.functional.linear(x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx]) + if glu_interleave_size is not None: + x = x.reshape( + -1, + 2 * hidden_size // (2 * glu_interleave_size), + 2, + glu_interleave_size, + ) + x = x.transpose(1, 2) + x = x.reshape(-1, 2 * hidden_size) + x1, x2 = x.chunk(2, dim=-1) + x = torch.nn.functional.silu(x1) * x2 + x = x * probs[group_idx].unsqueeze(-1) + x = torch.nn.functional.linear(x, fc2_ws_ref[group_idx], bias=fc2_bs_ref[group_idx]) + ys.append(x) + y_ref = torch.cat(ys) + y_ref.backward(dy_ref) + + # Construct operations + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=with_quantization, recipe=recipe): + fc1 = te_ops.GroupedLinear( + group_size, + hidden_size, + 2 * hidden_size, + bias=bias, + device=device, + dtype=dtype, + ) + fc2 = te_ops.GroupedLinear( + group_size, + hidden_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + ) + module = te_ops.Sequential( + fc1, + te_ops.ScaledSwiGLU(glu_interleave_size=glu_interleave_size), + fc2, + ) + + # Copy weights + with torch.no_grad(): + for group_idx in range(group_size): + getattr(fc1, f"weight{group_idx}").copy_(fc1_ws_test[group_idx]) + getattr(fc2, f"weight{group_idx}").copy_(fc2_ws_test[group_idx]) + if bias: + getattr(fc1, f"bias{group_idx}").copy_(fc1_bs_test[group_idx]) + getattr(fc2, f"bias{group_idx}").copy_(fc2_bs_test[group_idx]) + del fc1_ws_test, fc1_bs_test, fc2_ws_test, fc2_bs_test + + # Fuse ops and perform forward and backward pass + with te.autocast(enabled=with_quantization, recipe=recipe): + y_test = module(x_test, split_sizes, probs_test, split_sizes) + y_test.backward(dy_test) + + # Loose tols for sanity checking + tols = {"rtol": 0.25, "atol": 0.5} + if quantization == "nvfp4": + tols = {"rtol": 0.5, "atol": 1} + + # Check values + assert_close(y_test, y_ref, **tols) + assert_close(x_test.grad, x_ref.grad, **tols) + assert_close(probs_test.grad, probs_ref.grad, **tols) + for group_idx in range(group_size): + assert_close( + getattr(fc2, f"weight{group_idx}").grad, + fc2_ws_ref[group_idx].grad, + **tols, + ) + assert_close( + getattr(fc1, f"weight{group_idx}").grad, + fc1_ws_ref[group_idx].grad, + **tols, + ) + class TestCustomOps: """Test with ops that are defined externally""" diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 88b58a353a..e39a53bd38 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -125,14 +125,19 @@ def forward( return torch.cat(tensors, dim=dim) data_ptr += tensor.size(dim) * data_ptr_stride + # Out-of-place concatenation when view tensors have different storage + # Note: This works around an edge case with the split_quantize + # function, which might allocate a buffer and construct + # subviews. However, in order to reduce CPU overheads, these + # views are configured manually outside of PyTorch. PyTorch + # doesn't know these views share the same memory, and it + # blocks us from reconstructing the full tensor because it + # thinks we are accessing out-of-bounds memory. + if tensors[0].untyped_storage().nbytes() < out_shape[dim] * data_ptr_stride: + return torch.cat(tensors, dim=dim) + # No-op concatenation - out = tensors[0].new() - out.set_( - tensors[0].untyped_storage(), - tensors[0].storage_offset(), - out_shape, - strides, - ) + out = tensors[0].as_strided(out_shape, strides) out.requires_grad = any(tensor.requires_grad for tensor in tensors) return out diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 665ffe359c..32da121cce 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -14,8 +14,6 @@ SReLU, SReGLU, SiLU, - SwiGLU, - ClampedSwiGLU, ) from .add_extra_input import AddExtraInput from .all_gather import AllGather @@ -24,6 +22,7 @@ from .bias import Bias from .constant_scale import ConstantScale from .dropout import Dropout +from .grouped_linear import GroupedLinear from .identity import Identity from .l2normalization import L2Normalization from .layer_norm import LayerNorm @@ -32,3 +31,4 @@ from .reduce_scatter import ReduceScatter from .reshape import Reshape from .rmsnorm import RMSNorm +from .swiglu import ClampedSwiGLU, ScaledSwiGLU, SwiGLU diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 9d54e12dba..2f1debdf5e 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -27,8 +27,6 @@ "SReLU", "SReGLU", "SiLU", - "SwiGLU", - "ClampedSwiGLU", ] @@ -355,76 +353,3 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dsilu(*args, **kwargs) - - -class SwiGLU(_ActivationOperation): - r"""Swish gated linear unit - - The input tensor is split into chunks :math:`a` and :math:`b` - along the last dimension and the following is computed: - - .. math:: - - \text{GEGLU}(a,b) = \text{SiLU}(a) * b - - where - - .. math:: - - \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} - - .. warning:: - - Transformer Engine's gated activations and PyTorch's GLU - activation follow opposite conventions for :math:`a` and - :math:`b`. Transformer Engine applies the gating function to - the first half of the input tensor, while PyTorch applies it to - the second half. - - The Sigmoid Linear Unit (SiLU) gating function is also known as - the swish function. See - `GLU Variants Improve Transformer`__ - and `Gaussian Error Linear Units (GELUs)`__. - - """ - - def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.swiglu(*args, **kwargs) - - def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.dswiglu(*args, **kwargs) - - -class ClampedSwiGLU(_ActivationOperation): - r"""GPT-OSS - Implementation based on `GPT-OSS`__. - - This activation has two differences compared to the original SwiGLU - 1. Both gate and pre-activations are clipped based on parameter limit. - 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. - - .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt - from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. - - Parameters - ---------- - limit : float - The clamp limit. - alpha : float - The scaling factor for the sigmoid function used in the activation. - cache_quantized_input : bool, default = False - Quantize input tensor when caching for use in the backward pass. - """ - - def __init__( - self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False - ): - super().__init__(cache_quantized_input=cache_quantized_input) - self.limit = limit - self.alpha = alpha - - def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) - - def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py new file mode 100644 index 0000000000..b3dc460c71 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -0,0 +1,700 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for bias.""" + +from __future__ import annotations +from collections.abc import Callable, Iterable, Sequence +import contextlib +import math +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpp_extensions import general_grouped_gemm +from ...distributed import CudaRNGStatesTracker +from ...module.base import ( + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, + get_dummy_wgrad, +) +from ...quantization import FP8GlobalStateManager, Recipe +from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, + round_up_to_nearest_multiple, +) +from .._common import is_quantized_tensor, maybe_dequantize +from ..op import BasicOperation, OperationContext + + +class GroupedLinear(BasicOperation): + r"""Apply multiple linear transformations: :math:``y_i = x_i W_i^T + b_i`` + + This is equivalent to splitting the input tensor along its first + dimension, applying a separate ``torch.nn.Linear`` to each split, + and concatenating along the first dimension. + + Parameters + ---------- + num_groups : int + Number of linear transformations. + in_features : int + Inner dimension of input tensor. + out_features : int + Inner dimension of output tensor. + bias : bool, default = ``True`` + Apply additive bias. + device : torch.device, default = default CUDA device + Tensor device. + dtype : torch.dtype, default = default dtype + Tensor datatype. + rng_state_tracker_function : callable + Function that returns ``CudaRNGStatesTracker``, which is used + for model-parallel weight initialization. + accumulate_into_main_grad : bool, default = ``False`` + Whether to directly accumulate weight gradients into the + weight's ``main_grad`` attribute instead of relying on PyTorch + autograd. The weight's ``main_grad`` must be set externally + and there is no guarantee that `grad` will be set or be + meaningful. This is primarily intented to integrate with + Megatron-LM. This argument along with weight tensor having + attribute ``overwrite_main_grad`` set to True will overwrite + ``main_grad`` instead of accumulating. + + """ + + # Operation expects input split sizes + num_extra_inputs: int = 1 + + def __init__( + self, + num_groups: int, + in_features: int, + out_features: int, + *, + bias: bool = True, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, + accumulate_into_main_grad: bool = False, + ) -> None: + super().__init__() + + # Weight tensor dimensions + self.num_groups: int = num_groups + self.in_features: int = in_features + self.out_features: int = out_features + if self.num_groups <= 0: + raise ValueError(f"Invalid number of groups ({self.num_groups})") + if self.in_features <= 0: + raise ValueError(f"Invalid input size ({self.in_features})") + if self.out_features <= 0: + raise ValueError(f"Invalid output size ({self.out_features})") + + # Weight tensor attributes + device = canonicalize_device(device) + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Initialize recipe state if needed for natively quantized weight + self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters() + if self._with_quantized_weight: + self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe()) + + # RNG state tracker + self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] + self._rng_state_tracker_function = rng_state_tracker_function + + # Register weights + self.weight0: torch.nn.Parameter + for group_idx in range(self.num_groups): + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device="meta", + dtype=dtype, + ) + self.register_parameter( + f"weight{group_idx}", + torch.nn.Parameter(weight_tensor), + ) + + # Register biases + self.bias0: Optional[torch.nn.Parameter] + for group_idx in range(self.num_groups): + bias_tensor = None + if bias: + bias_tensor = torch.empty( + self.out_features, + device="meta", + dtype=dtype, + ) + bias_tensor = torch.nn.Parameter(bias_tensor) + self.register_parameter(f"bias{group_idx}", bias_tensor) + + # Initialize weights if needed + if device.type != "meta": + self.reset_parameters() + + # Whether to accumulate weight gradient into main_grad + self._accumulate_into_main_grad: bool = accumulate_into_main_grad + + def num_quantizers(self, mode: str) -> int: + if mode == "forward": + return 2 * self.num_groups + if mode == "backward": + return self.num_groups + return 0 + + @property + def has_bias(self) -> bool: + """Whether an additive bias is being applied""" + return self.bias0 is not None + + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + # Parameter device + device = self.weight0.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize weight values + # Note: Allocate a single buffer in order to support grouped + # GEMM kernels that expect a single weight buffer. + packed_weights = torch.empty( + self.num_groups, + self.out_features, + self.in_features, + dtype=self.weight0.dtype, + device=device, + ) + weights = [packed_weights[idx] for idx in range(self.num_groups)] + for weight in weights: + init_context = contextlib.nullcontext() + if self._rng_state_tracker_function is not None: + init_context = self._rng_state_tracker_function().fork() + with init_context: + torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + + # Quantize weights if needed + if self._with_quantized_weight: + + # Configure quantizers + quantizers = [ + self.get_quantizer("forward", 2 * idx + 1) for idx in range(self.num_groups) + ] + with_rowwise_usage = True + with_columnwise_usage = torch.is_grad_enabled() + for quantizer in quantizers: + if quantizer is None: + raise RuntimeError( + "Tried to quantize weight with deferred initialization " + "due to meta device, but no quantizer was available. " + "This is most likely because the weight was initialized " + "within quantized_model_init, but the forward pass was not " + "performed within autocast." + ) + quantizer.set_usage( + rowwise=with_rowwise_usage, + columnwise=with_columnwise_usage, + ) + quantizer.internal = False + + # Quantize weights + weights = self._quantize_weights(weights, quantizers) + + # Register weights + for group_idx, weight in enumerate(weights): + if not isinstance(weight, torch.nn.Parameter): + weight = torch.nn.Parameter(weight) + setattr(self, f"weight{group_idx}", weight) + + # Initialize biases if needed + if self.bias0 is not None: + packed_biases = torch.zeros( + self.num_groups, + self.out_features, + dtype=self.bias0.dtype, + device=device, + ) + for group_idx in range(self.num_groups): + bias = torch.nn.Parameter(packed_biases[group_idx]) + setattr(self, f"bias{group_idx}", bias) + + def _quantize_weights( + self, + weights: Sequence[torch.Tensor], + quantizers: Sequence[Quantizer], + ) -> Sequence[torch.Tensor]: + """Construct quantized weight tensors.""" + + # Manually construct MXFP8 weights + if isinstance(quantizers[0], MXFP8Quantizer): + return self._quantize_weights_mxfp8(weights, quantizers) + + # Use quantizers to construct quantized weights + with torch.no_grad(): + return [quantizer(weight) for quantizer, weight in zip(quantizers, weights)] + + def _quantize_weights_mxfp8( + self, + weights: Sequence[torch.Tensor], + quantizers: Sequence[Quantizer], + ) -> Sequence[MXFP8Tensor]: + """Construct MXFP8 weight tensors. + + Instead of allocating separate buffers for each weight tensor, + this function constructs large buffers and assigns subviews to + each tensor. This is intended to support grouped GEMM kernels + that expect packed buffers. + + """ + + # Tensor dimensions + num_groups = len(weights) + out_features, in_features = weights[0].size() + packed_shape = (num_groups, out_features, in_features) + unpacked_shape = (out_features, in_features) + + # Tensor attributes + device = weights[0].device + dtype = weights[0].dtype + requires_grad = torch.is_grad_enabled() + with_rowwise_usage = quantizers[0].rowwise_usage + with_columnwise_usage = quantizers[0].columnwise_usage + + # Construct packed buffers + rowwise_data = [None] * num_groups + rowwise_scales = [None] * num_groups + columnwise_data = [None] * num_groups + columnwise_scales = [None] * num_groups + if with_rowwise_usage: + scale_shape = ( + num_groups, + round_up_to_nearest_multiple(out_features, 128), + round_up_to_nearest_multiple(in_features // 32, 4), + ) + packed_data = torch.empty(packed_shape, dtype=torch.uint8, device=device) + packed_scales = torch.empty(scale_shape, dtype=torch.uint8, device=device) + rowwise_data = [packed_data[idx] for idx in range(num_groups)] + rowwise_scales = [packed_scales[idx] for idx in range(num_groups)] + if with_columnwise_usage: + scale_shape = ( + num_groups, + round_up_to_nearest_multiple(out_features // 32, 4), + round_up_to_nearest_multiple(in_features, 128), + ) + packed_data = torch.empty(packed_shape, dtype=torch.uint8, device=device) + packed_scales = torch.empty(scale_shape, dtype=torch.uint8, device=device) + columnwise_data = [packed_data[idx] for idx in range(num_groups)] + columnwise_scales = [packed_scales[idx] for idx in range(num_groups)] + + # Construct MXFP8 tensors and cast to MXFP8 + out = [] + with torch.no_grad(): + for group_idx in range(num_groups): + weight = MXFP8Tensor( + shape=unpacked_shape, + dtype=dtype, + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise_data=rowwise_data[group_idx], + rowwise_scale_inv=rowwise_scales[group_idx], + columnwise_data=columnwise_data[group_idx], + columnwise_scale_inv=columnwise_scales[group_idx], + quantizer=quantizers[group_idx], + requires_grad=requires_grad, + with_gemm_swizzled_scales=False, + ) + weight.copy_(weights[group_idx]) + out.append(weight) + + return out + + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() + + # Initialize params if needed + if any(param.device.type == "meta" for param in self.parameters()): + self.reset_parameters() + + # Check that weights are consistent + dtype = self.weight0.dtype + device = self.weight0.device + weight_requires_grad = self.weight0.requires_grad + weight_tensor_type = type(self.weight0.data) + for group_idx in range(self.num_groups): + weight = getattr(self, f"weight{group_idx}") + if weight.dtype != dtype: + raise RuntimeError( + f"Weight {group_idx} has invalid dtype (expected {dtype}, got {weight.dtype})." + ) + if not devices_match(weight.device, device): + raise RuntimeError( + f"Weight {group_idx} has invalid device " + f"(expected {device}, got {weight.device})." + ) + if weight.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Weight {group_idx} has requires_grad={weight.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck + raise RuntimeError( + f"Weight {group_idx} has invalid tensor type " + f"(expected {weight_tensor_type.__name__}, " + f"got {type(weight.data).__name__})." + ) + + # Check that biases are consistent + for group_idx in range(self.num_groups): + bias = getattr(self, f"bias{group_idx}") + if self.has_bias: + if bias is None: + raise RuntimeError(f"Expected biases, but bias {group_idx} is uninitialized") + if bias.dtype != dtype: + raise RuntimeError( + f"Bias {group_idx} has invalid dtype (expected {dtype}, got {bias.dtype})." + ) + if not devices_match(bias.device, device): + raise RuntimeError( + f"Bias {group_idx} has invalid device " + f"(expected {device}, got {bias.device})." + ) + if bias.requires_grad != weight_requires_grad: + raise RuntimeError( + f"Bias {group_idx} has requires_grad={bias.requires_grad}, " + f"but expected requires_grad={weight_requires_grad}." + ) + else: + if bias is not None: + raise RuntimeError(f"Expected no biases, but bias {group_idx} is initialized") + + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Assume weights have consistent grad requirement + weight_requires_grad = requires_grad and self.weight0.requires_grad + + # Configure quantizer usages + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + for group_idx in range(self.num_groups): + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: + super().reset_recipe_state(recipe=recipe) + + for group_idx in range(self.num_groups): + # Input/grad output quantizers use internal tensors + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + if input_quantizer is not None: + input_quantizer.internal = True + if grad_output_quantizer is not None: + grad_output_quantizer.internal = True + + # Handle weight quantizer + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + if weight_quantizer is None: + pass + elif is_quantized_tensor(getattr(self, f"weight{group_idx}", None)): + # Make sure weight param has correct quantizer + weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + weight_quantizer.internal = False + getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) + else: + # Use internal tensors if quantized weights will not be + # exposed externally + weight_quantizer.internal = ( + not FP8GlobalStateManager.with_fp8_parameters() + and not getattr(self, "_with_quantized_weight", False) + ) + + # Recipe-specific configuration + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + if recipe is not None: + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon + grad_output_quantizer.force_pow_2_scales = ( + recipe.fp8_quant_bwd_grad.power_2_scale + ) + grad_output_quantizer.amax_epsilon_scales = ( + recipe.fp8_quant_bwd_grad.amax_epsilon + ) + + def op_forward(self, *args, **kwargs): + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs): + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + num_groups = self.num_groups + has_bias = self.has_bias + device = self.weight0.device + + # Check which grads are required + ctx = basic_op_ctxs[0] + input_requires_grad = ctx.requires_grad + weight_requires_grad = ctx.requires_grad and self.weight0.requires_grad + + # Quantizers + input_quantizers = [None] * num_groups + weight_quantizers = [None] * num_groups + grad_output_quantizers = [None] * num_groups + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + for group_idx in range(num_groups): + input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx) + weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx) + + # Get autocast dtype if needed + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = self.weight0.dtype + + # Extract split sizes from extra input + split_sizes = basic_op_extra_inputs[0][0] + split_sizes_int = [int(s) for s in split_sizes.tolist()] + if len(split_sizes_int) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_int)}.") + + # Extract params + weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)] + bs = None + if has_bias: + bs = [maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(num_groups)] + + # Convert weight dtype if needed + ws = [] + for w, quantizer in zip(weights, weight_quantizers): + if not with_quantized_compute: + w = maybe_dequantize(w, dtype) + elif with_quantized_compute and not is_quantized_tensor(w): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + w = quantizer(w) + ws.append(w) + + # Split input tensor and convert dtypes if needed + x = maybe_dequantize(input_, dtype) + xs = None + if with_quantized_compute: + for quantizer in input_quantizers: + quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + xs = tex.split_quantize(x, split_sizes_int, input_quantizers) + else: + xs = torch.split(x, split_sizes_int) + + # Allocate output tensor + in_shape = list(input_.size()) + out_shape = in_shape[:-1] + [self.out_features] + out = torch.empty(out_shape, dtype=dtype, device=device) + + # Perform GEMMs + general_grouped_gemm( + ws, + xs, + [out], + [None] * num_groups, # quantization_params + dtype, + m_splits=split_sizes_int, + bias=bs, + use_bias=has_bias, + use_split_accumulator=_2X_ACC_FPROP, + single_output=True, + ) + + # Prepare weight tensors for backward pass + if not input_requires_grad: + ws = [None] * num_groups + elif with_quantized_compute: + for w, weight_param in zip(ws, weights): + if w is not weight_param: + w.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Prepare input tensor for backward pass + if not weight_requires_grad: + xs = [None] * num_groups + elif with_quantized_compute: + for x in xs: + x.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Save state for backward pass + if ctx.requires_grad: + ctx.save_for_backward(split_sizes, *xs, *ws) + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizers = input_quantizers + ctx.weight_quantizers = weight_quantizers + ctx.grad_output_quantizers = grad_output_quantizers + ctx.grad_input_quantizers = None + ctx.dtype = dtype + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad + + return out, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + num_groups = self.num_groups + has_bias = self.has_bias + device = self.weight0.device + + # Saved tensors from forward pass + ctx = basic_op_ctxs[0] + saved_tensors = ctx.saved_tensors + split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:] + xs, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] + ws, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] + + # Split grad output tensor and convert dtypes if needed + split_sizes_int = [int(s) for s in split_sizes.tolist()] + dy = maybe_dequantize(grad_output, ctx.dtype) + dys = None + grad_biases = [None] * num_groups + if ctx.with_quantized_compute: + for quantizer in ctx.grad_output_quantizers: + quantizer.set_usage( + rowwise=ctx.input_requires_grad, + columnwise=ctx.weight_requires_grad, + ) + dys = tex.split_quantize(dy, split_sizes_int, ctx.grad_output_quantizers) + if has_bias: + grad_biases = [ + dy.reshape(-1, dy.size(-1)).sum(dim=0) + for dy in torch.split(grad_output, split_sizes_int) + ] + else: + dys = torch.split(dy, split_sizes_int) + if has_bias: + grad_biases = [dy.reshape(-1, dy.size(-1)).sum(dim=0) for dy in dys] + + # Initialize grad weight buffers + accumulate_into_main_grad = self._accumulate_into_main_grad + grad_weights = [None] * num_groups + if ctx.weight_requires_grad: + if accumulate_into_main_grad: + # Megatron-LM wgrad fusion + # Note: Get grad tensors from params so we can + # accumulate directly into it. + for group_idx in range(num_groups): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + grad_weights[group_idx] = weight_param.main_grad + accumulate_into_main_grad = not getattr(self.weight0, "overwrite_main_grad", False) + else: + weight_shape = ws[0].size() + for group_idx in range(num_groups): + grad_weights[group_idx] = torch.empty( + weight_shape, + dtype=ctx.dtype, + device=device, + ) + else: + accumulate_into_main_grad = False + + # Perform dgrad GEMMs + grad_input = None + if ctx.input_requires_grad: + out_shape = list(grad_output.size()) + in_shape = out_shape[:-1] + [self.in_features] + grad_input = torch.empty( + in_shape, + dtype=ctx.dtype, + device=device, + ) + general_grouped_gemm( + ws, + dys, + [grad_input], + [None] * num_groups, # quantization_params + ctx.dtype, + layout="NN", + m_splits=split_sizes_int, + use_split_accumulator=_2X_ACC_DGRAD, + single_output=True, + ) + + # Perform wgrad GEMMs + if ctx.weight_requires_grad: + general_grouped_gemm( + xs, + dys, + grad_weights, + [None] * num_groups, # quantization_params + ctx.dtype, + layout="NT", + m_splits=split_sizes_int, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_into_main_grad, + ) + + # Clear input tensors if possible + clear_tensor_data(*xs) + + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weights = [None] * num_groups + for group_idx in range(num_groups): + weight_param = getattr(self, f"weight{group_idx}") + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weights[group_idx] = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + grad_params = grad_weights + grad_biases if has_bias else grad_weights + return grad_input, [grad_params], [(None,)] diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py new file mode 100644 index 0000000000..104a84f148 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -0,0 +1,415 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for multiplying with extra input tensor.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...tensor import Float8CurrentScalingQuantizer, Quantizer +from ...utils import clear_tensor_data +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize + +__all__ = ["SwiGLU", "ClampedSwiGLU", "ScaledSwiGLU"] + + +class SwiGLU(BasicOperation): + r"""Swish gated linear unit + + The input tensor is split into chunks :math:``a`` and :math:``b`` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{SiLU}(a) * b + + where + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:``a`` and + :math:``b``. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + The Sigmoid Linear Unit (SiLU) gating function is also known as + the swish function. See + ``GLU Variants Improve Transformer``__ + and ``Gaussian Error Linear Units (GELUs)``__. + + """ + + def __init__( + self, *, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None + ): + super().__init__() + self.cache_quantized_input: bool = cache_quantized_input + self.glu_interleave_size: Optional[int] = glu_interleave_size + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + input_ = maybe_dequantize(input_.contiguous(), dtype) + + # Remove interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Launch kernel + out = tex.swiglu(swiglu_in, next_op_input_quantizer) + + # Quantize input to FP8 before caching if needed + if self.cache_quantized_input: + input_quantizer = Float8CurrentScalingQuantizer( + tex.DType.kFloat8E4M3, + input_.device, + ) + input_quantizer.set_usage(rowwise=True, columnwise=False) + input_ = input_quantizer(input_) + + # Save state for backward pass + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(input_) + ctx.save_for_backward(input_) + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (input_,) = ctx.saved_tensors + + # Make sure tensors have correct dtypes + x = maybe_dequantize(input_.contiguous(), ctx.dtype) + dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype) + + # Remove interleaving if needed + swiglu_in = x + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Quantizer for grad input + quantizer = ctx.prev_op_grad_output_quantizer + if self.glu_interleave_size is not None: + quantizer = None + + # Launch kernel + grad_swiglu_in = tex.dswiglu(dy, swiglu_in, quantizer) + + # Apply interleaving if needed + dx = grad_swiglu_in + if self.glu_interleave_size is not None: + shape = dx.size() + dx = dx.reshape( + -1, + 2, + shape[-1] // (2 * self.glu_interleave_size), + self.glu_interleave_size, + ) + dx = dx.transpose(1, 2).contiguous() + dx = dx.view(shape) + + # Clear input tensor if possible + clear_tensor_data(input_) + + return dx, () + + +class ClampedSwiGLU(BasicOperation): + r"""GPT-OSS + Implementation based on ``GPT-OSS``__. + + This activation has two differences compared to the original SwiGLU + 1. Both gate and pre-activations are clipped based on parameter limit. + 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. + + .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is different + from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. + + Parameters + ---------- + limit : float + The clamp limit. + alpha : float + The scaling factor for the sigmoid function used in the activation. + cache_quantized_input : bool, default = ``False`` + Quantize input tensor when caching for use in the backward pass. + """ + + def __init__( + self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False + ): + super().__init__() + self.limit: float = limit + self.alpha: float = alpha + self.cache_quantized_input: bool = cache_quantized_input + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + x = maybe_dequantize(input_.contiguous(), dtype) + + # Launch kernel + y = tex.clamped_swiglu( + x, + next_op_input_quantizer, + limit=self.limit, + alpha=self.alpha, + ) + + # Quantize input to FP8 before caching if needed + if self.cache_quantized_input: + input_quantizer = Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, x.device) + input_quantizer.set_usage(rowwise=True, columnwise=False) + x = input_quantizer(x) + + # Save state for backward pass + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x) + ctx.save_for_backward(x) + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (input_,) = ctx.saved_tensors + + # Make sure tensors have correct dtypes + x = maybe_dequantize(input_.contiguous(), ctx.dtype) + dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype) + + # Launch kernel + dx = tex.clamped_dswiglu( + dy, + x, + ctx.prev_op_grad_output_quantizer, + limit=self.limit, + alpha=self.alpha, + ) + + # Clear input tensor if possible + clear_tensor_data(input_) + + return dx, () + + +class ScaledSwiGLU(BasicOperation): + r"""SwiGLU with post-scaling + + If the SwiGLU output has shape ``(d_1, ..., d_n)``, it is + multiplied with an extra input tensor of shape + ``(d_1, ..., d_{n-1})``. + + """ + + # Operation expects scales + num_extra_inputs: int = 1 + + def __init__(self, glu_interleave_size: Optional[int] = None): + super().__init__() + self.glu_interleave_size: Optional[int] = glu_interleave_size + + def op_forward(self, *args, **kwargs) -> None: + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_forward` instead of `op_forward`." + ) + + def op_backward(self, *args, **kwargs) -> None: + raise RuntimeError( + f"{self.__class__.__name__} operation has " + f"{self.num_extra_inputs} extra tensor inputs " + f"and {self.num_extra_outputs} extra tensor outputs. " + "It overrides `fuser_backward` instead of `op_backward`." + ) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + extra_input = basic_op_extra_inputs[0][0] + + # Determine compute dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + elif isinstance(input_, torch.Tensor): + dtype = input_.dtype + else: + dtype = extra_input.dtype + + # Make sure inputs are in correct dtype + input_ = maybe_dequantize(input_, dtype) + scales = maybe_dequantize(extra_input, dtype) + + # Remove gate interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Compute scaled SwiGLU + swiglu_out = tex.swiglu(swiglu_in, None) + out = swiglu_out * scales.unsqueeze(-1) + + # Save state for backward pass + ctx = basic_op_ctxs[0] + if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(input_) + ctx.input_requires_grad = True + ctx.extra_input_requires_grad = extra_input.requires_grad + ctx.dtype = dtype + ctx.save_for_backward( + input_, + scales if ctx.input_requires_grad else None, + ) + + return out, [()] + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + ctx = basic_op_ctxs[0] + input_, scales = ctx.saved_tensors + input_ = maybe_dequantize(input_, ctx.dtype) + if scales is not None: + scales = maybe_dequantize(scales, ctx.dtype) + grad_output = maybe_dequantize(grad_output, ctx.dtype) + + # Remove gate interleaving if needed + swiglu_in = input_ + if self.glu_interleave_size is not None: + shape = swiglu_in.size() + swiglu_in = swiglu_in.reshape( + -1, + shape[-1] // (2 * self.glu_interleave_size), + 2, + self.glu_interleave_size, + ) + swiglu_in = swiglu_in.transpose(1, 2).contiguous() + swiglu_in = swiglu_in.view(shape) + + # Compute input grad + grad_input = None + if ctx.input_requires_grad: + grad_swiglu_out = grad_output * scales.unsqueeze(-1) + grad_swiglu_in = tex.dswiglu(grad_swiglu_out, swiglu_in, None) + grad_input = grad_swiglu_in + if self.glu_interleave_size is not None: + shape = grad_input.size() + grad_input = grad_input.reshape( + -1, + 2, + shape[-1] // (2 * self.glu_interleave_size), + self.glu_interleave_size, + ) + grad_input = grad_input.transpose(1, 2).contiguous() + grad_input = grad_input.view(shape) + + # Compute scales grad by recomputing SwiGLU + grad_extra_input = None + if ctx.extra_input_requires_grad: + swiglu_out = tex.swiglu(swiglu_in, None) + grad_extra_input = torch.linalg.vecdot(swiglu_out, grad_output) + + # Clear input tensor if possible + clear_tensor_data(ctx.saved_tensors[0]) # input_ + + return grad_input, [()], [(grad_extra_input,)]