Skip to content
97 changes: 97 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2559,6 +2559,103 @@ def no_sync(self):
finally:
self.inside_no_sync_ctxt = False

@contextmanager
def coalesce_grad_reduction(self):
r"""Coalesce ZeRO 1/2/3 gradient reduction across multiple engine.backward()
calls. One with-block == one optimizer step: every backward inside
leaves grads locally on params, and the flush on exit issues a single
reduction pass that populates averaged_gradients for the next step().

Constraints:
- engine.step() inside the block raises.
- Reentry / nesting with engine.no_sync() raises.
- Do not span multiple gradient_accumulation_steps with multiple
with-blocks; the flush overwrites averaged_gradients each exit.

Unsupported (NotImplementedError): ZeRO stage 0, BF16/FP16_Optimizer
wrappers, PipelineModule.
"""
stage = self.zero_optimization_stage()
if stage not in (ZeroStageEnum.optimizer_states, ZeroStageEnum.gradients, ZeroStageEnum.weights):
raise NotImplementedError(f"coalesce_grad_reduction requires ZeRO stage 1/2/3, got stage {int(stage)}")
if self.pipeline_parallelism:
raise NotImplementedError("coalesce_grad_reduction is not supported under pipeline parallelism")
optimizer = self.optimizer
if not hasattr(optimizer, "_coalesce_grad_reduction"):
# BF16_Optimizer / FP16_Optimizer route grads through their own
# backward_epilogue path, bypassing DeepSpeedZeroOptimizer's
# per-param hooks that this context relies on.
raise NotImplementedError(
f"coalesce_grad_reduction does not yet support optimizer wrapper {type(optimizer).__name__}")
assert not self.inside_no_sync_ctxt, \
"coalesce_grad_reduction cannot be nested inside another no_sync context"

# Engine boundary is the source of truth; optimizer's copy is overwritten
# by _backward_prologue from the engine value on each backward, so we
# only need to save/restore the engine flag.
saved_engine_boundary = self._is_gradient_accumulation_boundary
self.inside_no_sync_ctxt = True
optimizer._coalesce_grad_reduction = True
try:
yield
finally:
# Reset _coalesce_grad_reduction BEFORE the flush so the reducer calls
# we drive in the flush helpers do NOT short-circuit at our guard
# in process_gradients / reduce_ready_partitions_and_remove_grads.
optimizer._coalesce_grad_reduction = False
self.inside_no_sync_ctxt = False
self._is_gradient_accumulation_boundary = True
optimizer.is_gradient_accumulation_boundary = True
try:
# Drive a single reduction pass over locally accumulated grads.
# Iterate explicitly (rather than calling reduce_gradients) so
# the path works regardless of overlap_comm / contiguous_gradients,
# both of which alter reduce_gradients's control flow.
if stage == ZeroStageEnum.weights:
self._flush_coalesced_reduction_zero3(optimizer)
else:
self._flush_coalesced_reduction_zero12(optimizer)
finally:
self._is_gradient_accumulation_boundary = saved_engine_boundary

def _flush_coalesced_reduction_zero12(self, optimizer):
# Quiesce the reduction stream before re-entering it (overlap_comm uses
# a separate stream + double-buffered ipg bucket). Without this the
# bucket.index swap in reduce_independent_p_g_buckets_and_remove_grads
# may race against the previous step's residual reduction.
if getattr(optimizer, "overlap_comm", False) and hasattr(optimizer, "reduction_stream"):
if not get_accelerator().resolves_data_dependency():
optimizer.reduction_stream.synchronize()
# Ensure ipg bucket buffers exist (process_gradients normally allocates
# them via setup_buckets, but we suppressed it during coalesce period).
# Note: micro_step_id increments by 1 here for the whole coalesce block,
# which is fine -- copy_grads_in_partition's accumulate condition uses
# micro_step_id > 0 OR not boundary, and we force boundary=True.
optimizer.setup_buckets()
for i, group in enumerate(optimizer.bit16_groups):
for param in group:
if not param.requires_grad:
continue
# use_grad_accum_attribute=True parks the accumulated grad in
# param.grad_accum instead of param.grad (backward_epilogue
# routes it there each microbatch). get_gradient_for_reduction
# returns the right one for both modes.
if optimizer.get_gradient_for_reduction(param) is None:
continue
optimizer.reduce_ready_partitions_and_remove_grads(param, i)
optimizer.overlapping_partition_gradients_reduce_epilogue()

def _flush_coalesced_reduction_zero3(self, optimizer):
# Leaf-module unused-param zero-fill (stage3.py:1336-1337) runs from
# the leaf module's own backward hook, BEFORE the reducer call we
# suppress. So by flush time the leaf params already have grads (real
# or zero-filled) populated by the hook regardless of _coalesce_grad_reduction.
for group in optimizer.fp16_groups:
for param in group:
if param.requires_grad and param.grad is not None:
optimizer.reduce_ready_partitions_and_remove_grads(param)
optimizer.independent_gradient_partition_epilogue()

