From 2f4c6b814625bd9303ac24642e126f8574fe46f8 Mon Sep 17 00:00:00 2001 From: Lifu Zhang Date: Fri, 6 Feb 2026 10:33:10 -0800 Subject: [PATCH 1/2] Fix on TE to support Mcore Vision Encoder CUDA Graph Signed-off-by: Lifu Zhang --- transformer_engine/pytorch/graph.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f587ca9946..80b5f07a2f 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -452,9 +452,9 @@ def hook_fn( inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs): torch.autograd.backward( - tuple(o for o in outputs if o.requires_grad), + tuple(o for o in outputs if o is not None and o.requires_grad), grad_tensors=tuple( - torch.empty_like(o) for o in outputs if o.requires_grad + torch.empty_like(o) for o in outputs if o is not None and o.requires_grad ), ) grad_inputs = tuple(input.grad for input in inputs) @@ -616,19 +616,19 @@ def hook_fn( # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. static_grad_outputs_keys = tuple( - (o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad + (o.shape, o.dtype, o.layout) for o in static_outputs if o is not None and o.requires_grad ) if static_grad_outputs_keys in static_grad_outputs_dict: static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None + torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs ) static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -636,7 +636,7 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple(o for o in static_outputs if o is not None and o.requires_grad), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -719,7 +719,7 @@ def hook_fn( ): # For now, assumes all static_outputs require grad static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -727,7 +727,7 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple(o for o in static_outputs if o is not None and o.requires_grad), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -794,7 +794,7 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Replay forward graph fwd_graph.replay() assert isinstance(static_outputs, tuple) - return tuple(o.detach() for o in static_outputs) + return tuple(o.detach() if o is not None else o for o in static_outputs) @staticmethod @torch.autograd.function.once_differentiable From 503a64458641f1a0c033d8de77e3bbf8f628f8d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:39:18 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 80b5f07a2f..cc7aade361 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -454,7 +454,9 @@ def hook_fn( torch.autograd.backward( tuple(o for o in outputs if o is not None and o.requires_grad), grad_tensors=tuple( - torch.empty_like(o) for o in outputs if o is not None and o.requires_grad + torch.empty_like(o) + for o in outputs + if o is not None and o.requires_grad ), ) grad_inputs = tuple(input.grad for input in inputs) @@ -616,7 +618,9 @@ def hook_fn( # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. static_grad_outputs_keys = tuple( - (o.shape, o.dtype, o.layout) for o in static_outputs if o is not None and o.requires_grad + (o.shape, o.dtype, o.layout) + for o in static_outputs + if o is not None and o.requires_grad ) if static_grad_outputs_keys in static_grad_outputs_dict: static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] @@ -628,7 +632,8 @@ def hook_fn( static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs else: static_grad_outputs = tuple( - torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -636,7 +641,9 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o is not None and o.requires_grad), + tuple( + o for o in static_outputs if o is not None and o.requires_grad + ), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -719,7 +726,8 @@ def hook_fn( ): # For now, assumes all static_outputs require grad static_grad_outputs = tuple( - torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad)