Add GradientAccumulation utility for SupervisedTrainer#8763
Add GradientAccumulation utility for SupervisedTrainer#8763aymuos15 wants to merge 4 commits intoProject-MONAI:devfrom
Conversation
📝 WalkthroughWalkthroughAdds a new callable GradientAccumulation in monai/engines/utils.py that implements gradient accumulation (validates accumulation_steps, scales loss for backward, suppresses per-mini-batch optimizer/GradScaler updates, and restores original behavior including on exceptions). Exposes GradientAccumulation in utils.all and re-exports it from monai/engines/init.py. Adds tests/engines/test_gradient_accumulation.py with unit and integration tests covering validation, repr, passthrough for accumulation_steps==1, optimizer/scaler patching and restoration (including exception cases), loss scaling and return behavior, batch forwarding, and integration comparisons. Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/engines/test_gradient_accumulation.py (1)
339-350: Drop the unusedinit_weightfrom the helper return.
It is not consumed by callers, so removing it tightens the helper contract and avoids dead unpacks downstream.♻️ Proposed cleanup
@@ -def _make_model_pair(lr): +def _make_model_pair(lr): @@ - return ref_model, test_model, ref_opt, test_opt, init_weight + return ref_model, test_model, ref_opt, test_opt@@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr) @@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr) @@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/engines/test_gradient_accumulation.py` around lines 339 - 350, The helper _make_model_pair currently returns an unused init_weight which tightens its contract unnecessarily; remove the creation or cloning of init_weight from inside _make_model_pair (or keep the local init copy only to set test_model weights) and update the return tuple from _make_model_pair to return only (ref_model, test_model, ref_opt, test_opt), then update any callers that unpack the result to stop expecting the fifth value so there are no dead unpacks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Around line 339-350: The helper _make_model_pair currently returns an unused
init_weight which tightens its contract unnecessarily; remove the creation or
cloning of init_weight from inside _make_model_pair (or keep the local init copy
only to set test_model weights) and update the return tuple from
_make_model_pair to return only (ref_model, test_model, ref_opt, test_opt), then
update any callers that unpack the result to stop expecting the fifth value so
there are no dead unpacks.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (3)
monai/engines/__init__.pymonai/engines/utils.pytests/engines/test_gradient_accumulation.py
…ject-MONAI#6100) Closes Project-MONAI#6100 Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
597e086 to
1db8cc1
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/engines/test_gradient_accumulation.py (1)
105-105: Consider marking intentionally-unused bindings with_prefixes.This keeps tests clear while avoiding avoidable lint noise.
🧹 Optional cleanup
- def fake_iteration(eng, batch): + def fake_iteration(eng, _batch): @@ - def check_scaler(eng, batch): + def check_scaler(eng, _batch): @@ - def fake_iteration(*args, **kwargs): + def fake_iteration(*_args, **_kwargs): @@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr) @@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr) @@ - ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) + ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr)Also applies to: 188-188, 234-234, 257-257, 287-287, 318-318
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/engines/test_gradient_accumulation.py` at line 105, The test defines callback functions like fake_iteration(eng, batch) with parameters that are intentionally unused; update these function signatures (and the other occurrences at the same pattern) to mark unused parameters with leading underscores (e.g., _eng, _batch or _batch_idx) so linters know the bindings are intentionally unused—search for the function name fake_iteration and the similar callback definitions at the other noted locations and rename the unused parameters with _ prefixes.monai/engines/utils.py (1)
366-368: Align new definitions with Google-style docstring sections.
_noop,__init__, and__repr__should include explicitArgs/Returns(andRaiseswhere applicable) sections to match repo docstring policy.♻️ Suggested docstring adjustments
def _noop(*args: Any, **kwargs: Any) -> None: - """No-op callable used to suppress optimizer/scaler methods during gradient accumulation.""" + """No-op callable used to suppress optimizer/scaler methods. + + Args: + *args: Ignored positional arguments. + **kwargs: Ignored keyword arguments. + + Returns: + None. + """ class GradientAccumulation: @@ def __init__(self, accumulation_steps: int = 2) -> None: + """Initialize gradient accumulation behavior. + + Args: + accumulation_steps: Number of mini-batches to accumulate before stepping. + + Raises: + ValueError: If `accumulation_steps` is not a positive integer. + """ if not isinstance(accumulation_steps, int) or accumulation_steps < 1: @@ def __repr__(self) -> str: + """Return a debug-friendly representation. + + Returns: + String representation with configured accumulation steps. + """ return f"GradientAccumulation(accumulation_steps={self.accumulation_steps})"As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
Also applies to: 405-413
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/engines/utils.py` around lines 366 - 368, Add Google-style docstring sections to the new definitions: for _noop include an "Args" section describing *args and **kwargs and a "Returns" section noting it returns None; for the class __init__ add an "Args" section for each parameter and a "Returns" section if applicable (or state None) and an optional "Raises" section if it can raise exceptions; for __repr__ add a "Returns" section describing the returned str. Update the docstrings in functions/methods named _noop, __init__, and __repr__ to follow the repo's Google-style (Args, Returns, and Raises where needed) and mirror the format used elsewhere in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@monai/engines/utils.py`:
- Around line 366-368: Add Google-style docstring sections to the new
definitions: for _noop include an "Args" section describing *args and **kwargs
and a "Returns" section noting it returns None; for the class __init__ add an
"Args" section for each parameter and a "Returns" section if applicable (or
state None) and an optional "Raises" section if it can raise exceptions; for
__repr__ add a "Returns" section describing the returned str. Update the
docstrings in functions/methods named _noop, __init__, and __repr__ to follow
the repo's Google-style (Args, Returns, and Raises where needed) and mirror the
format used elsewhere in the file.
In `@tests/engines/test_gradient_accumulation.py`:
- Line 105: The test defines callback functions like fake_iteration(eng, batch)
with parameters that are intentionally unused; update these function signatures
(and the other occurrences at the same pattern) to mark unused parameters with
leading underscores (e.g., _eng, _batch or _batch_idx) so linters know the
bindings are intentionally unused—search for the function name fake_iteration
and the similar callback definitions at the other noted locations and rename the
unused parameters with _ prefixes.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (3)
monai/engines/__init__.pymonai/engines/utils.pytests/engines/test_gradient_accumulation.py
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
monai/engines/utils.py (1)
413-413: Widenbatchdatatype hint in__call__.
batchdata: dict[str, Any]is tighter than common trainer inputs. ConsiderAnyto avoid misleading static typing for tuple/list batch payloads.Proposed fix
- def __call__(self, engine: Any, batchdata: dict[str, Any]) -> dict: + def __call__(self, engine: Any, batchdata: Any) -> dict:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/engines/utils.py` at line 413, The type hint for the __call__ method currently restricts batchdata to dict[str, Any], which is too narrow for trainers that pass tuples/lists; change the signature of __call__ to accept batchdata: Any (or more permissive Union types) so it can handle dict, tuple, list, etc., and update any related type annotations/comments in the same function (named __call__) and its callers to reflect the broader type.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/engines/utils.py`:
- Around line 406-407: The validation for accumulation_steps currently allows
booleans because bool is an int subclass; update the check in
monai.engines.utils (the accumulation_steps validation) to explicitly reject
bools — e.g., require type(accumulation_steps) is int or add "and not
isinstance(accumulation_steps, bool)" to the isinstance check — and keep the
existing lower-bound check (accumulation_steps < 1) so True/False no longer pass
validation.
In `@tests/engines/test_gradient_accumulation.py`:
- Line 91: The test callback function parameters that are intentionally unused
(for example in function fake_iteration) are triggering ARG001; rename those
parameters by prefixing with an underscore (e.g., change def fake_iteration(eng,
batch): to def fake_iteration(_eng, _batch):) and apply the same pattern to the
other callbacks mentioned (the occurrences around the other reported locations)
so unused arguments are clearly marked and lint-clean.
- Line 240: The helper _make_model_pair currently returns a third value
init_weight that callers (tests in tests/engines/test_gradient_accumulation.py)
unpack but never use; remove this unused plumbing by changing _make_model_pair
to return only (ref_model, test_model, ref_opt, test_opt) and update all call
sites (e.g., the unpack at the shown line and similar occurrences at the other
locations) to stop expecting init_weight — adjust any tuple unpacking in the
tests to four variables matching the function's new signature.
---
Nitpick comments:
In `@monai/engines/utils.py`:
- Line 413: The type hint for the __call__ method currently restricts batchdata
to dict[str, Any], which is too narrow for trainers that pass tuples/lists;
change the signature of __call__ to accept batchdata: Any (or more permissive
Union types) so it can handle dict, tuple, list, etc., and update any related
type annotations/comments in the same function (named __call__) and its callers
to reflect the broader type.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/engines/utils.pytests/engines/test_gradient_accumulation.py
| if not isinstance(accumulation_steps, int) or accumulation_steps < 1: | ||
| raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") |
There was a problem hiding this comment.
Reject boolean values for accumulation_steps.
True currently passes validation because bool is an int subclass, so invalid config can silently map to 1.
Proposed fix
- if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
+ if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1:
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")🧰 Tools
🪛 Ruff (0.15.2)
[warning] 407-407: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/engines/utils.py` around lines 406 - 407, The validation for
accumulation_steps currently allows booleans because bool is an int subclass;
update the check in monai.engines.utils (the accumulation_steps validation) to
explicitly reject bools — e.g., require type(accumulation_steps) is int or add
"and not isinstance(accumulation_steps, bool)" to the isinstance check — and
keep the existing lower-bound check (accumulation_steps < 1) so True/False no
longer pass validation.
|
|
||
| saw_original: list[bool] = [] | ||
|
|
||
| def fake_iteration(eng, batch): |
There was a problem hiding this comment.
Prefix intentionally unused test callback parameters with _.
These currently trigger ARG001 warnings; renaming keeps tests clear and lint-clean.
Proposed fix
- def fake_iteration(eng, batch):
+ def fake_iteration(eng, _batch):
saw_original.append(getattr(eng.optimizer, attr_name) is original)
return {CommonKeys.LOSS: torch.tensor(1.0)}
- def check_scaler(eng, batch):
+ def check_scaler(eng, _batch):
scaler_was_patched.append(eng.scaler.step is not original_scaler_step)
return {CommonKeys.LOSS: torch.tensor(0.5)}
- def fake_iteration(*args, **kwargs):
+ def fake_iteration(*_args, **_kwargs):
scaled = engine.loss_function()
return {CommonKeys.LOSS: scaled}Also applies to: 174-174, 220-220
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 91-91: Unused function argument: batch
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/engines/test_gradient_accumulation.py` at line 91, The test callback
function parameters that are intentionally unused (for example in function
fake_iteration) are triggering ARG001; rename those parameters by prefixing with
an underscore (e.g., change def fake_iteration(eng, batch): to def
fake_iteration(_eng, _batch):) and apply the same pattern to the other callbacks
mentioned (the occurrences around the other reported locations) so unused
arguments are clearly marked and lint-clean.
| acc_steps, lr = 4, 0.1 | ||
| batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)] | ||
|
|
||
| ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr) |
There was a problem hiding this comment.
Remove unused init_weight plumbing from model-pair helper.
init_weight is returned/unpacked but never used by tests.
Proposed fix
- ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+ ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
...
- ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+ ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
...
- ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+ ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
-def _make_model_pair(lr):
+def _make_model_pair(lr):
"""Create a reference and test model pair with identical initial weights."""
ref_model = nn.Linear(4, 1, bias=False)
init_weight = ref_model.weight.data.clone()
@@
- return ref_model, test_model, ref_opt, test_opt, init_weight
+ return ref_model, test_model, ref_opt, test_optAlso applies to: 271-271, 303-303, 328-339
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 240-240: Unpacked variable init_weight is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/engines/test_gradient_accumulation.py` at line 240, The helper
_make_model_pair currently returns a third value init_weight that callers (tests
in tests/engines/test_gradient_accumulation.py) unpack but never use; remove
this unused plumbing by changing _make_model_pair to return only (ref_model,
test_model, ref_opt, test_opt) and update all call sites (e.g., the unpack at
the shown line and similar occurrences at the other locations) to stop expecting
init_weight — adjust any tuple unpacking in the tests to four variables matching
the function's new signature.
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
monai/engines/utils.py (1)
406-407:⚠️ Potential issue | 🟡 MinorReject
boolforaccumulation_steps.
Truecurrently passes becauseboolis anintsubclass, so invalid config can silently map to1.Proposed fix
- if not isinstance(accumulation_steps, int) or accumulation_steps < 1: + if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1: raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/engines/utils.py` around lines 406 - 407, The current validation for accumulation_steps accepts bool because bool is an int subclass; update the check so booleans are rejected — e.g., replace the isinstance(accumulation_steps, int) test with a stricter type check (such as type(accumulation_steps) is int or add an explicit not isinstance(accumulation_steps, bool) condition) so that accumulation_steps must be a genuine int and >= 1; adjust the ValueError message if needed to reflect the stricter type requirement and keep the existing check for accumulation_steps < 1.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Around line 28-29: The test data in INVALID_ACCUMULATION_STEPS misses a
boolean edge case; update the tuples in INVALID_ACCUMULATION_STEPS (and the
similar list at lines 58-63 referenced in the comment) to include True as an
invalid input (e.g., add (True,) alongside (0,), (-1,), (2.5,), ("2",)) so the
test suite covers the bool-as-int validation bug for the functions that consume
INVALID_ACCUMULATION_STEPS.
---
Duplicate comments:
In `@monai/engines/utils.py`:
- Around line 406-407: The current validation for accumulation_steps accepts
bool because bool is an int subclass; update the check so booleans are rejected
— e.g., replace the isinstance(accumulation_steps, int) test with a stricter
type check (such as type(accumulation_steps) is int or add an explicit not
isinstance(accumulation_steps, bool) condition) so that accumulation_steps must
be a genuine int and >= 1; adjust the ValueError message if needed to reflect
the stricter type requirement and keep the existing check for accumulation_steps
< 1.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/engines/utils.pytests/engines/test_gradient_accumulation.py
| INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)] | ||
|
|
There was a problem hiding this comment.
Add explicit bool invalid-input coverage.
This suite misses True, which is the key edge case for the bool-as-int validation bug.
Proposed fix
-INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)]
+INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",), (True,), (False,)]As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."
Also applies to: 58-63
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/engines/test_gradient_accumulation.py` around lines 28 - 29, The test
data in INVALID_ACCUMULATION_STEPS misses a boolean edge case; update the tuples
in INVALID_ACCUMULATION_STEPS (and the similar list at lines 58-63 referenced in
the comment) to include True as an invalid input (e.g., add (True,) alongside
(0,), (-1,), (2.5,), ("2",)) so the test suite covers the bool-as-int validation
bug for the functions that consume INVALID_ACCUMULATION_STEPS.
Summary
GradientAccumulationcallable class inmonai.engines.utilsfor use asiteration_updateinSupervisedTrainer, enabling gradient accumulation over multiple mini-batches to simulate larger effective batch sizes on memory-constrained hardwareiteration_updatepattern established byInteractioninmonai.apps.deepedit(as referenced by @wyli in Add gradient accumulation logic to SupervisedTrainer #6101)IterationEventsfire every mini-batch, so existing handlers are unaffectedepoch_length % accumulation_steps != 0GradScaler) support includedCloses #6100
Supersedes #6101
Usage
Types of changes
Test plan
accumulation_steps=1zero_grad/optimizer.stepsuppression patterns verified across full epochsepoch_lengthnot divisible byaccumulation_stepsepoch_length=None) — no epoch flushtry/finally)GradScalerpatching when step suppressed, not patched when steppingscaler=Noneedge cases_iteration