Skip to content

Add GradientAccumulation utility for SupervisedTrainer#8763

Open
aymuos15 wants to merge 4 commits intoProject-MONAI:devfrom
aymuos15:feat/grad-accum-supervisedtrainer
Open

Add GradientAccumulation utility for SupervisedTrainer#8763
aymuos15 wants to merge 4 commits intoProject-MONAI:devfrom
aymuos15:feat/grad-accum-supervisedtrainer

Conversation

@aymuos15
Copy link
Contributor

@aymuos15 aymuos15 commented Mar 3, 2026

Summary

  • Adds GradientAccumulation callable class in monai.engines.utils for use as iteration_update in SupervisedTrainer, enabling gradient accumulation over multiple mini-batches to simulate larger effective batch sizes on memory-constrained hardware
  • Follows the callable-class iteration_update pattern established by Interaction in monai.apps.deepedit (as referenced by @wyli in Add gradient accumulation logic to SupervisedTrainer #6101)
  • All IterationEvents fire every mini-batch, so existing handlers are unaffected
  • Epoch boundary flush ensures no gradients are silently discarded when epoch_length % accumulation_steps != 0
  • Mixed-precision (GradScaler) support included

Closes #6100
Supersedes #6101

Usage

from monai.engines import SupervisedTrainer, GradientAccumulation

trainer = SupervisedTrainer(
    ...,
    iteration_update=GradientAccumulation(accumulation_steps=4),
)

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • New tests added to cover the changes.
  • In-line docstrings updated.

Test plan

  • Input validation (zero, negative, float, string)
  • Passthrough when accumulation_steps=1
  • zero_grad / optimizer.step suppression patterns verified across full epochs
  • Epoch boundary flush when epoch_length not divisible by accumulation_steps
  • Iterable dataset (epoch_length=None) — no epoch flush
  • Patching/restoration of all engine methods after each call
  • Restoration after exception (try/finally)
  • GradScaler patching when step suppressed, not patched when stepping
  • No scaler attribute / scaler=None edge cases
  • Batch data forwarded correctly to _iteration
  • Output loss is unscaled (original value for loggers/metrics)
  • Integration: gradient equivalence with manual accumulation (requires ignite)
  • Integration: epoch boundary flush equivalence (requires ignite)
  • Integration: multi-epoch correctness (requires ignite)

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 3, 2026

📝 Walkthrough

Walkthrough

Adds a new callable GradientAccumulation in monai/engines/utils.py that implements gradient accumulation (validates accumulation_steps, scales loss for backward, suppresses per-mini-batch optimizer/GradScaler updates, and restores original behavior including on exceptions). Exposes GradientAccumulation in utils.all and re-exports it from monai/engines/init.py. Adds tests/engines/test_gradient_accumulation.py with unit and integration tests covering validation, repr, passthrough for accumulation_steps==1, optimizer/scaler patching and restoration (including exception cases), loss scaling and return behavior, batch forwarding, and integration comparisons.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and concisely summarizes the main change: adding a GradientAccumulation utility for SupervisedTrainer.
Description check ✅ Passed Description covers all required sections: summary, usage example, types of changes, and detailed test plan. All template checkboxes are properly marked.
Linked Issues check ✅ Passed Changes fully meet issue #6100 objectives: built-in gradient accumulation, correct event semantics, epoch-boundary flush, iterable dataset support, GradScaler support, and clean callable-class API.
Out of Scope Changes check ✅ Passed All changes are scoped to gradient accumulation: new utility class, public API exports, and comprehensive tests. No unrelated modifications detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/engines/test_gradient_accumulation.py (1)

339-350: Drop the unused init_weight from the helper return.
It is not consumed by callers, so removing it tightens the helper contract and avoids dead unpacks downstream.

♻️ Proposed cleanup
@@
-def _make_model_pair(lr):
+def _make_model_pair(lr):
@@
-    return ref_model, test_model, ref_opt, test_opt, init_weight
+    return ref_model, test_model, ref_opt, test_opt
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` around lines 339 - 350, The
helper _make_model_pair currently returns an unused init_weight which tightens
its contract unnecessarily; remove the creation or cloning of init_weight from
inside _make_model_pair (or keep the local init copy only to set test_model
weights) and update the return tuple from _make_model_pair to return only
(ref_model, test_model, ref_opt, test_opt), then update any callers that unpack
the result to stop expecting the fifth value so there are no dead unpacks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Around line 339-350: The helper _make_model_pair currently returns an unused
init_weight which tightens its contract unnecessarily; remove the creation or
cloning of init_weight from inside _make_model_pair (or keep the local init copy
only to set test_model weights) and update the return tuple from
_make_model_pair to return only (ref_model, test_model, ref_opt, test_opt), then
update any callers that unpack the result to stop expecting the fifth value so
there are no dead unpacks.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 894068a and 597e086.

📒 Files selected for processing (3)
  • monai/engines/__init__.py
  • monai/engines/utils.py
  • tests/engines/test_gradient_accumulation.py

…ject-MONAI#6100)

Closes Project-MONAI#6100

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
@aymuos15 aymuos15 force-pushed the feat/grad-accum-supervisedtrainer branch from 597e086 to 1db8cc1 Compare March 3, 2026 11:08
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/engines/test_gradient_accumulation.py (1)

105-105: Consider marking intentionally-unused bindings with _ prefixes.

This keeps tests clear while avoiding avoidable lint noise.

🧹 Optional cleanup
-        def fake_iteration(eng, batch):
+        def fake_iteration(eng, _batch):
@@
-        def check_scaler(eng, batch):
+        def check_scaler(eng, _batch):
@@
-        def fake_iteration(*args, **kwargs):
+        def fake_iteration(*_args, **_kwargs):
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr)
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr)
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr)

Also applies to: 188-188, 234-234, 257-257, 287-287, 318-318

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` at line 105, The test defines
callback functions like fake_iteration(eng, batch) with parameters that are
intentionally unused; update these function signatures (and the other
occurrences at the same pattern) to mark unused parameters with leading
underscores (e.g., _eng, _batch or _batch_idx) so linters know the bindings are
intentionally unused—search for the function name fake_iteration and the similar
callback definitions at the other noted locations and rename the unused
parameters with _ prefixes.
monai/engines/utils.py (1)

366-368: Align new definitions with Google-style docstring sections.

_noop, __init__, and __repr__ should include explicit Args/Returns (and Raises where applicable) sections to match repo docstring policy.

♻️ Suggested docstring adjustments
 def _noop(*args: Any, **kwargs: Any) -> None:
-    """No-op callable used to suppress optimizer/scaler methods during gradient accumulation."""
+    """No-op callable used to suppress optimizer/scaler methods.
+
+    Args:
+        *args: Ignored positional arguments.
+        **kwargs: Ignored keyword arguments.
+
+    Returns:
+        None.
+    """

 class GradientAccumulation:
@@
     def __init__(self, accumulation_steps: int = 2) -> None:
+        """Initialize gradient accumulation behavior.
+
+        Args:
+            accumulation_steps: Number of mini-batches to accumulate before stepping.
+
+        Raises:
+            ValueError: If `accumulation_steps` is not a positive integer.
+        """
         if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
@@
     def __repr__(self) -> str:
+        """Return a debug-friendly representation.
+
+        Returns:
+            String representation with configured accumulation steps.
+        """
         return f"GradientAccumulation(accumulation_steps={self.accumulation_steps})"

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

Also applies to: 405-413

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/engines/utils.py` around lines 366 - 368, Add Google-style docstring
sections to the new definitions: for _noop include an "Args" section describing
*args and **kwargs and a "Returns" section noting it returns None; for the class
__init__ add an "Args" section for each parameter and a "Returns" section if
applicable (or state None) and an optional "Raises" section if it can raise
exceptions; for __repr__ add a "Returns" section describing the returned str.
Update the docstrings in functions/methods named _noop, __init__, and __repr__
to follow the repo's Google-style (Args, Returns, and Raises where needed) and
mirror the format used elsewhere in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@monai/engines/utils.py`:
- Around line 366-368: Add Google-style docstring sections to the new
definitions: for _noop include an "Args" section describing *args and **kwargs
and a "Returns" section noting it returns None; for the class __init__ add an
"Args" section for each parameter and a "Returns" section if applicable (or
state None) and an optional "Raises" section if it can raise exceptions; for
__repr__ add a "Returns" section describing the returned str. Update the
docstrings in functions/methods named _noop, __init__, and __repr__ to follow
the repo's Google-style (Args, Returns, and Raises where needed) and mirror the
format used elsewhere in the file.

In `@tests/engines/test_gradient_accumulation.py`:
- Line 105: The test defines callback functions like fake_iteration(eng, batch)
with parameters that are intentionally unused; update these function signatures
(and the other occurrences at the same pattern) to mark unused parameters with
leading underscores (e.g., _eng, _batch or _batch_idx) so linters know the
bindings are intentionally unused—search for the function name fake_iteration
and the similar callback definitions at the other noted locations and rename the
unused parameters with _ prefixes.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 597e086 and 1db8cc1.

📒 Files selected for processing (3)
  • monai/engines/__init__.py
  • monai/engines/utils.py
  • tests/engines/test_gradient_accumulation.py

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
monai/engines/utils.py (1)

413-413: Widen batchdata type hint in __call__.

batchdata: dict[str, Any] is tighter than common trainer inputs. Consider Any to avoid misleading static typing for tuple/list batch payloads.

Proposed fix
-    def __call__(self, engine: Any, batchdata: dict[str, Any]) -> dict:
+    def __call__(self, engine: Any, batchdata: Any) -> dict:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/engines/utils.py` at line 413, The type hint for the __call__ method
currently restricts batchdata to dict[str, Any], which is too narrow for
trainers that pass tuples/lists; change the signature of __call__ to accept
batchdata: Any (or more permissive Union types) so it can handle dict, tuple,
list, etc., and update any related type annotations/comments in the same
function (named __call__) and its callers to reflect the broader type.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/engines/utils.py`:
- Around line 406-407: The validation for accumulation_steps currently allows
booleans because bool is an int subclass; update the check in
monai.engines.utils (the accumulation_steps validation) to explicitly reject
bools — e.g., require type(accumulation_steps) is int or add "and not
isinstance(accumulation_steps, bool)" to the isinstance check — and keep the
existing lower-bound check (accumulation_steps < 1) so True/False no longer pass
validation.

In `@tests/engines/test_gradient_accumulation.py`:
- Line 91: The test callback function parameters that are intentionally unused
(for example in function fake_iteration) are triggering ARG001; rename those
parameters by prefixing with an underscore (e.g., change def fake_iteration(eng,
batch): to def fake_iteration(_eng, _batch):) and apply the same pattern to the
other callbacks mentioned (the occurrences around the other reported locations)
so unused arguments are clearly marked and lint-clean.
- Line 240: The helper _make_model_pair currently returns a third value
init_weight that callers (tests in tests/engines/test_gradient_accumulation.py)
unpack but never use; remove this unused plumbing by changing _make_model_pair
to return only (ref_model, test_model, ref_opt, test_opt) and update all call
sites (e.g., the unpack at the shown line and similar occurrences at the other
locations) to stop expecting init_weight — adjust any tuple unpacking in the
tests to four variables matching the function's new signature.

---

Nitpick comments:
In `@monai/engines/utils.py`:
- Line 413: The type hint for the __call__ method currently restricts batchdata
to dict[str, Any], which is too narrow for trainers that pass tuples/lists;
change the signature of __call__ to accept batchdata: Any (or more permissive
Union types) so it can handle dict, tuple, list, etc., and update any related
type annotations/comments in the same function (named __call__) and its callers
to reflect the broader type.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 1db8cc1 and a3eca14.

📒 Files selected for processing (2)
  • monai/engines/utils.py
  • tests/engines/test_gradient_accumulation.py

Comment on lines +406 to +407
if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Reject boolean values for accumulation_steps.

True currently passes validation because bool is an int subclass, so invalid config can silently map to 1.

Proposed fix
-        if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
+        if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1:
             raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 407-407: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/engines/utils.py` around lines 406 - 407, The validation for
accumulation_steps currently allows booleans because bool is an int subclass;
update the check in monai.engines.utils (the accumulation_steps validation) to
explicitly reject bools — e.g., require type(accumulation_steps) is int or add
"and not isinstance(accumulation_steps, bool)" to the isinstance check — and
keep the existing lower-bound check (accumulation_steps < 1) so True/False no
longer pass validation.


saw_original: list[bool] = []

def fake_iteration(eng, batch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Prefix intentionally unused test callback parameters with _.

These currently trigger ARG001 warnings; renaming keeps tests clear and lint-clean.

Proposed fix
-        def fake_iteration(eng, batch):
+        def fake_iteration(eng, _batch):
             saw_original.append(getattr(eng.optimizer, attr_name) is original)
             return {CommonKeys.LOSS: torch.tensor(1.0)}

-        def check_scaler(eng, batch):
+        def check_scaler(eng, _batch):
             scaler_was_patched.append(eng.scaler.step is not original_scaler_step)
             return {CommonKeys.LOSS: torch.tensor(0.5)}

-        def fake_iteration(*args, **kwargs):
+        def fake_iteration(*_args, **_kwargs):
             scaled = engine.loss_function()
             return {CommonKeys.LOSS: scaled}

Also applies to: 174-174, 220-220

🧰 Tools
🪛 Ruff (0.15.2)

[warning] 91-91: Unused function argument: batch

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` at line 91, The test callback
function parameters that are intentionally unused (for example in function
fake_iteration) are triggering ARG001; rename those parameters by prefixing with
an underscore (e.g., change def fake_iteration(eng, batch): to def
fake_iteration(_eng, _batch):) and apply the same pattern to the other callbacks
mentioned (the occurrences around the other reported locations) so unused
arguments are clearly marked and lint-clean.

acc_steps, lr = 4, 0.1
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)]

ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused init_weight plumbing from model-pair helper.

