diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f587ca9946..cc7aade361 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -452,9 +452,11 @@ def hook_fn( inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs): torch.autograd.backward( - tuple(o for o in outputs if o.requires_grad), + tuple(o for o in outputs if o is not None and o.requires_grad), grad_tensors=tuple( - torch.empty_like(o) for o in outputs if o.requires_grad + torch.empty_like(o) + for o in outputs + if o is not None and o.requires_grad ), ) grad_inputs = tuple(input.grad for input in inputs) @@ -616,19 +618,22 @@ def hook_fn( # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. static_grad_outputs_keys = tuple( - (o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad + (o.shape, o.dtype, o.layout) + for o in static_outputs + if o is not None and o.requires_grad ) if static_grad_outputs_keys in static_grad_outputs_dict: static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None + torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs ) static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -636,7 +641,9 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple( + o for o in static_outputs if o is not None and o.requires_grad + ), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -719,7 +726,8 @@ def hook_fn( ): # For now, assumes all static_outputs require grad static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -727,7 +735,7 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple(o for o in static_outputs if o is not None and o.requires_grad), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -794,7 +802,7 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Replay forward graph fwd_graph.replay() assert isinstance(static_outputs, tuple) - return tuple(o.detach() for o in static_outputs) + return tuple(o.detach() if o is not None else o for o in static_outputs) @staticmethod @torch.autograd.function.once_differentiable