From f3bc607be9f5e77cdd3f6fbf64a3a8589d7a78cc Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Thu, 23 Apr 2026 11:49:06 +0200 Subject: [PATCH 1/8] guard fuser grad checks on non-leaf nodes Signed-off-by: CarlosGomes98 --- transformer_engine/pytorch/ops/fuser.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index a3c7e1bac7..5786c12ef0 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -138,7 +138,8 @@ def forward( ) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for y in ys: - y.requires_grad_(idx >= fuser.first_op_requiring_backward) + if func_ctx is not None: + y.requires_grad_(idx >= fuser.first_op_requiring_backward) extra_outputs[idx] = ys # Flatten list of extra outputs @@ -190,7 +191,8 @@ def forward( for tensor in [x] + extra_outputs_flat: tensor._do_not_clear = True - x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) + if func_ctx is not None: + x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) if extra_outputs_flat: return x, *extra_outputs_flat From e05942f279e71c53d7997a9923d9ad4f8a0236bb Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Wed, 29 Apr 2026 15:03:20 +0200 Subject: [PATCH 2/8] rely on set_output_requires_grad flag, update docstring Signed-off-by: CarlosGomes98 --- transformer_engine/pytorch/ops/fuser.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 5786c12ef0..9ce567da84 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -65,6 +65,7 @@ def forward( input_: torch.Tensor, fuser: OperationFuser, basic_op_kwargs: list[dict[str, Any]], + set_output_requires_grad: bool, *params_and_extra_inputs: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass @@ -79,6 +80,8 @@ def forward( Container for the pipeline of operations to run basic_op_kwargs: list of dict Keyword arguments to BasicOperation + set_output_requires_grad: bool + Whether to set ``requires_grad`` flags on returned tensors *params_and_extra_inputs: torch.Tensor Other tensor inputs to include in autograd graph. Consists of parameter tensors, followed by extra operation inputs. @@ -138,7 +141,7 @@ def forward( ) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for y in ys: - if func_ctx is not None: + if set_output_requires_grad: y.requires_grad_(idx >= fuser.first_op_requiring_backward) extra_outputs[idx] = ys @@ -191,7 +194,7 @@ def forward( for tensor in [x] + extra_outputs_flat: tensor._do_not_clear = True - if func_ctx is not None: + if set_output_requires_grad: x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) if extra_outputs_flat: @@ -295,6 +298,7 @@ def backward( dx, # input_ None, # fuser None, # basic_op_kwargs + None, # set_output_requires_grad *grad_params_flat, *grad_extra_inputs_flat, ) @@ -503,20 +507,19 @@ def __call__( op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward) # Fuser forward pass - if is_grad_enabled: - forward_func = _OperationFuserAutogradFunction.apply - args = [] - else: - forward_func = _OperationFuserAutogradFunction.forward - args = [None] - args += ( + args = ( input, self, basic_op_kwargs, + is_grad_enabled, *self._flat_basic_op_params, *extra_inputs, ) - return forward_func(*args) + + if is_grad_enabled: + return _OperationFuserAutogradFunction.apply(*args) + + return _OperationFuserAutogradFunction.forward(None, *args) def register_forward_fusion( From d308cb923866c87693454e821ecb89e44526a398 Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Wed, 29 Apr 2026 15:12:48 +0200 Subject: [PATCH 3/8] make code clearer Signed-off-by: CarlosGomes98 --- 3rdparty/cudnn-frontend | 2 +- transformer_engine/pytorch/ops/fuser.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 97f6cb3b88..d33027a41a 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 97f6cb3b88cacff507cca1280db5650a457d92b3 +Subproject commit d33027a41a93af9c85f089c6364ab415fce98982 diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 9ce567da84..2cbae4bd26 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -507,19 +507,24 @@ def __call__( op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward) # Fuser forward pass + # When is_grad_enabled is False, we call forward directly. + # This does not register a PyTorch autograd node, so + # no fuser backward will run. We pass set_output_requires_grad=False + # to avoid setting requires_grad on outputs in + # this path since they may be non-leaf tensors from the inner ops. args = ( input, self, basic_op_kwargs, - is_grad_enabled, + is_grad_enabled, # set_output_requires_grad *self._flat_basic_op_params, *extra_inputs, ) - if is_grad_enabled: - return _OperationFuserAutogradFunction.apply(*args) + if not is_grad_enabled: + return _OperationFuserAutogradFunction.forward(None, *args) - return _OperationFuserAutogradFunction.forward(None, *args) + return _OperationFuserAutogradFunction.apply(*args) def register_forward_fusion( From 2c2f8e99a8728c24bff61ed8eb988e69bda92ec7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 13:13:43 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: CarlosGomes98 --- transformer_engine/pytorch/ops/fuser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 2cbae4bd26..e8b039a2a0 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -509,14 +509,14 @@ def __call__( # Fuser forward pass # When is_grad_enabled is False, we call forward directly. # This does not register a PyTorch autograd node, so - # no fuser backward will run. We pass set_output_requires_grad=False + # no fuser backward will run. We pass set_output_requires_grad=False # to avoid setting requires_grad on outputs in # this path since they may be non-leaf tensors from the inner ops. args = ( input, self, basic_op_kwargs, - is_grad_enabled, # set_output_requires_grad + is_grad_enabled, # set_output_requires_grad *self._flat_basic_op_params, *extra_inputs, ) From bdb0be2d3802f2a7664ad436cea843de19c39b8e Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Wed, 29 Apr 2026 15:03:20 +0200 Subject: [PATCH 5/8] rely on set_output_requires_grad flag, update docstring Signed-off-by: CarlosGomes98 --- transformer_engine/pytorch/ops/fuser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index e8b039a2a0..b60dd59355 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -516,7 +516,7 @@ def __call__( input, self, basic_op_kwargs, - is_grad_enabled, # set_output_requires_grad + is_grad_enabled, *self._flat_basic_op_params, *extra_inputs, ) From a67b89c7af70579e80291b4b1c02a86640cff54c Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Wed, 29 Apr 2026 15:12:48 +0200 Subject: [PATCH 6/8] make code clearer Signed-off-by: CarlosGomes98 --- transformer_engine/pytorch/ops/fuser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index b60dd59355..2cbae4bd26 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -509,14 +509,14 @@ def __call__( # Fuser forward pass # When is_grad_enabled is False, we call forward directly. # This does not register a PyTorch autograd node, so - # no fuser backward will run. We pass set_output_requires_grad=False + # no fuser backward will run. We pass set_output_requires_grad=False # to avoid setting requires_grad on outputs in # this path since they may be non-leaf tensors from the inner ops. args = ( input, self, basic_op_kwargs, - is_grad_enabled, + is_grad_enabled, # set_output_requires_grad *self._flat_basic_op_params, *extra_inputs, ) From 3eff5431a3f95f0c3e4b19157c6de25de871de21 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 13:21:27 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/fuser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 2cbae4bd26..e8b039a2a0 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -509,14 +509,14 @@ def __call__( # Fuser forward pass # When is_grad_enabled is False, we call forward directly. # This does not register a PyTorch autograd node, so - # no fuser backward will run. We pass set_output_requires_grad=False + # no fuser backward will run. We pass set_output_requires_grad=False # to avoid setting requires_grad on outputs in # this path since they may be non-leaf tensors from the inner ops. args = ( input, self, basic_op_kwargs, - is_grad_enabled, # set_output_requires_grad + is_grad_enabled, # set_output_requires_grad *self._flat_basic_op_params, *extra_inputs, ) From 07157978dd5c8a3a8ce230501559051fb891d64d Mon Sep 17 00:00:00 2001 From: CarlosGomes98 Date: Wed, 29 Apr 2026 17:02:51 +0200 Subject: [PATCH 8/8] Revert cudnn-frontend submodule bump Signed-off-by: CarlosGomes98 Made-with: Cursor --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index d33027a41a..97f6cb3b88 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit d33027a41a93af9c85f089c6364ab415fce98982 +Subproject commit 97f6cb3b88cacff507cca1280db5650a457d92b3