init_weight is returned/unpacked but never used by tests.

Proposed fix
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
...
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
...
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)

-def _make_model_pair(lr):
+def _make_model_pair(lr):
     """Create a reference and test model pair with identical initial weights."""
     ref_model = nn.Linear(4, 1, bias=False)
     init_weight = ref_model.weight.data.clone()
@@
-    return ref_model, test_model, ref_opt, test_opt, init_weight
+    return ref_model, test_model, ref_opt, test_opt

Also applies to: 271-271, 303-303, 328-339

🧰 Tools
🪛 Ruff (0.15.2)

[warning] 240-240: Unpacked variable init_weight is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` at line 240, The helper
_make_model_pair currently returns a third value init_weight that callers (tests
in tests/engines/test_gradient_accumulation.py) unpack but never use; remove
this unused plumbing by changing _make_model_pair to return only (ref_model,
test_model, ref_opt, test_opt) and update all call sites (e.g., the unpack at
the shown line and similar occurrences at the other locations) to stop expecting
init_weight — adjust any tuple unpacking in the tests to four variables matching
the function's new signature.

aymuos15 added 2 commits March 3, 2026 15:00
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
monai/engines/utils.py (1)

406-407: ⚠️ Potential issue | 🟡 Minor

Reject bool for accumulation_steps.

True currently passes because bool is an int subclass, so invalid config can silently map to 1.

Proposed fix
-        if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
+        if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1:
             raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/engines/utils.py` around lines 406 - 407, The current validation for
