Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
33db7c4
fix: fix LinearFunctionForZeroStage3 to support torch.func transforms
roycho96 Mar 21, 2026
39b1755
fix: always pass bias arg in zero3_linear_wrap to avoid setup_context…
roycho96 Mar 21, 2026
6df37af
fix: remove @autocast_custom_fwd from forward, move autocast state to…
roycho96 Mar 22, 2026
c0b9694
fix(zero3): replace custom_bwd with explicit autocast for functorch-s…
zhangj1an Mar 22, 2026
5e83d05
fix(zero): use setup_context for offload pre/post backward Functions
zhangj1an Mar 22, 2026
7483701
Merge branch 'master' into fix/support-func-torch
zhangj1an Mar 24, 2026
a1e798d
run pre-commit checks
zhangj1an Mar 25, 2026
8762d00
update unit tests to reproduce main branch error
zhangj1an Mar 25, 2026
dd037da
add reproduce scripts
zhangj1an Mar 25, 2026
01ee5a6
Merge branch 'master' into fix/support-func-torch
zhangj1an Mar 25, 2026
f69c1f1
update reproduce script
zhangj1an Mar 25, 2026
e58ac18
update reproduce script to skip repeated env setup
zhangj1an Mar 25, 2026
3121a7f
update reproduce script to remove duplicated code
zhangj1an Mar 25, 2026
60d20da
update reproduce script to print test env
zhangj1an Mar 25, 2026
bb245b2
drop PyTorch < 2.0 support and fix autocast backward in ZeRO linear
roycho96 Mar 29, 2026
04c456f
change PyTorch version in README
roycho96 Mar 29, 2026
703aad3
resolve conflict with master
tohtana Mar 29, 2026
8468149
Merge pull request #1 from tohtana/tohtana/pr7916-merge-master-resolve
zhangj1an Mar 30, 2026
e309a6f
remove repro scripts
zhangj1an Mar 30, 2026
e425569
update unit test
zhangj1an Mar 30, 2026
39f7e3c
drop support for pytorch<2.0
zhangj1an Mar 30, 2026
39c9a73
Merge branch 'master' into fix/support-func-torch
zhangj1an Mar 31, 2026
a5aa09a
Merge branch 'master' into fix/support-func-torch
roycho96 Apr 4, 2026
10906ae
Merge branch 'master' into fix/support-func-torch
roycho96 Apr 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 126 additions & 61 deletions deepspeed/runtime/zero/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,76 +35,141 @@ def print_rank_0(message, debug=False, force=False):
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name())
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name())

# PyTorch >= 2.0 supports setup_context, which is required for
# torch.func transforms (vmap, grad, jvp, jacrev, etc.)
_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, 'setup_context')
Comment thread
tohtana marked this conversation as resolved.
Outdated

class LinearFunctionForZeroStage3(torch.autograd.Function):
if _SUPPORTS_SETUP_CONTEXT:

# Note that both forward and backward are @staticmethods
@staticmethod
@autocast_custom_fwd
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
class LinearFunctionForZeroStage3(torch.autograd.Function):

ctx.save_for_backward(input, weight, bias)
@staticmethod
# bias is an optional argument
def forward(input, weight, bias=None):

if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output

return ret

# This function has only a single output, so it gets only one gradient
@staticmethod
@autocast_custom_bwd
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors

grad_input = grad_weight = grad_bias = None

#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
dim = grad_output.dim()
if ctx.needs_input_grad[0]:
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
grad_input = grad_output.matmul(weight)
#print(f"Computed grad input {grad_input.shape}")
if ctx.needs_input_grad[1]:
#print("Computing grad weight")
if dim > 2:
grad_weight = grad_output.reshape(-1,
grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
grad_weight = grad_output.t().matmul(input)
#print(f"Computed grad weight grad_weight {grad_weight.shape}")
if bias is not None and ctx.needs_input_grad[2]:
#print("Computing grad bias")
if dim > 2:
grad_bias = grad_output.sum([i for i in range(dim - 1)])
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output

return ret

@staticmethod
def setup_context(ctx, inputs, output):
# Replicate autocast state that @autocast_custom_fwd normally sets on ctx,
# since the decorator assumes args[0] is ctx which is unavailable in the
# separate forward() + setup_context() pattern.
device_type = get_accelerator().device_name()
ctx._dtype = torch.get_autocast_dtype(device_type)
ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type)
input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None
ctx.save_for_backward(input, weight, bias)

# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# Do not use @autocast_custom_bwd here: it pairs with @autocast_custom_fwd on
# legacy forward(ctx, ...). With forward + setup_context, use AMP state from setup_context.
device_type = get_accelerator().device_name()
if getattr(ctx, "_fwd_used_autocast", False):
with torch.amp.autocast(device_type=device_type, enabled=True, dtype=ctx._dtype):
return LinearFunctionForZeroStage3._backward_core(ctx, grad_output)
return LinearFunctionForZeroStage3._backward_core(ctx, grad_output)
Comment thread
roycho96 marked this conversation as resolved.
Outdated

@staticmethod
def _backward_core(ctx, grad_output):
input, weight, bias = ctx.saved_tensors

grad_input = grad_weight = grad_bias = None

dim = grad_output.dim()
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
if ctx.needs_input_grad[1]:
if dim > 2:
grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(
input.reshape(-1, input.shape[-1]))
else:
grad_weight = grad_output.t().matmul(input)
if bias is not None and ctx.needs_input_grad[2]:
if dim > 2:
grad_bias = grad_output.sum([i for i in range(dim - 1)])
else:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias

else:

class LinearFunctionForZeroStage3(torch.autograd.Function):

# Note that both forward and backward are @staticmethods
@staticmethod
@autocast_custom_fwd
# bias is an optional argument
def forward(ctx, input, weight, bias=None):

ctx.save_for_backward(input, weight, bias)

if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
grad_bias = grad_output.sum(0)
#print("Done computing grad bias")
#print("needs bias")
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
return grad_input, grad_weight, grad_bias
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output

return ret

# This function has only a single output, so it gets only one gradient
@staticmethod
@autocast_custom_bwd
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors

grad_input = grad_weight = grad_bias = None

#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
dim = grad_output.dim()
if ctx.needs_input_grad[0]:
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
grad_input = grad_output.matmul(weight)
#print(f"Computed grad input {grad_input.shape}")
if ctx.needs_input_grad[1]:
#print("Computing grad weight")
if dim > 2:
grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(
input.reshape(-1, input.shape[-1]))
else:
grad_weight = grad_output.t().matmul(input)
#print(f"Computed grad weight grad_weight {grad_weight.shape}")
if bias is not None and ctx.needs_input_grad[2]:
#print("Computing grad bias")
if dim > 2:
grad_bias = grad_output.sum([i for i in range(dim - 1)])
else:
grad_bias = grad_output.sum(0)
#print("Done computing grad bias")
#print("needs bias")
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
return grad_input, grad_weight, grad_bias


def zero3_linear_wrap(input, weight, bias=None):
if bias is None:
return LinearFunctionForZeroStage3.apply(input, weight)
else:
return LinearFunctionForZeroStage3.apply(input, weight, bias)
return LinearFunctionForZeroStage3.apply(input, weight, bias)


class LinearModuleForZeroStage3(Module):
Expand Down
139 changes: 99 additions & 40 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

FWD_MODULE_STACK = list()

# PyTorch >= 2.0: setup_context on autograd.Function is required for torch.func transforms.
# Match deepspeed/runtime/zero/linear.py: keep legacy forward(ctx, ...) when unavailable.
_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, "setup_context")


#for each tensor in outputs run the forward_function and register backward_function as hook
def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs):
Expand Down Expand Up @@ -401,23 +405,45 @@ def _run_before_backward_function(sub_module):
sub_module.applied_pre_backward_ref_cnt -= 1
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")

class PreBackwardFunctionForModule(torch.autograd.Function):
if _SUPPORTS_SETUP_CONTEXT:

class PreBackwardFunctionForModule(torch.autograd.Function):

@staticmethod
def forward(outputs):
return outputs.detach()

@staticmethod
def forward(ctx, outputs):
# Capture `module` and _run_before_backward_function
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1
outputs = outputs.detach()
return outputs
@staticmethod
def setup_context(ctx, inputs, output):
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1

@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return args

else:

@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return args
class PreBackwardFunctionForModule(torch.autograd.Function):

@staticmethod
def forward(ctx, outputs):
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1
outputs = outputs.detach()
return outputs

@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return args

module.pre_bwd_fn = PreBackwardFunctionForModule

Expand All @@ -431,31 +457,64 @@ def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)

class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(ctx, output):
ctx.module = module
if output.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.post_backward_function(ctx.module)
return args
if _SUPPORTS_SETUP_CONTEXT:

class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(output):
return output.detach()

@staticmethod
def setup_context(ctx, inputs, output):
(output_in, ) = inputs
ctx.module = module
if output_in.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function

@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.post_backward_function(ctx.module)
return args

else:

class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(ctx, output):
ctx.module = module
if output.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.post_backward_function(ctx.module)
return args

module.post_bwd_fn = PostBackwardFunctionModule

Expand Down
Loading