8498 trainerclassabstraction#8950
Conversation
for more information, see https://pre-commit.ci
DCO Remediation Commit for Enoch Mok <enochmokny@gmail.com> I, Enoch Mok <enochmokny@gmail.com>, hereby add my Signed-off-by to this commit: d02da49 Signed-off-by: Enoch Mok <enochmokny@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
…tch; used ./runtests.sh autofix Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
for more information, see https://pre-commit.ci
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
… updated download_url function Signed-off-by: Enoch Mok <enochmokny@gmail.com>
for more information, see https://pre-commit.ci
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Enoch Mok <65853622+e-mny@users.noreply.github.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
…ner base classes Signed-off-by: Enoch Mok <enochmokny@gmail.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthrough
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 3❌ Failed checks (2 warnings, 1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/test_utils.py (1)
80-84: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
skip_if_downloading_fails()still won’t skip hash mismatches.
download_url()andextractall()now raiseHashCheckError(ValueError)on hash failure, but this helper only filtersRuntimeError/OSError. So the new"hash check"entry never catches the changed failure mode.Suggested diff
- except (RuntimeError, OSError) as rt_e: + except (RuntimeError, OSError, ValueError) as rt_e: err_str = str(rt_e) if any(k in err_str for k in DOWNLOAD_FAIL_MSGS): raise unittest.SkipTest(f"Error while downloading: {rt_e}") from rt_e # incomplete download🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_utils.py` around lines 80 - 84, The skip_if_downloading_fails() helper still misses hash mismatch failures because download_url() and extractall() now raise HashCheckError(ValueError) instead of only RuntimeError/OSError. Update the exception handling in skip_if_downloading_fails() to recognize the new HashCheckError path and ensure the existing DOWNLOAD_FAIL_MSGS "hash check" entry is matched when hash validation fails, so those tests are skipped consistently.
🧹 Nitpick comments (3)
monai/apps/utils.py (1)
51-52: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAdd a docstring to the new public exception.
HashCheckErroris part of the module API now. Give it a short Google-style docstring so callers know which failure paths raise it.Suggested diff
class HashCheckError(ValueError): + """Raised when download or extraction fails hash validation.""" passAs per path instructions,
**/*.py:Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/apps/utils.py` around lines 51 - 52, Add a short Google-style docstring to HashCheckError in utils.py since it is now part of the public API. Update the class definition to describe that this exception is raised for hash check failures, and make sure the docstring is concise and follows the module’s docstring style conventions.Source: Path instructions
monai/engines/trainer.py (2)
86-104: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winUse
_unpack_batch()inAdversarialTrainertoo.The new helper centralizes this contract, but Lines 846-851 still duplicate the same 2/4-tuple branch.
Also applies to: 846-851
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/engines/trainer.py` around lines 86 - 104, The batch unpacking logic is duplicated in AdversarialTrainer instead of using the new _unpack_batch() helper from Trainer. Update the AdversarialTrainer batch handling path to call _unpack_batch() for the prepared batch so it follows the same 2-tuple/4-tuple contract as the rest of the trainer code, and remove the local branching that manually expands inputs, targets, args, and kwargs.
61-61: 🩺 Stability & Availability | 🔵 TrivialUse the legacy
torch.cuda.amp.GradScalerto ensure consistency.While the
torch>=2.8.0constraint ensurestorch.amp.GradScaleris available, other parts of the codebase, specificallymonai/engines/workflow.pyandtests/integration/test_integration_fast_train.py, utilize the legacytorch.cuda.amp.GradScaler. Align these instances for consistency.File: monai/engines/trainer.py
# Line 61 self.scaler = torch.amp.GradScaler("cuda") if self.amp else None # Lines 782-783 self.state.g_scaler = torch.amp.GradScaler("cuda") if self.amp else None self.state.d_scaler = torch.amp.GradScaler("cuda") if self.amp else None🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/engines/trainer.py` at line 61, The AMP scaler usage in Trainer is inconsistent with the rest of the codebase: `Trainer.__init__` and the GAN-related scaler setup in `Trainer.run`/state initialization use `torch.amp.GradScaler("cuda")` instead of the legacy `torch.cuda.amp.GradScaler`. Update the `scaler`, `g_scaler`, and `d_scaler` assignments in `monai/engines/trainer.py` to use `torch.cuda.amp.GradScaler` so they match `workflow.py` and the integration test expectations.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/apps/utils.py`:
- Around line 227-235: The docstring contract for the existing-file branch is
stale: the filepath/hash-validation path in the download helper now raises
HashCheckError(ValueError), not RuntimeError. Update the Google-style Raises
section in the same helper that checks filepath.exists() and calls check_hash so
it accurately lists HashCheckError for this branch and keeps the public
exception contract aligned with the actual behavior.
In `@monai/engines/trainer.py`:
- Around line 39-46: The export list in Trainer’s __all__ is not sorted and
triggers Ruff RUF022. Reorder the names in the __all__ array in
monai/engines/trainer.py so they are alphabetically sorted, keeping the same
exported symbols (Trainer, SingleNetworkTrainer, SupervisedTrainer,
DualNetworkTrainer, GanTrainer, AdversarialTrainer).
- Around line 145-150: The `Trainer` initializer currently only checks that
`accumulation_steps` is >= 1, so non-integer values like `1.5` still pass and
later break the stepping logic. Update the validation in `Trainer.__init__` to
enforce the documented integer contract for `accumulation_steps` by rejecting
non-integer values before the existing positive check, while keeping the current
error handling for values less than 1.
- Around line 87-104: The _unpack_batch helper in trainer.py currently passes
through any batch that is not length 2, which allows 3- or 5-item prepared
batches to fail later with a generic unpacking error. Update _unpack_batch to
explicitly accept only 2-item and 4-item batches, and raise a clear exception
for any other length before returning, so callers of _unpack_batch get an
immediate documented failure instead of an अस्पष्ट unpack error.
- Around line 145-156: The `compile` parameter in `Trainer.__init__` and the
related constructor usage shadows the Python built-in, so rename that parameter
to `_compile` while keeping the `self.compile` attribute unchanged. Update the
`torch.compile(...)` branch to check `_compile` instead of `compile`, and make
the same rename anywhere the constructor is called or referenced so the API
behavior stays the same and Ruff A002 is satisfied.
In `@tests/apps/test_download_and_extract.py`:
- Around line 38-74: The tests in test_download_and_extract currently reuse
testing_dir/testing_data, which makes download_url and extractall short-circuit
on preexisting fixtures and makes the hash-mismatch assertions order-dependent.
Update the test setup to use an isolated temporary workspace per test (for
example in the test class setup/teardown or inside each test) and point
filepath/output_dir there, so download_and_extract, download_url, and extractall
exercise fresh files and directories and the mismatch cases are meaningful.
---
Outside diff comments:
In `@tests/test_utils.py`:
- Around line 80-84: The skip_if_downloading_fails() helper still misses hash
mismatch failures because download_url() and extractall() now raise
HashCheckError(ValueError) instead of only RuntimeError/OSError. Update the
exception handling in skip_if_downloading_fails() to recognize the new
HashCheckError path and ensure the existing DOWNLOAD_FAIL_MSGS "hash check"
entry is matched when hash validation fails, so those tests are skipped
consistently.
---
Nitpick comments:
In `@monai/apps/utils.py`:
- Around line 51-52: Add a short Google-style docstring to HashCheckError in
utils.py since it is now part of the public API. Update the class definition to
describe that this exception is raised for hash check failures, and make sure
the docstring is concise and follows the module’s docstring style conventions.
In `@monai/engines/trainer.py`:
- Around line 86-104: The batch unpacking logic is duplicated in
AdversarialTrainer instead of using the new _unpack_batch() helper from Trainer.
Update the AdversarialTrainer batch handling path to call _unpack_batch() for
the prepared batch so it follows the same 2-tuple/4-tuple contract as the rest
of the trainer code, and remove the local branching that manually expands
inputs, targets, args, and kwargs.
- Line 61: The AMP scaler usage in Trainer is inconsistent with the rest of the
codebase: `Trainer.__init__` and the GAN-related scaler setup in
`Trainer.run`/state initialization use `torch.amp.GradScaler("cuda")` instead of
the legacy `torch.cuda.amp.GradScaler`. Update the `scaler`, `g_scaler`, and
`d_scaler` assignments in `monai/engines/trainer.py` to use
`torch.cuda.amp.GradScaler` so they match `workflow.py` and the integration test
expectations.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 367296f9-e362-43ae-9a22-ca85fb7c8e60
📒 Files selected for processing (5)
monai/apps/utils.pymonai/bundle/utils.pymonai/engines/trainer.pytests/apps/test_download_and_extract.pytests/test_utils.py
| ValueError: When the hash validation of the ``url`` downloaded file fails. | ||
| """ | ||
| if not filepath: | ||
| filepath = Path(".", _basename(url)).resolve() | ||
| logger.info(f"Default downloading to '{filepath}'") | ||
| filepath = Path(filepath) | ||
| if filepath.exists(): | ||
| if not check_hash(filepath, hash_val, hash_type): | ||
| raise RuntimeError( | ||
| f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}." | ||
| ) | ||
| raise HashCheckError(f"{hash_type} hash check of existing file failed: {filepath=}, expected {hash_type=}.") |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Fix the stale Raises contract for existing files.
Line 235 now raises HashCheckError(ValueError) for the pre-existing-file branch too, but the Raises section still advertises RuntimeError for that path. The public contract is now inconsistent.
Suggested diff
- RuntimeError: When the hash validation of the ``filepath`` existing file fails.
+ ValueError: When the hash validation of the ``filepath`` existing file fails.As per path instructions, **/*.py: Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/apps/utils.py` around lines 227 - 235, The docstring contract for the
existing-file branch is stale: the filepath/hash-validation path in the download
helper now raises HashCheckError(ValueError), not RuntimeError. Update the
Google-style Raises section in the same helper that checks filepath.exists() and
calls check_hash so it accurately lists HashCheckError for this branch and keeps
the public exception contract aligned with the actual behavior.
Source: Path instructions
| compile: bool = False, | ||
| compile_kwargs: dict | None = None, | ||
| **kwargs, | ||
| ) -> None: | ||
| if accumulation_steps < 1: | ||
| raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") | ||
| super().__init__(**kwargs) | ||
| if compile: | ||
| compile_kwargs = {} if compile_kwargs is None else compile_kwargs | ||
| network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] | ||
| self.network = network | ||
| self.compile = compile |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor
Resolve Ruff A002: Rename compile parameter
The parameter compile at lines 145 and 308 shadows the Python built-in. Rename the parameter to _compile to satisfy the linter while preserving the existing public self.compile attribute.
🧰 Tools
🪛 Ruff (0.15.18)
[error] 145-145: Function argument compile is shadowing a Python builtin
(A002)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/engines/trainer.py` around lines 145 - 156, The `compile` parameter in
`Trainer.__init__` and the related constructor usage shadows the Python
built-in, so rename that parameter to `_compile` while keeping the
`self.compile` attribute unchanged. Update the `torch.compile(...)` branch to
check `_compile` instead of `compile`, and make the same rename anywhere the
constructor is called or referenced so the API behavior stays the same and Ruff
A002 is satisfied.
| filepath = self.testing_dir / "MedNIST.tar.gz" | ||
| output_dir = self.testing_dir | ||
|
|
||
| with skip_if_downloading_fails(): | ||
| download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) | ||
| download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) | ||
| download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type) | ||
|
|
||
| wrong_md5 = "0" | ||
| with self.assertLogs(logger="monai.apps", level="ERROR"): | ||
| try: | ||
| download_url(url, filepath, wrong_md5) | ||
| except (ContentTooShortError, HTTPError, RuntimeError) as e: | ||
| if isinstance(e, RuntimeError): | ||
| # FIXME: skip MD5 check as current downloading method may fail | ||
| self.assertTrue(str(e).startswith("md5 check")) | ||
| return # skipping this test due the network connection errors | ||
|
|
||
| try: | ||
| extractall(filepath, output_dir, wrong_md5) | ||
| except RuntimeError as e: | ||
| self.assertTrue(str(e).startswith("md5 check")) | ||
| self.assertTrue(filepath.exists(), "Downloaded file does not exist") | ||
| self.assertTrue(any(output_dir.iterdir()), "Extraction output is empty") | ||
|
|
||
| @skip_if_quick | ||
| def test_download_url_hash_mismatch(self): | ||
| """download_url should raise ValueError on hash mismatch.""" | ||
| filepath = self.testing_dir / "MedNIST.tar.gz" | ||
|
|
||
| with skip_if_downloading_fails(): | ||
| # First ensure file is downloaded correctly | ||
| download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) | ||
|
|
||
| # Now test incorrect hash | ||
| with self.assertRaises(ValueError) as ctx: | ||
| download_url(self.url, filepath, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) | ||
|
|
||
| self.assertIn("hash check", str(ctx.exception).lower()) | ||
|
|
||
| @skip_if_quick | ||
| @parameterized.expand((("icon", "tar"), ("favicon", "zip"))) | ||
| def test_default(self, key, file_type): | ||
| def test_extractall_hash_mismatch(self): | ||
| """extractall should raise ValueError when hash is incorrect.""" | ||
| filepath = self.testing_dir / "MedNIST.tar.gz" | ||
| output_dir = self.testing_dir | ||
|
|
||
| with skip_if_downloading_fails(): | ||
| download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) | ||
|
|
||
| with self.assertRaises(ValueError) as ctx: | ||
| extractall(filepath, output_dir, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) | ||
|
|
||
| self.assertIn("hash check", str(ctx.exception).lower()) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Use an isolated temp workspace for these tests.
These tests write into tests/testing_data, but download_url() and extractall() intentionally short-circuit on existing files/directories. That makes the flow order-dependent, and Line 45 is effectively vacuous because tests/testing_data already contains fixtures.
Suggested diff
- filepath = self.testing_dir / "MedNIST.tar.gz"
- output_dir = self.testing_dir
-
- with skip_if_downloading_fails():
- download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type)
-
- self.assertTrue(filepath.exists(), "Downloaded file does not exist")
- self.assertTrue(any(output_dir.iterdir()), "Extraction output is empty")
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ tmp_path = Path(tmp_dir)
+ filepath = tmp_path / "MedNIST.tar.gz"
+ output_dir = tmp_path / "extract"
+
+ with skip_if_downloading_fails():
+ download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type)
+
+ extracted_dir = output_dir / "MedNIST"
+ self.assertTrue(filepath.exists(), "Downloaded file does not exist")
+ self.assertTrue(extracted_dir.exists() and any(extracted_dir.iterdir()), "Extraction output is empty")As per path instructions, **/*.py: Ensure new or modified definitions will be covered by existing or new unit tests.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| filepath = self.testing_dir / "MedNIST.tar.gz" | |
| output_dir = self.testing_dir | |
| with skip_if_downloading_fails(): | |
| download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) | |
| download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) | |
| download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type) | |
| wrong_md5 = "0" | |
| with self.assertLogs(logger="monai.apps", level="ERROR"): | |
| try: | |
| download_url(url, filepath, wrong_md5) | |
| except (ContentTooShortError, HTTPError, RuntimeError) as e: | |
| if isinstance(e, RuntimeError): | |
| # FIXME: skip MD5 check as current downloading method may fail | |
| self.assertTrue(str(e).startswith("md5 check")) | |
| return # skipping this test due the network connection errors | |
| try: | |
| extractall(filepath, output_dir, wrong_md5) | |
| except RuntimeError as e: | |
| self.assertTrue(str(e).startswith("md5 check")) | |
| self.assertTrue(filepath.exists(), "Downloaded file does not exist") | |
| self.assertTrue(any(output_dir.iterdir()), "Extraction output is empty") | |
| @skip_if_quick | |
| def test_download_url_hash_mismatch(self): | |
| """download_url should raise ValueError on hash mismatch.""" | |
| filepath = self.testing_dir / "MedNIST.tar.gz" | |
| with skip_if_downloading_fails(): | |
| # First ensure file is downloaded correctly | |
| download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) | |
| # Now test incorrect hash | |
| with self.assertRaises(ValueError) as ctx: | |
| download_url(self.url, filepath, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) | |
| self.assertIn("hash check", str(ctx.exception).lower()) | |
| @skip_if_quick | |
| @parameterized.expand((("icon", "tar"), ("favicon", "zip"))) | |
| def test_default(self, key, file_type): | |
| def test_extractall_hash_mismatch(self): | |
| """extractall should raise ValueError when hash is incorrect.""" | |
| filepath = self.testing_dir / "MedNIST.tar.gz" | |
| output_dir = self.testing_dir | |
| with skip_if_downloading_fails(): | |
| download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) | |
| with self.assertRaises(ValueError) as ctx: | |
| extractall(filepath, output_dir, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) | |
| self.assertIn("hash check", str(ctx.exception).lower()) | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| tmp_path = Path(tmp_dir) | |
| filepath = tmp_path / "MedNIST.tar.gz" | |
| output_dir = tmp_path / "extract" | |
| with skip_if_downloading_fails(): | |
| download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type) | |
| extracted_dir = output_dir / "MedNIST" | |
| self.assertTrue(filepath.exists(), "Downloaded file does not exist") | |
| self.assertTrue(extracted_dir.exists() and any(extracted_dir.iterdir()), "Extraction output is empty") | |
| `@skip_if_quick` | |
| def test_download_url_hash_mismatch(self): | |
| """download_url should raise ValueError on hash mismatch.""" | |
| filepath = self.testing_dir / "MedNIST.tar.gz" | |
| with skip_if_downloading_fails(): | |
| # First ensure file is downloaded correctly | |
| download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) | |
| # Now test incorrect hash | |
| with self.assertRaises(ValueError) as ctx: | |
| download_url(self.url, filepath, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) | |
| self.assertIn("hash check", str(ctx.exception).lower()) | |
| `@skip_if_quick` | |
| def test_extractall_hash_mismatch(self): | |
| """extractall should raise ValueError when hash is incorrect.""" | |
| filepath = self.testing_dir / "MedNIST.tar.gz" | |
| output_dir = self.testing_dir | |
| with skip_if_downloading_fails(): | |
| download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) | |
| with self.assertRaises(ValueError) as ctx: | |
| extractall(filepath, output_dir, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) | |
| self.assertIn("hash check", str(ctx.exception).lower()) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/apps/test_download_and_extract.py` around lines 38 - 74, The tests in
test_download_and_extract currently reuse testing_dir/testing_data, which makes
download_url and extractall short-circuit on preexisting fixtures and makes the
hash-mismatch assertions order-dependent. Update the test setup to use an
isolated temporary workspace per test (for example in the test class
setup/teardown or inside each test) and point filepath/output_dir there, so
download_and_extract, download_url, and extractall exercise fresh files and
directories and the mismatch cases are meaningful.
Source: Path instructions
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Fixes #8498
Description
Introduces two intermediate base classes to reduce duplication across the trainer hierarchy:
SingleNetworkTrainer(Trainer)— centralises ownership ofnetwork,optimizer,inferer,optim_set_to_none,accumulation_steps, andtorch.compilewrapping. Exposes_accumulation_flags()and_amp_step()helpers so subclasses implement only theirforward/loss logic.
SupervisedTrainernow inherits from this class.DualNetworkTrainer(Trainer)— centralises the generator/discriminator network pair(
g_*/d_*attributes).GanTrainerandAdversarialTrainernow inherit from this class.Also adds
Trainer._unpack_batch(), a static helper that eliminates the repeated2-vs-4-tuple batch unpacking pattern found in
SupervisedTrainerandAdversarialTrainer.All public
__init__signatures and runtime behaviour are unchanged.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.