Skip to content
1 change: 1 addition & 0 deletions monai/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer
from .utils import (
DiffusionPrepareBatch,
GradientAccumulation,
IterationEvents,
PrepareBatch,
PrepareBatchDefault,
Expand Down
120 changes: 120 additions & 0 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"default_make_latent",
"engine_apply_transform",
"default_metric_cmp_fn",
"GradientAccumulation",
]


Expand Down Expand Up @@ -360,3 +361,122 @@ def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool:

"""
return current_metric > prev_best


def _noop(*args: Any, **kwargs: Any) -> None:
"""No-op callable used to suppress optimizer/scaler methods during gradient accumulation."""


class GradientAccumulation:
"""
Callable class implementing gradient accumulation for use with ``SupervisedTrainer``.

Gradients are accumulated over ``accumulation_steps`` mini-batches before calling
``optimizer.step()``, simulating a larger effective batch size on memory-constrained
hardware.

Pass an instance as ``iteration_update`` when constructing ``SupervisedTrainer``::

trainer = SupervisedTrainer(
...,
iteration_update=GradientAccumulation(accumulation_steps=4),
)

All ``IterationEvents`` (``FORWARD_COMPLETED``, ``LOSS_COMPLETED``,
``BACKWARD_COMPLETED``, ``MODEL_COMPLETED``) still fire on every mini-batch, so
existing handlers (checkpoint savers, metric loggers, etc.) are unaffected.

When ``epoch_length`` is known, the optimizer is flushed at the end of each epoch
even if ``epoch_length % accumulation_steps != 0``, so no gradients are silently
discarded. For iterable datasets (``epoch_length=None``) this flush does not apply.

The loss stored in ``engine.state.output[Keys.LOSS]`` is the **unscaled**
original loss value, so metrics and loggers report the true loss. Internally
the loss is divided by ``accumulation_steps`` for the backward pass only.

Args:
accumulation_steps: number of mini-batches to accumulate before updating
weights. Must be a positive integer. Default: 2.

Raises:
ValueError: when ``accumulation_steps`` is not a positive integer.
"""

def __init__(self, accumulation_steps: int = 2) -> None:
if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
Comment thread
ericspod marked this conversation as resolved.
Outdated
self.accumulation_steps = accumulation_steps

def __repr__(self) -> str:
return f"GradientAccumulation(accumulation_steps={self.accumulation_steps})"

def __call__(self, engine: Any, batchdata: dict[str, Any]) -> dict:
"""
Execute one iteration with gradient accumulation.

Args:
engine: the Ignite engine (usually ``SupervisedTrainer``).
batchdata: batch data for this iteration.

Returns:
the output dict from ``engine._iteration()``.
"""
acc = self.accumulation_steps

result: dict

if acc == 1:
result = engine._iteration(engine, batchdata)
return result

# engine.state.iteration is 1-indexed and already incremented before __call__
epoch_length = engine.state.epoch_length # None for iterable datasets
if epoch_length is not None:
local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch
should_zero_grad = local_iter % acc == 0
should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length
else:
local_iter = engine.state.iteration - 1 # 0-indexed global
should_zero_grad = local_iter % acc == 0
should_step = (local_iter + 1) % acc == 0

# Save and conditionally suppress zero_grad. Only clear gradients at the start of an accumulation cycle.
original_zero_grad = engine.optimizer.zero_grad
if not should_zero_grad:
engine.optimizer.zero_grad = _noop

# Save and wrap loss_function to scale by 1/accumulation_steps. This ensures the per-mini-batch
# gradient contribution is correct: the scaled loss will be backpropagated, and accumulated gradients
# will average to the same value they would with the full batch.
original_loss_fn = engine.loss_function
engine.loss_function = lambda *args, **kwargs: original_loss_fn(*args, **kwargs) / acc

# Save and conditionally suppress optimizer.step. Only update weights at the end of an accumulation cycle.
# Also patch GradScaler.step and GradScaler.update when step is suppressed, for mixed-precision training.
original_step = engine.optimizer.step
original_scaler_step = None
original_scaler_update = None
if not should_step:
engine.optimizer.step = _noop
if hasattr(engine, "scaler") and engine.scaler is not None:
original_scaler_step = engine.scaler.step
original_scaler_update = engine.scaler.update
engine.scaler.step = _noop
engine.scaler.update = _noop

try:
result = engine._iteration(engine, batchdata)
finally:
engine.optimizer.zero_grad = original_zero_grad
engine.loss_function = original_loss_fn
engine.optimizer.step = original_step
if original_scaler_step is not None:
engine.scaler.step = original_scaler_step
engine.scaler.update = original_scaler_update

# Restore the unscaled loss for logging and metrics. The backward pass
# already used the scaled value, so this only affects what handlers see.
if CommonKeys.LOSS in result:
result[CommonKeys.LOSS] = result[CommonKeys.LOSS] * acc

return result
Loading
Loading