accumulation_steps accepts bool because bool is an int subclass; update the
check so booleans are rejected — e.g., replace the
isinstance(accumulation_steps, int) test with a stricter type check (such as
type(accumulation_steps) is int or add an explicit not
isinstance(accumulation_steps, bool) condition) so that accumulation_steps must
be a genuine int and >= 1; adjust the ValueError message if needed to reflect
the stricter type requirement and keep the existing check for accumulation_steps
< 1.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Around line 28-29: The test data in INVALID_ACCUMULATION_STEPS misses a
boolean edge case; update the tuples in INVALID_ACCUMULATION_STEPS (and the
similar list at lines 58-63 referenced in the comment) to include True as an
invalid input (e.g., add (True,) alongside (0,), (-1,), (2.5,), ("2",)) so the
test suite covers the bool-as-int validation bug for the functions that consume
INVALID_ACCUMULATION_STEPS.

---

Duplicate comments:
In `@monai/engines/utils.py`:
- Around line 406-407: The current validation for accumulation_steps accepts
bool because bool is an int subclass; update the check so booleans are rejected
— e.g., replace the isinstance(accumulation_steps, int) test with a stricter
type check (such as type(accumulation_steps) is int or add an explicit not
isinstance(accumulation_steps, bool) condition) so that accumulation_steps must
be a genuine int and >= 1; adjust the ValueError message if needed to reflect
the stricter type requirement and keep the existing check for accumulation_steps
< 1.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between a3eca14 and 53c5dc5.

📒 Files selected for processing (2)
  • monai/engines/utils.py
  • tests/engines/test_gradient_accumulation.py

Comment on lines +28 to +29
INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add explicit bool invalid-input coverage.

This suite misses True, which is the key edge case for the bool-as-int validation bug.

Proposed fix
-INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)]
+INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",), (True,), (False,)]

As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

Also applies to: 58-63

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` around lines 28 - 29, The test
data in INVALID_ACCUMULATION_STEPS misses a boolean edge case; update the tuples
in INVALID_ACCUMULATION_STEPS (and the similar list at lines 58-63 referenced in
the comment) to include True as an invalid input (e.g., add (True,) alongside
(0,), (-1,), (2.5,), ("2",)) so the test suite covers the bool-as-int validation
bug for the functions that consume INVALID_ACCUMULATION_STEPS.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add gradient accumulation functionality to SupervisedTrainer

1 participant