From 79def34966bc227cbc459efec0be1304877010de Mon Sep 17 00:00:00 2001 From: Siddhartha Raman Date: Tue, 16 Jun 2026 11:39:05 -0700 Subject: [PATCH 1/6] Enable NVFP4 RHT amax for grouped SReLU MLP Signed-off-by: Siddhartha Raman --- tests/pytorch/test_fusible_ops.py | 46 +++++++++++++++++-- .../pytorch/ops/fused/grouped_mlp.py | 42 ++++++++++++----- 2 files changed, 71 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 43c7965518..abfc0f75f6 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3435,8 +3435,9 @@ def test_grouped_mlp( quantization: Optional[str], device: torch.device = "cuda", split_alignment: int = 256, + activation: str = "scaled_swiglu", ) -> None: - """GroupedLinear + ScaledSwiGLU + GroupedLinear""" + """GroupedLinear + scaled activation + GroupedLinear""" # Split sizes split_sizes = [split_alignment * (i) for i in range(group_size)] @@ -3446,13 +3447,20 @@ def test_grouped_mlp( # Make input shape in_shape = (split_sizes.sum().item(), hidden_size) out_shape = in_shape - fc1_out_features = 2 * hidden_size + if activation == "scaled_swiglu": + fc1_out_features = 2 * hidden_size + elif activation == "scaled_srelu": + fc1_out_features = hidden_size + else: + raise ValueError(f"Unexpected grouped MLP activation ({activation})") # Skip invalid configurations with_quantization = quantization is not None maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) if with_quantization and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if activation == "scaled_srelu" and quantization == "nvfp4_rht" and bias: + pytest.skip("NVFP4 RHT SReLU grouped MLP coverage is limited to no-bias") # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -3535,8 +3543,13 @@ def test_grouped_mlp( fc1_out = torch.nn.functional.linear( x, fc1_ws_ref[group_idx], bias=fc1_bs_ref[group_idx] ) - act_in1, act_in2 = fc1_out.chunk(2, dim=-1) - act_out = torch.nn.functional.silu(act_in1) * act_in2 + if activation == "scaled_swiglu": + act_in1, act_in2 = fc1_out.chunk(2, dim=-1) + act_out = torch.nn.functional.silu(act_in1) * act_in2 + elif activation == "scaled_srelu": + act_out = torch.nn.functional.relu(fc1_out).square() + else: + raise ValueError(f"Unexpected grouped MLP activation ({activation})") fc2_in = act_out * probs[group_idx].unsqueeze(-1) y = torch.nn.functional.linear(fc2_in, fc2_ws_ref[group_idx]) if bias: @@ -3565,7 +3578,13 @@ def test_grouped_mlp( dtype=dtype, scale_bias=bias, ) - module = te.ops.Sequential(fc1, te_ops.ScaledSwiGLU(), fc2) + if activation == "scaled_swiglu": + activation_op = te_ops.ScaledSwiGLU() + elif activation == "scaled_srelu": + activation_op = te_ops.ScaledSReLU() + else: + raise ValueError(f"Unexpected grouped MLP activation ({activation})") + module = te.ops.Sequential(fc1, activation_op, fc2) # Copy weights with torch.no_grad(): @@ -3585,6 +3604,8 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} + if quantization == "nvfp4_rht": + tols = {"rtol": 0.25, "atol": 0.5} # Check values assert_close(y_test, y_ref, **tols) @@ -3597,6 +3618,21 @@ def test_grouped_mlp( assert_close_grads(getattr(fc2, f"bias{group_idx}"), fc2_bs_ref[group_idx], **tols) assert_close_grads(getattr(fc1, f"bias{group_idx}"), fc1_bs_ref[group_idx], **tols) + def test_grouped_mlp_nvfp4_rht_srelu( + self, + *, + device: torch.device = "cuda", + ) -> None: + """GroupedLinear + ScaledSReLU + GroupedLinear with NVFP4 RHT amax.""" + + self.test_grouped_mlp( + bias=False, + dtype=torch.bfloat16, + quantization="nvfp4_rht", + device=device, + activation="scaled_srelu", + ) + class TestCustomOps: """Test with ops that are defined externally""" diff --git a/transformer_engine/pytorch/ops/fused/grouped_mlp.py b/transformer_engine/pytorch/ops/fused/grouped_mlp.py index 39180f098e..d83311b6a5 100644 --- a/transformer_engine/pytorch/ops/fused/grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/grouped_mlp.py @@ -871,7 +871,7 @@ def fuser_forward( basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations - fc1_op, _, fc2_op = self.basic_ops + fc1_op, activation_op, fc2_op = self.basic_ops fc1_ctx, _activation_ctx, fc2_ctx = basic_op_ctxs # Tensor properties @@ -1151,17 +1151,20 @@ def fuser_forward( fc1_alpha_tensor = alpha_tensor use_tmem_post_rht_amax = _use_tmem_post_rht_amax() - use_fc1_glu_hadamard = False + use_fc1_act_hadamard = False + use_fc1_act_hadamard_srelu = False use_nvfp4_rht_amax = ( use_nvfp4 and isinstance(fc2_input_quantizer, NVFP4Quantizer) and fc2_input_quantizer.with_rht and fc2_input_quantizer.with_post_rht_amax ) - if use_nvfp4_rht_amax and self._cudnn_act_func == "swiglu": - kernel_getter = getattr(self, "grouped_gemm_glu_hadamard_kernel", None) + activation_is_srelu = isinstance(activation_op, ScaledSReLU) + if use_nvfp4_rht_amax and (self._cudnn_act_func == "swiglu" or activation_is_srelu): + kernel_getter = getattr(self, "grouped_gemm_act_hadamard_kernel", None) if kernel_getter is not None: - use_fc1_glu_hadamard = kernel_getter() is not None + use_fc1_act_hadamard = kernel_getter() is not None + use_fc1_act_hadamard_srelu = use_fc1_act_hadamard and activation_is_srelu fc1_activation_kwargs = { "a_tensor": fc1_x_data, @@ -1178,9 +1181,11 @@ def fuser_forward( "current_stream": current_stream, "use_dynamic_sched": True, } - if self._cudnn_act_func is not None: + if use_fc1_act_hadamard_srelu: + fc1_activation_kwargs["act_func"] = "srelu" + elif self._cudnn_act_func is not None: fc1_activation_kwargs["act_func"] = self._cudnn_act_func - if use_fc1_glu_hadamard: + if use_fc1_act_hadamard: fc1_activation_kwargs["use_tmem_post_rht_amax"] = use_tmem_post_rht_amax else: fc1_activation_kwargs["norm_const_tensor"] = fc1_norm_const_tensor @@ -1234,8 +1239,8 @@ def fuser_forward( fc1_activation_kwargs["b_dtype"] = data_dtype fc1_activation_kwargs["b_major"] = "k" - if use_fc1_glu_hadamard: - fc1_kernel_out = self.grouped_gemm_glu_hadamard_kernel()(**fc1_activation_kwargs) + if use_fc1_act_hadamard: + fc1_kernel_out = self.grouped_gemm_act_hadamard_kernel()(**fc1_activation_kwargs) else: fc1_kernel_out = self.grouped_gemm_activation_kernel()(**fc1_activation_kwargs) @@ -1269,7 +1274,7 @@ def fuser_forward( fc2_in = fc2_in.view(in_shape[0], fc2_weight_shape[1]).contiguous() fc2_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc2_input_quantizer.optimize_for_gemm = True - if use_fc1_glu_hadamard: + if use_fc1_act_hadamard: grouped_fc2_x = _group_quantize_with_amax_for_grouped_mlp( fc2_in, fc2_input_quantizer, @@ -2109,8 +2114,8 @@ def grouped_gemm_activation_kernel(cls) -> Callable: @classmethod @functools.lru_cache(maxsize=None) - def grouped_gemm_glu_hadamard_kernel(cls) -> Optional[Callable]: - """Fused grouped GEMM GLU kernel that also emits NVFP4 RHT amaxes.""" + def grouped_gemm_act_hadamard_kernel(cls) -> Optional[Callable]: + """Fused grouped GEMM activation kernel that also emits NVFP4 RHT amaxes.""" try: from cudnn import ( grouped_gemm_glu_hadamard_wrapper_sm100, @@ -2146,6 +2151,19 @@ def grouped_gemm_activation_kernel(cls) -> Callable: return grouped_gemm_srelu_wrapper_sm100 + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_act_hadamard_kernel(cls) -> Optional[Callable]: + """Fused grouped GEMM activation kernel that also emits NVFP4 RHT amaxes.""" + try: + from cudnn import ( + grouped_gemm_glu_hadamard_wrapper_sm100, + ) # pylint: disable=no-name-in-module,import-outside-toplevel + except ImportError: + return None + + return grouped_gemm_glu_hadamard_wrapper_sm100 + @classmethod @functools.lru_cache(maxsize=None) def grouped_gemm_dactivation_kernel(cls) -> Callable: From b8d2dd3a4a230d21c1e2f1cdb1966751acb75f0d Mon Sep 17 00:00:00 2001 From: Siddhartha Raman Sundara Raman Date: Tue, 23 Jun 2026 12:29:57 -0500 Subject: [PATCH 2/6] Update tests/pytorch/test_fusible_ops.py Co-authored-by: vthumbe1503 Signed-off-by: Siddhartha Raman Sundara Raman --- tests/pytorch/test_fusible_ops.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index abfc0f75f6..f7e3a20fec 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3603,9 +3603,7 @@ def test_grouped_mlp( y_test.backward(dy_test) # Loose tols for sanity checking - tols = {"rtol": 0.125, "atol": 0.25} - if quantization == "nvfp4_rht": - tols = {"rtol": 0.25, "atol": 0.5} +tols = quantization_tols(quantization) # Check values assert_close(y_test, y_ref, **tols) From 3c1db0fd725ecd77b50ec665fda78361ec769c02 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 23 Jun 2026 16:32:39 -0700 Subject: [PATCH 3/6] Fix indentation for tols assignment in test Signed-off-by: vthumbe1503 --- tests/pytorch/test_fusible_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index f7e3a20fec..c4e10af5a9 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3603,7 +3603,7 @@ def test_grouped_mlp( y_test.backward(dy_test) # Loose tols for sanity checking -tols = quantization_tols(quantization) + tols = quantization_tols(quantization) # Check values assert_close(y_test, y_ref, **tols) From 6ce52594dab51c7f26039ea4dd95dd3d95df09d5 Mon Sep 17 00:00:00 2001 From: Siddhartha Raman Sundara Raman Date: Tue, 23 Jun 2026 20:13:14 -0700 Subject: [PATCH 4/6] Guard SReLU hadamard kernel by cuDNN frontend version Signed-off-by: Siddhartha Raman Sundara Raman --- transformer_engine/pytorch/ops/fused/grouped_mlp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/fused/grouped_mlp.py b/transformer_engine/pytorch/ops/fused/grouped_mlp.py index d83311b6a5..11e97e02e3 100644 --- a/transformer_engine/pytorch/ops/fused/grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/grouped_mlp.py @@ -83,6 +83,11 @@ def _cudnn_frontend_supports_grouped_gemm_srelu() -> bool: return _cudnn_frontend_version_at_least("1.24.0") +def _cudnn_frontend_supports_grouped_gemm_srelu_hadamard() -> bool: + """Check cuDNN frontend min version for grouped GEMM SReLU hadamard kernels.""" + return _cudnn_frontend_version_at_least("1.26.0") + + def _nvidia_cudnn_frontend_supports_wgrad() -> bool: """Check cuDNN FE min version for grouped GEMM wgrad kernel.""" return _cudnn_frontend_version_supported() @@ -1160,7 +1165,10 @@ def fuser_forward( and fc2_input_quantizer.with_post_rht_amax ) activation_is_srelu = isinstance(activation_op, ScaledSReLU) - if use_nvfp4_rht_amax and (self._cudnn_act_func == "swiglu" or activation_is_srelu): + activation_supports_hadamard = self._cudnn_act_func == "swiglu" or ( + activation_is_srelu and _cudnn_frontend_supports_grouped_gemm_srelu_hadamard() + ) + if use_nvfp4_rht_amax and activation_supports_hadamard: kernel_getter = getattr(self, "grouped_gemm_act_hadamard_kernel", None) if kernel_getter is not None: use_fc1_act_hadamard = kernel_getter() is not None @@ -2155,6 +2163,9 @@ def grouped_gemm_activation_kernel(cls) -> Callable: @functools.lru_cache(maxsize=None) def grouped_gemm_act_hadamard_kernel(cls) -> Optional[Callable]: """Fused grouped GEMM activation kernel that also emits NVFP4 RHT amaxes.""" + if not _cudnn_frontend_supports_grouped_gemm_srelu_hadamard(): + return None + try: from cudnn import ( grouped_gemm_glu_hadamard_wrapper_sm100, From a04aea08040a31ec2b0785fc45a381a0c9a991c5 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 23 Jun 2026 20:29:48 -0700 Subject: [PATCH 5/6] Update tolerance handling for quantization tests Set default tolerance values for quantization checks. Signed-off-by: vthumbe1503 --- tests/pytorch/test_fusible_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index c4e10af5a9..cc9a9587b2 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3603,7 +3603,7 @@ def test_grouped_mlp( y_test.backward(dy_test) # Loose tols for sanity checking - tols = quantization_tols(quantization) + tols = quantization_tols(quantization) if quantization is not None else {"rtol": 0.125, "atol": 0.25} # Check values assert_close(y_test, y_ref, **tols) From 46a6b32c769e5f2c2e6c77b60a8fa5c544d44d72 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Jun 2026 03:30:45 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fusible_ops.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index cc9a9587b2..d5703b02f7 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -3603,7 +3603,11 @@ def test_grouped_mlp( y_test.backward(dy_test) # Loose tols for sanity checking - tols = quantization_tols(quantization) if quantization is not None else {"rtol": 0.125, "atol": 0.25} + tols = ( + quantization_tols(quantization) + if quantization is not None + else {"rtol": 0.125, "atol": 0.25} + ) # Check values assert_close(y_test, y_ref, **tols)