From 6e16c630b11fe32704bde034cbe340a352129693 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Tue, 28 Apr 2026 01:18:39 -0700 Subject: [PATCH 1/2] Fix CUDA graph parameter grad lifetime Signed-off-by: Robin Zhang --- tests/pytorch/test_cuda_graphs.py | 156 +++++++++++++++++++++++++++- transformer_engine/pytorch/graph.py | 41 +++++++- 2 files changed, 192 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index a782dadc60..e62dc9f401 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -630,11 +630,137 @@ def test_make_graphed_callables_with_kwargs( assert_all_equal(outputs, graph_outputs) +def test_make_graphed_callables_returns_owned_parameter_grads() -> None: + """Parameter grads returned from graph replay must not alias static graph buffers.""" + reset_rng_states() + model_config = model_configs["small"] + dtype = torch.float32 + model = torch.nn.Linear( + model_config.hidden_size, + model_config.hidden_size, + bias=False, + device="cuda", + dtype=dtype, + ) + model = make_graphed_callables( + model, + (generate_data(model_config, dtype, warmup=True, requires_grad=False),), + ) + + seen_grads = [] + + def save_grad(grad): + seen_grads.append(grad) + return grad + + hook = model.weight.register_hook(save_grad) + try: + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 1 + first_grad = seen_grads[0] + first_grad_ptr = first_grad.data_ptr() + first_grad_snapshot = first_grad.clone() + + model.zero_grad(set_to_none=True) + + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 2 + assert first_grad.data_ptr() == first_grad_ptr + assert seen_grads[1].data_ptr() != first_grad_ptr + torch.testing.assert_close(first_grad, first_grad_snapshot, rtol=0, atol=0) + finally: + hook.remove() + reset_graphs(model) + + +def test_make_graphed_callables_accumulates_owned_parameter_grads() -> None: + """Parameter grad accumulation must not reuse overwritten static graph buffers.""" + reset_rng_states() + model_config = model_configs["small"] + dtype = torch.float32 + model = torch.nn.Linear( + model_config.hidden_size, + model_config.hidden_size, + bias=False, + device="cuda", + dtype=dtype, + ) + model = make_graphed_callables( + model, + (generate_data(model_config, dtype, warmup=True, requires_grad=False),), + ) + + input_1 = generate_data(model_config, dtype, requires_grad=False) + grad_1 = generate_data(model_config, dtype, requires_grad=False) + input_2 = generate_data(model_config, dtype, requires_grad=False) + grad_2 = generate_data(model_config, dtype, requires_grad=False) + expected_grad = torch.einsum("...o,...i->oi", grad_1, input_1) + torch.einsum( + "...o,...i->oi", grad_2, input_2 + ) + + try: + model.zero_grad(set_to_none=True) + model(input_1).backward(grad_1) + model(input_2).backward(grad_2) + torch.testing.assert_close(model.weight.grad, expected_grad, rtol=0, atol=0) + finally: + reset_graphs(model) + + +def test_make_graphed_callables_preserves_skipped_parameter_grad_alias() -> None: + """Delayed-wgrad parameters are excluded from returned-grad clone handling.""" + reset_rng_states() + model_config = model_configs["small"] + dtype = torch.float32 + model = torch.nn.Linear( + model_config.hidden_size, + model_config.hidden_size, + bias=False, + device="cuda", + dtype=dtype, + ) + model.weight.skip_backward_post_hook = True + model = make_graphed_callables( + model, + (generate_data(model_config, dtype, warmup=True, requires_grad=False),), + ) + + seen_grads = [] + + def save_grad(grad): + seen_grads.append(grad) + return grad + + hook = model.weight.register_hook(save_grad) + try: + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 1 + first_grad_ptr = seen_grads[0].data_ptr() + + model.zero_grad(set_to_none=True) + + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 2 + assert seen_grads[1].data_ptr() == first_grad_ptr + finally: + hook.remove() + reset_graphs(model) + + def _test_cuda_graphs_with_interleaved_pipeline_parallelism( *, with_graph: bool, model_config: ModelConfig, dtype: torch.dtype, + reuse_graph_input_output_buffers: bool = False, ) -> List[torch.Tensor]: """Simulate Megatron-LM interleaved pipeline parallelism.""" reset_rng_states() @@ -675,6 +801,7 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( sample_args, allow_unused_input=True, _order=layer_order, + _reuse_graph_input_output_buffers=reuse_graph_input_output_buffers, ) layer_forwards = { (i // num_microbatches, i % num_microbatches): forward @@ -701,11 +828,15 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( # Cache for layer outputs. outputs = {} + output_snapshots = {} if reuse_graph_input_output_buffers else None def forward(layer_idx: int, microbatch_idx: int): """Helper function for forward steps""" idxs = (layer_idx, microbatch_idx) outputs[idxs] = layer_forwards[idxs](inputs[idxs]) + if output_snapshots is not None: + # Reused graph output buffers are only valid until their corresponding backward. + output_snapshots[idxs] = outputs[idxs].detach().clone() def backward(layer_idx: int, microbatch_idx: int): """Helper function for backward steps""" @@ -728,8 +859,9 @@ def backward(layer_idx: int, microbatch_idx: int): # Optimizer step. optimizer.step() - outputs = [y for _, y in sorted(outputs.items())] - outputs = get_outputs(model, outputs) + output_values = output_snapshots if output_snapshots is not None else outputs + output_values = [y for _, y in sorted(output_values.items())] + outputs = get_outputs(model, output_values) if with_graph: reset_graphs(layer_forwards) return outputs @@ -752,3 +884,23 @@ def test_make_graphed_callables_with_interleaved_pipeline_parallelism( **kwargs, ) assert_all_equal(outputs, graph_outputs) + + +def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers( + *, + model_config: str = "small", + dtype: torch.dtype = torch.float16, +) -> None: + """Test CUDA graphs with reused input/output buffers.""" + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) + outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + with_graph=False, + **kwargs, + ) + graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + with_graph=True, + reuse_graph_input_output_buffers=True, + **kwargs, + ) + assert_all_equal(outputs, graph_outputs) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 075db1394b..23ab002514 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -407,6 +407,15 @@ def _make_graphed_callables( bwd_dw_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))] + def _is_returned_param_grad_slot(idx, static_grad_inputs, module_params): + """Return whether a static grad slot is consumed through Graphed.backward.""" + module_param_start = len(static_grad_inputs) - len(module_params) + if idx < module_param_start: + return False + return not getattr( + module_params[idx - module_param_start], "skip_backward_post_hook", False + ) + # For cases with multiple active RNG states, e.g. TP. if graph_safe_rng_available(): for _, state in get_all_rng_states().items(): @@ -728,6 +737,24 @@ def hook_fn( static_outputs ) + # Parameter grads are cloned before being returned from + # Graphed.backward, so their static buffers can be weak-refed now. + static_grad_inputs = per_callable_static_grad_inputs[per_callable_bwd_idx] + module_params = per_callable_module_params[per_callable_bwd_idx] + per_callable_static_grad_inputs[per_callable_bwd_idx] = tuple( + ( + make_weak_ref(grad_input) + if ( + _is_returned_param_grad_slot( + idx, static_grad_inputs, module_params + ) + and grad_input is not None + ) + else grad_input + ) + for idx, grad_input in enumerate(static_grad_inputs) + ) + # Weak ref the static grad inputs of the previous backward pass within the # same chunk. if previous_per_callable_bwd_idx is not None: @@ -911,9 +938,17 @@ def backward(ctx, *grads): "Expected static_grad_inputs to be a tuple, but got" f" {type(static_grad_inputs).__name__}" ) - return (None, None, None) + tuple( - b.detach() if b is not None else b for b in static_grad_inputs - ) + grad_inputs = [] + for idx, grad_input in enumerate(static_grad_inputs): + if grad_input is None: + grad_inputs.append(None) + elif _is_returned_param_grad_slot(idx, static_grad_inputs, module_params): + # Returned parameter grads may be installed directly as param.grad. + # Clone to avoid exposing CUDA graph static buffers to autograd users. + grad_inputs.append(grad_input.detach().clone()) + else: + grad_inputs.append(grad_input.detach()) + return (None, None, None) + tuple(grad_inputs) def functionalized(*user_args, **user_kwargs): From 4077b8570b517c9694b8d904bb4278ac2a0d754e Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Tue, 28 Apr 2026 02:40:12 -0700 Subject: [PATCH 2/2] Address CUDA graph grad lifetime review feedback Signed-off-by: Robin Zhang --- tests/pytorch/test_cuda_graphs.py | 62 ++++++++++++++++++++++++++--- transformer_engine/pytorch/graph.py | 46 ++++++++++++++------- 2 files changed, 87 insertions(+), 21 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index e62dc9f401..165cbad3dd 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -755,13 +755,60 @@ def save_grad(grad): reset_graphs(model) +def test_make_graphed_callables_snapshots_parameter_grad_clone_policy() -> None: + """Parameter grad clone policy is fixed at capture time.""" + reset_rng_states() + model_config = model_configs["small"] + dtype = torch.float32 + model = torch.nn.Linear( + model_config.hidden_size, + model_config.hidden_size, + bias=False, + device="cuda", + dtype=dtype, + ) + model = make_graphed_callables( + model, + (generate_data(model_config, dtype, warmup=True, requires_grad=False),), + ) + model.weight.skip_backward_post_hook = True + + seen_grads = [] + + def save_grad(grad): + seen_grads.append(grad) + return grad + + hook = model.weight.register_hook(save_grad) + try: + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 1 + first_grad = seen_grads[0] + first_grad_ptr = first_grad.data_ptr() + first_grad_snapshot = first_grad.clone() + + model.zero_grad(set_to_none=True) + + output = model(generate_data(model_config, dtype, requires_grad=False)) + output.backward(generate_data(model_config, dtype, requires_grad=False)) + + assert len(seen_grads) == 2 + assert seen_grads[1].data_ptr() != first_grad_ptr + torch.testing.assert_close(first_grad, first_grad_snapshot, rtol=0, atol=0) + finally: + hook.remove() + reset_graphs(model) + + def _test_cuda_graphs_with_interleaved_pipeline_parallelism( *, with_graph: bool, model_config: ModelConfig, dtype: torch.dtype, reuse_graph_input_output_buffers: bool = False, -) -> List[torch.Tensor]: +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Simulate Megatron-LM interleaved pipeline parallelism.""" reset_rng_states() @@ -862,9 +909,10 @@ def backward(layer_idx: int, microbatch_idx: int): output_values = output_snapshots if output_snapshots is not None else outputs output_values = [y for _, y in sorted(output_values.items())] outputs = get_outputs(model, output_values) + final_weights = [param.detach().clone() for param in model.parameters()] if with_graph: reset_graphs(layer_forwards) - return outputs + return outputs, final_weights def test_make_graphed_callables_with_interleaved_pipeline_parallelism( @@ -875,15 +923,16 @@ def test_make_graphed_callables_with_interleaved_pipeline_parallelism( """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism.""" model_config = model_configs[model_config] kwargs = dict(model_config=model_config, dtype=dtype) - outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + outputs, weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism( with_graph=False, **kwargs, ) - graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + graph_outputs, graph_weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism( with_graph=True, **kwargs, ) assert_all_equal(outputs, graph_outputs) + assert_all_equal(weights, graph_weights) def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers( @@ -894,13 +943,14 @@ def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buf """Test CUDA graphs with reused input/output buffers.""" model_config = model_configs[model_config] kwargs = dict(model_config=model_config, dtype=dtype) - outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + outputs, weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism( with_graph=False, **kwargs, ) - graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( + graph_outputs, graph_weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism( with_graph=True, reuse_graph_input_output_buffers=True, **kwargs, ) assert_all_equal(outputs, graph_outputs) + assert_all_equal(weights, graph_weights) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 23ab002514..fba2178786 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -407,13 +407,15 @@ def _make_graphed_callables( bwd_dw_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))] - def _is_returned_param_grad_slot(idx, static_grad_inputs, module_params): - """Return whether a static grad slot is consumed through Graphed.backward.""" + def _returned_param_grad_slots(static_grad_inputs, module_params): + """Snapshot static grad slots that are consumed through Graphed.backward.""" module_param_start = len(static_grad_inputs) - len(module_params) - if idx < module_param_start: - return False - return not getattr( - module_params[idx - module_param_start], "skip_backward_post_hook", False + return tuple( + idx >= module_param_start + and not getattr( + module_params[idx - module_param_start], "skip_backward_post_hook", False + ) + for idx in range(len(static_grad_inputs)) ) # For cases with multiple active RNG states, e.g. TP. @@ -578,6 +580,7 @@ def hook_fn( per_callable_output_unflatten_spec = [None] * len(flatten_sample_args) per_callable_static_grad_outputs = [None] * len(flatten_sample_args) per_callable_static_grad_inputs = [None] * len(flatten_sample_args) + per_callable_returned_param_grad_slots = [None] * len(flatten_sample_args) fwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks static_grad_outputs_dict = {} @@ -725,6 +728,13 @@ def hook_fn( per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs + returned_param_grad_slots = _returned_param_grad_slots( + static_grad_inputs, + per_callable_module_params[per_callable_bwd_idx], + ) + per_callable_returned_param_grad_slots[per_callable_bwd_idx] = ( + returned_param_grad_slots + ) # Weak ref the static outputs and static grad inputs that are no longer needed # in the following steps. These two type of tensors are both in cudagraph @@ -740,16 +750,10 @@ def hook_fn( # Parameter grads are cloned before being returned from # Graphed.backward, so their static buffers can be weak-refed now. static_grad_inputs = per_callable_static_grad_inputs[per_callable_bwd_idx] - module_params = per_callable_module_params[per_callable_bwd_idx] per_callable_static_grad_inputs[per_callable_bwd_idx] = tuple( ( make_weak_ref(grad_input) - if ( - _is_returned_param_grad_slot( - idx, static_grad_inputs, module_params - ) - and grad_input is not None - ) + if returned_param_grad_slots[idx] and grad_input is not None else grad_input ) for idx, grad_input in enumerate(static_grad_inputs) @@ -796,6 +800,7 @@ def hook_fn( # Capture backward graphs in reverse order per_callable_static_grad_outputs = [] per_callable_static_grad_inputs = [] + per_callable_returned_param_grad_slots = [] for static_input_surface, static_outputs, bwd_graph, bwd_dw_graph, bwd_idx in zip( reversed(per_callable_static_input_surfaces), reversed(per_callable_static_outputs), @@ -840,10 +845,19 @@ def hook_fn( per_callable_static_grad_outputs.append(static_grad_outputs) per_callable_static_grad_inputs.append(static_grad_inputs) + per_callable_returned_param_grad_slots.append( + _returned_param_grad_slots( + static_grad_inputs, + per_callable_module_params[bwd_idx], + ) + ) - # Reverses the most recent two lists + # Reverse the most recent per-callable lists. per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs)) per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs)) + per_callable_returned_param_grad_slots = list( + reversed(per_callable_returned_param_grad_slots) + ) # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. def make_graphed_autograd_function( @@ -857,6 +871,7 @@ def make_graphed_autograd_function( static_outputs, static_grad_outputs, static_grad_inputs, + returned_param_grad_slots, ): class Graphed(torch.autograd.Function): """Autograd function for graph replay.""" @@ -942,7 +957,7 @@ def backward(ctx, *grads): for idx, grad_input in enumerate(static_grad_inputs): if grad_input is None: grad_inputs.append(None) - elif _is_returned_param_grad_slot(idx, static_grad_inputs, module_params): + elif returned_param_grad_slots[idx]: # Returned parameter grads may be installed directly as param.grad. # Clone to avoid exposing CUDA graph static buffers to autograd users. grad_inputs.append(grad_input.detach().clone()) @@ -1043,6 +1058,7 @@ def reset(): per_callable_static_outputs[i], per_callable_static_grad_outputs[i], per_callable_static_grad_inputs[i], + per_callable_returned_param_grad_slots[i], ) func = graph_callables[i]