def scale(self, loss):
r"""Apply loss scaler for manual backward pass.

Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ def _enforce_optimizer_offload():

self.is_gradient_accumulation_boundary: bool = True

# Toggled by DeepSpeedEngine.coalesce_grad_reduction().
self._coalesce_grad_reduction = False

self.param_reduce_events: Deque[get_accelerator().Event] = collections.deque()
# TODO. make this configurable via JSON
self.max_param_reduce_events: int = 2
Expand Down Expand Up @@ -1812,6 +1815,8 @@ def _partitioned_buffers_all_gather(self, params: List[Parameter], buffers_to_al
return output

def reduce_ready_partitions_and_remove_grads(self, param):
if self._coalesce_grad_reduction:
return
#print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True)
self.reduce_independent_p_g_buckets_and_remove_grads(param)

Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def __init__(self,

self.is_gradient_accumulation_boundary = True

# Toggled by DeepSpeedEngine.coalesce_grad_reduction().
self._coalesce_grad_reduction = False

# CPU-Offload requires contiguous gradients
self.contiguous_gradients = contiguous_gradients or self.cpu_offload

Expand Down Expand Up @@ -1613,6 +1616,8 @@ def reduce_ipg_grads(self, comm_dtype=None):
#####################################################################

def process_gradients(self, param, i):
if self._coalesce_grad_reduction:
return
self.setup_buckets()
if self.use_grad_accum_attribute:
self._fill_param_grad_accum_attribute(param)
Expand Down
69 changes: 69 additions & 0 deletions docs/code-docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,75 @@ Gradient Accumulation
---------------------
.. autofunction:: deepspeed.DeepSpeedEngine.is_gradient_accumulation_boundary

Coalesced Gradient Reduction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: deepspeed.DeepSpeedEngine.coalesce_grad_reduction

Use this when one optimizer step needs multiple ``engine.backward()`` calls
and per-backward reduction is wasted work. Typical cases are GradCache-style
cached contrastive losses that replay backward over chunked representations,
and custom ``torch.autograd.Function`` subclasses that call
``torch.autograd.backward`` from inside their ``forward``. Results are
bit-exact against the per-backward baseline.

Under ZeRO-3, each backward inside the block leaves param-shaped gradients
on the leaf modules instead of triggering the per-backward reduce-scatter.
On exit, a single pass drives the reducer over the accumulated grads and
restores the partitioned ``averaged_gradients`` for ``step()``.

.. code-block:: python

for batch in data_loader:
chunks = batch.split(chunk_size)
with model_engine.coalesce_grad_reduction():
for chunk in chunks:
loss = model_engine(chunk)
model_engine.backward(loss)
model_engine.step()

Communication
^^^^^^^^^^^^^

With ``N`` back-to-back ``backward()`` calls per step, ZeRO-2 and ZeRO-3
normally issue ``N`` gradient collectives (one per backward). Inside
``coalesce_grad_reduction()`` those collapse to one collective on exit.
ZeRO-1 already reduces only at the accumulation boundary, so its collective
count is unchanged; the context still removes the per-backward bucket setup
cost.

Memory
^^^^^^

Suppressing the per-backward reduction means each rank holds a full local
gradient copy for the duration of the ``with`` block.

* ZeRO-2: window-resident memory equals ZeRO-1 with
:meth:`deepspeed.DeepSpeedEngine.no_sync`, one full gradient per rank
held until flush. On a 2-GPU, 134M-param bf16 rig with ``N=4``, peak
window memory drops from 640 MiB (baseline) to 384 MiB.
* ZeRO-3: window-resident is one full gradient per rank vs the
``1/world_size`` partition the per-backward path holds throughout. Peak
is roughly equal to baseline (the in-flight backward already needs
full-grad room and the accumulator reuses it).

Constraints
^^^^^^^^^^^

* ZeRO stage 0 and pipeline parallelism raise ``NotImplementedError``.
* The BF16/FP16 optimizer wrappers (``BF16_Optimizer``, ``FP16_Optimizer``)
route grads through their own ``backward_epilogue`` path and are not yet
supported; the context raises ``NotImplementedError`` at entry. Use raw
ZeRO-1/2/3 for now.
* ``engine.step()`` inside the ``with`` block raises.
* Cannot be nested inside :meth:`deepspeed.DeepSpeedEngine.no_sync`.
* Do not split one ``gradient_accumulation_steps`` window across multiple
``with`` blocks: the flush overwrites ``averaged_gradients`` on each exit.

:meth:`deepspeed.DeepSpeedEngine.no_sync` raises ``AssertionError`` for
ZeRO-2 and ZeRO-3 (``zero_optimization_partition_gradients()`` is true for
stage >= 2), so it cannot collapse collectives for those stages.
``coalesce_grad_reduction()`` is the equivalent for ZeRO-2/3.


Mixed Precision Training
-------------------------
Expand Down
Loading
Loading