Skip to content

8498 trainerclassabstraction#8950

Open
e-mny wants to merge 27 commits into
Project-MONAI:devfrom
e-mny:8498-trainerclassabstraction
Open

8498 trainerclassabstraction#8950
e-mny wants to merge 27 commits into
Project-MONAI:devfrom
e-mny:8498-trainerclassabstraction

Conversation

@e-mny

@e-mny e-mny commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Fixes #8498

Description

Introduces two intermediate base classes to reduce duplication across the trainer hierarchy:

  • SingleNetworkTrainer(Trainer) — centralises ownership of network, optimizer, inferer,
    optim_set_to_none, accumulation_steps, and torch.compile wrapping. Exposes
    _accumulation_flags() and _amp_step() helpers so subclasses implement only their
    forward/loss logic. SupervisedTrainer now inherits from this class.
  • DualNetworkTrainer(Trainer) — centralises the generator/discriminator network pair
    (g_* / d_* attributes). GanTrainer and AdversarialTrainer now inherit from this class.

Also adds Trainer._unpack_batch(), a static helper that eliminates the repeated
2-vs-4-tuple batch unpacking pattern found in SupervisedTrainer and AdversarialTrainer.
All public __init__ signatures and runtime behaviour are unchanged.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

e-mny and others added 24 commits April 17, 2026 13:04
DCO Remediation Commit for Enoch Mok <enochmokny@gmail.com>

I, Enoch Mok <enochmokny@gmail.com>, hereby add my Signed-off-by to this commit: d02da49

Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
…tch; used ./runtests.sh autofix

Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
… updated download_url function

Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Signed-off-by: Enoch Mok <65853622+e-mny@users.noreply.github.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
Signed-off-by: Enoch Mok <enochmokny@gmail.com>
…ner base classes

Signed-off-by: Enoch Mok <enochmokny@gmail.com>
@coderabbitai

coderabbitai Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: cdfaf09c-e977-4d43-af51-a0a076755c25

📥 Commits

Reviewing files that changed from the base of the PR and between 32ebca7 and 54f5f60.

📒 Files selected for processing (1)
  • monai/engines/trainer.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • monai/engines/trainer.py

📝 Walkthrough

Walkthrough

monai/apps/utils.py adds HashCheckError and changes download_url and extractall to raise it for hash-validation failures. The related tests now assert ValueError and updated hash-check messages, and an older TIFF download test was removed. monai/engines/trainer.py introduces SingleNetworkTrainer and DualNetworkTrainer, updates supervised, GAN, and adversarial trainer wiring, and switches AMP scaler creation to torch.amp.GradScaler("cuda"). monai/bundle/utils.py reformats one distributed-rank setting string.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 3

❌ Failed checks (2 warnings, 1 inconclusive)

Check name Status Explanation Resolution
Out of Scope Changes check ⚠️ Warning Hash-validation and download/extraction test changes are unrelated to the trainer abstraction objective and appear out of scope. Split the hash-check and download/extract test updates into a separate PR focused on apps/utils behavior.
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title is generic and only references the issue number, so it does not clearly summarize the change. Use a descriptive title such as "Refactor trainer hierarchy with single- and dual-network base classes".
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The description matches the template and includes a summary plus change-type checkboxes.
Linked Issues check ✅ Passed The PR implements the requested trainer refactor by adding a generalized adversarial base class for GanTrainer and AdversarialTrainer.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/test_utils.py (1)

80-84: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

skip_if_downloading_fails() still won’t skip hash mismatches.

download_url() and extractall() now raise HashCheckError(ValueError) on hash failure, but this helper only filters RuntimeError/OSError. So the new "hash check" entry never catches the changed failure mode.

Suggested diff
-    except (RuntimeError, OSError) as rt_e:
+    except (RuntimeError, OSError, ValueError) as rt_e:
         err_str = str(rt_e)
         if any(k in err_str for k in DOWNLOAD_FAIL_MSGS):
             raise unittest.SkipTest(f"Error while downloading: {rt_e}") from rt_e  # incomplete download
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_utils.py` around lines 80 - 84, The skip_if_downloading_fails()
helper still misses hash mismatch failures because download_url() and
extractall() now raise HashCheckError(ValueError) instead of only
RuntimeError/OSError. Update the exception handling in
skip_if_downloading_fails() to recognize the new HashCheckError path and ensure
the existing DOWNLOAD_FAIL_MSGS "hash check" entry is matched when hash
validation fails, so those tests are skipped consistently.
🧹 Nitpick comments (3)
monai/apps/utils.py (1)

51-52: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Add a docstring to the new public exception.

HashCheckError is part of the module API now. Give it a short Google-style docstring so callers know which failure paths raise it.

Suggested diff
 class HashCheckError(ValueError):
+    """Raised when download or extraction fails hash validation."""
     pass

As per path instructions, **/*.py: Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/apps/utils.py` around lines 51 - 52, Add a short Google-style docstring
to HashCheckError in utils.py since it is now part of the public API. Update the
class definition to describe that this exception is raised for hash check
failures, and make sure the docstring is concise and follows the module’s
docstring style conventions.

Source: Path instructions

monai/engines/trainer.py (2)

86-104: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Use _unpack_batch() in AdversarialTrainer too.

The new helper centralizes this contract, but Lines 846-851 still duplicate the same 2/4-tuple branch.

Also applies to: 846-851

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/engines/trainer.py` around lines 86 - 104, The batch unpacking logic is
duplicated in AdversarialTrainer instead of using the new _unpack_batch() helper
from Trainer. Update the AdversarialTrainer batch handling path to call
_unpack_batch() for the prepared batch so it follows the same 2-tuple/4-tuple
contract as the rest of the trainer code, and remove the local branching that
manually expands inputs, targets, args, and kwargs.

61-61: 🩺 Stability & Availability | 🔵 Trivial

Use the legacy torch.cuda.amp.GradScaler to ensure consistency.

While the torch>=2.8.0 constraint ensures torch.amp.GradScaler is available, other parts of the codebase, specifically monai/engines/workflow.py and tests/integration/test_integration_fast_train.py, utilize the legacy torch.cuda.amp.GradScaler. Align these instances for consistency.

File: monai/engines/trainer.py
# Line 61
        self.scaler = torch.amp.GradScaler("cuda") if self.amp else None

# Lines 782-783
        self.state.g_scaler = torch.amp.GradScaler("cuda") if self.amp else None
        self.state.d_scaler = torch.amp.GradScaler("cuda") if self.amp else None
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/engines/trainer.py` at line 61, The AMP scaler usage in Trainer is
inconsistent with the rest of the codebase: `Trainer.__init__` and the
GAN-related scaler setup in `Trainer.run`/state initialization use
`torch.amp.GradScaler("cuda")` instead of the legacy
`torch.cuda.amp.GradScaler`. Update the `scaler`, `g_scaler`, and `d_scaler`
assignments in `monai/engines/trainer.py` to use `torch.cuda.amp.GradScaler` so
they match `workflow.py` and the integration test expectations.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/apps/utils.py`:
- Around line 227-235: The docstring contract for the existing-file branch is
stale: the filepath/hash-validation path in the download helper now raises
HashCheckError(ValueError), not RuntimeError. Update the Google-style Raises
section in the same helper that checks filepath.exists() and calls check_hash so
it accurately lists HashCheckError for this branch and keeps the public
exception contract aligned with the actual behavior.

In `@monai/engines/trainer.py`:
- Around line 39-46: The export list in Trainer’s __all__ is not sorted and
triggers Ruff RUF022. Reorder the names in the __all__ array in
monai/engines/trainer.py so they are alphabetically sorted, keeping the same
exported symbols (Trainer, SingleNetworkTrainer, SupervisedTrainer,
DualNetworkTrainer, GanTrainer, AdversarialTrainer).
- Around line 145-150: The `Trainer` initializer currently only checks that
`accumulation_steps` is >= 1, so non-integer values like `1.5` still pass and
later break the stepping logic. Update the validation in `Trainer.__init__` to
enforce the documented integer contract for `accumulation_steps` by rejecting
non-integer values before the existing positive check, while keeping the current
error handling for values less than 1.
- Around line 87-104: The _unpack_batch helper in trainer.py currently passes
through any batch that is not length 2, which allows 3- or 5-item prepared
batches to fail later with a generic unpacking error. Update _unpack_batch to
explicitly accept only 2-item and 4-item batches, and raise a clear exception
for any other length before returning, so callers of _unpack_batch get an
immediate documented failure instead of an अस्पष्ट unpack error.
- Around line 145-156: The `compile` parameter in `Trainer.__init__` and the
related constructor usage shadows the Python built-in, so rename that parameter
to `_compile` while keeping the `self.compile` attribute unchanged. Update the
`torch.compile(...)` branch to check `_compile` instead of `compile`, and make
the same rename anywhere the constructor is called or referenced so the API
behavior stays the same and Ruff A002 is satisfied.

In `@tests/apps/test_download_and_extract.py`:
- Around line 38-74: The tests in test_download_and_extract currently reuse
testing_dir/testing_data, which makes download_url and extractall short-circuit
on preexisting fixtures and makes the hash-mismatch assertions order-dependent.
Update the test setup to use an isolated temporary workspace per test (for
example in the test class setup/teardown or inside each test) and point
filepath/output_dir there, so download_and_extract, download_url, and extractall
exercise fresh files and directories and the mismatch cases are meaningful.

---

Outside diff comments:
In `@tests/test_utils.py`:
- Around line 80-84: The skip_if_downloading_fails() helper still misses hash
mismatch failures because download_url() and extractall() now raise
HashCheckError(ValueError) instead of only RuntimeError/OSError. Update the
exception handling in skip_if_downloading_fails() to recognize the new
HashCheckError path and ensure the existing DOWNLOAD_FAIL_MSGS "hash check"
entry is matched when hash validation fails, so those tests are skipped
consistently.

---

Nitpick comments:
In `@monai/apps/utils.py`:
- Around line 51-52: Add a short Google-style docstring to HashCheckError in
utils.py since it is now part of the public API. Update the class definition to
describe that this exception is raised for hash check failures, and make sure
the docstring is concise and follows the module’s docstring style conventions.

In `@monai/engines/trainer.py`:
- Around line 86-104: The batch unpacking logic is duplicated in
AdversarialTrainer instead of using the new _unpack_batch() helper from Trainer.
Update the AdversarialTrainer batch handling path to call _unpack_batch() for
the prepared batch so it follows the same 2-tuple/4-tuple contract as the rest
of the trainer code, and remove the local branching that manually expands
inputs, targets, args, and kwargs.
- Line 61: The AMP scaler usage in Trainer is inconsistent with the rest of the
codebase: `Trainer.__init__` and the GAN-related scaler setup in
`Trainer.run`/state initialization use `torch.amp.GradScaler("cuda")` instead of
the legacy `torch.cuda.amp.GradScaler`. Update the `scaler`, `g_scaler`, and
`d_scaler` assignments in `monai/engines/trainer.py` to use
`torch.cuda.amp.GradScaler` so they match `workflow.py` and the integration test
expectations.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 367296f9-e362-43ae-9a22-ca85fb7c8e60

📥 Commits

Reviewing files that changed from the base of the PR and between 87ff41a and 32ebca7.

📒 Files selected for processing (5)
  • monai/apps/utils.py
  • monai/bundle/utils.py
  • monai/engines/trainer.py
  • tests/apps/test_download_and_extract.py
  • tests/test_utils.py

Comment thread monai/apps/utils.py
Comment on lines +227 to +235
ValueError: When the hash validation of the ``url`` downloaded file fails.
"""
if not filepath:
filepath = Path(".", _basename(url)).resolve()
logger.info(f"Default downloading to '{filepath}'")
filepath = Path(filepath)
if filepath.exists():
if not check_hash(filepath, hash_val, hash_type):
raise RuntimeError(
f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}."
)
raise HashCheckError(f"{hash_type} hash check of existing file failed: {filepath=}, expected {hash_type=}.")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Fix the stale Raises contract for existing files.

Line 235 now raises HashCheckError(ValueError) for the pre-existing-file branch too, but the Raises section still advertises RuntimeError for that path. The public contract is now inconsistent.

Suggested diff
-        RuntimeError: When the hash validation of the ``filepath`` existing file fails.
+        ValueError: When the hash validation of the ``filepath`` existing file fails.

As per path instructions, **/*.py: Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/apps/utils.py` around lines 227 - 235, The docstring contract for the
existing-file branch is stale: the filepath/hash-validation path in the download
helper now raises HashCheckError(ValueError), not RuntimeError. Update the
Google-style Raises section in the same helper that checks filepath.exists() and
calls check_hash so it accurately lists HashCheckError for this branch and keeps
the public exception contract aligned with the actual behavior.

Source: Path instructions

Comment thread monai/engines/trainer.py
Comment thread monai/engines/trainer.py
Comment thread monai/engines/trainer.py
Comment thread monai/engines/trainer.py
Comment on lines +145 to +156
compile: bool = False,
compile_kwargs: dict | None = None,
**kwargs,
) -> None:
if accumulation_steps < 1:
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
super().__init__(**kwargs)
if compile:
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
self.network = network
self.compile = compile

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor

Resolve Ruff A002: Rename compile parameter

The parameter compile at lines 145 and 308 shadows the Python built-in. Rename the parameter to _compile to satisfy the linter while preserving the existing public self.compile attribute.

🧰 Tools
🪛 Ruff (0.15.18)

[error] 145-145: Function argument compile is shadowing a Python builtin

(A002)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/engines/trainer.py` around lines 145 - 156, The `compile` parameter in
`Trainer.__init__` and the related constructor usage shadows the Python
built-in, so rename that parameter to `_compile` while keeping the
`self.compile` attribute unchanged. Update the `torch.compile(...)` branch to
check `_compile` instead of `compile`, and make the same rename anywhere the
constructor is called or referenced so the API behavior stays the same and Ruff
A002 is satisfied.

Comment on lines +38 to +74
filepath = self.testing_dir / "MedNIST.tar.gz"
output_dir = self.testing_dir

with skip_if_downloading_fails():
download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type)
download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type)
download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type)

wrong_md5 = "0"
with self.assertLogs(logger="monai.apps", level="ERROR"):
try:
download_url(url, filepath, wrong_md5)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors

try:
extractall(filepath, output_dir, wrong_md5)
except RuntimeError as e:
self.assertTrue(str(e).startswith("md5 check"))
self.assertTrue(filepath.exists(), "Downloaded file does not exist")
self.assertTrue(any(output_dir.iterdir()), "Extraction output is empty")

@skip_if_quick
def test_download_url_hash_mismatch(self):
"""download_url should raise ValueError on hash mismatch."""
filepath = self.testing_dir / "MedNIST.tar.gz"

with skip_if_downloading_fails():
# First ensure file is downloaded correctly
download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type)

# Now test incorrect hash
with self.assertRaises(ValueError) as ctx:
download_url(self.url, filepath, hash_val="0" * len(self.hash_val), hash_type=self.hash_type)

self.assertIn("hash check", str(ctx.exception).lower())

@skip_if_quick
@parameterized.expand((("icon", "tar"), ("favicon", "zip")))
def test_default(self, key, file_type):
def test_extractall_hash_mismatch(self):
"""extractall should raise ValueError when hash is incorrect."""
filepath = self.testing_dir / "MedNIST.tar.gz"
output_dir = self.testing_dir

with skip_if_downloading_fails():
download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type)

with self.assertRaises(ValueError) as ctx:
extractall(filepath, output_dir, hash_val="0" * len(self.hash_val), hash_type=self.hash_type)

self.assertIn("hash check", str(ctx.exception).lower())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Use an isolated temp workspace for these tests.

These tests write into tests/testing_data, but download_url() and extractall() intentionally short-circuit on existing files/directories. That makes the flow order-dependent, and Line 45 is effectively vacuous because tests/testing_data already contains fixtures.

Suggested diff
-        filepath = self.testing_dir / "MedNIST.tar.gz"
-        output_dir = self.testing_dir
-
-        with skip_if_downloading_fails():
-            download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type)
-
-        self.assertTrue(filepath.exists(), "Downloaded file does not exist")
-        self.assertTrue(any(output_dir.iterdir()), "Extraction output is empty")
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            tmp_path = Path(tmp_dir)
+            filepath = tmp_path / "MedNIST.tar.gz"
+            output_dir = tmp_path / "extract"
+
+            with skip_if_downloading_fails():
+                download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type)
+
+            extracted_dir = output_dir / "MedNIST"
+            self.assertTrue(filepath.exists(), "Downloaded file does not exist")
+            self.assertTrue(extracted_dir.exists() and any(extracted_dir.iterdir()), "Extraction output is empty")

As per path instructions, **/*.py: Ensure new or modified definitions will be covered by existing or new unit tests.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
filepath = self.testing_dir / "MedNIST.tar.gz"
output_dir = self.testing_dir
with skip_if_downloading_fails():
download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type)
download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type)
download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type)
wrong_md5 = "0"
with self.assertLogs(logger="monai.apps", level="ERROR"):
try:
download_url(url, filepath, wrong_md5)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
if isinstance(e, RuntimeError):
# FIXME: skip MD5 check as current downloading method may fail
self.assertTrue(str(e).startswith("md5 check"))
return # skipping this test due the network connection errors
try:
extractall(filepath, output_dir, wrong_md5)
except RuntimeError as e:
self.assertTrue(str(e).startswith("md5 check"))
self.assertTrue(filepath.exists(), "Downloaded file does not exist")
self.assertTrue(any(output_dir.iterdir()), "Extraction output is empty")
@skip_if_quick
def test_download_url_hash_mismatch(self):
"""download_url should raise ValueError on hash mismatch."""
filepath = self.testing_dir / "MedNIST.tar.gz"
with skip_if_downloading_fails():
# First ensure file is downloaded correctly
download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type)
# Now test incorrect hash
with self.assertRaises(ValueError) as ctx:
download_url(self.url, filepath, hash_val="0" * len(self.hash_val), hash_type=self.hash_type)
self.assertIn("hash check", str(ctx.exception).lower())
@skip_if_quick
@parameterized.expand((("icon", "tar"), ("favicon", "zip")))
def test_default(self, key, file_type):
def test_extractall_hash_mismatch(self):
"""extractall should raise ValueError when hash is incorrect."""
filepath = self.testing_dir / "MedNIST.tar.gz"
output_dir = self.testing_dir
with skip_if_downloading_fails():
download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type)
with self.assertRaises(ValueError) as ctx:
extractall(filepath, output_dir, hash_val="0" * len(self.hash_val), hash_type=self.hash_type)
self.assertIn("hash check", str(ctx.exception).lower())
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
filepath = tmp_path / "MedNIST.tar.gz"
output_dir = tmp_path / "extract"
with skip_if_downloading_fails():
download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type)
extracted_dir = output_dir / "MedNIST"
self.assertTrue(filepath.exists(), "Downloaded file does not exist")
self.assertTrue(extracted_dir.exists() and any(extracted_dir.iterdir()), "Extraction output is empty")
`@skip_if_quick`
def test_download_url_hash_mismatch(self):
"""download_url should raise ValueError on hash mismatch."""
filepath = self.testing_dir / "MedNIST.tar.gz"
with skip_if_downloading_fails():
# First ensure file is downloaded correctly
download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type)
# Now test incorrect hash
with self.assertRaises(ValueError) as ctx:
download_url(self.url, filepath, hash_val="0" * len(self.hash_val), hash_type=self.hash_type)
self.assertIn("hash check", str(ctx.exception).lower())
`@skip_if_quick`
def test_extractall_hash_mismatch(self):
"""extractall should raise ValueError when hash is incorrect."""
filepath = self.testing_dir / "MedNIST.tar.gz"
output_dir = self.testing_dir
with skip_if_downloading_fails():
download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type)
with self.assertRaises(ValueError) as ctx:
extractall(filepath, output_dir, hash_val="0" * len(self.hash_val), hash_type=self.hash_type)
self.assertIn("hash check", str(ctx.exception).lower())
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/apps/test_download_and_extract.py` around lines 38 - 74, The tests in
test_download_and_extract currently reuse testing_dir/testing_data, which makes
download_url and extractall short-circuit on preexisting fixtures and makes the
hash-mismatch assertions order-dependent. Update the test setup to use an
isolated temporary workspace per test (for example in the test class
setup/teardown or inside each test) and point filepath/output_dir there, so
download_and_extract, download_url, and extractall exercise fresh files and
directories and the mismatch cases are meaningful.

Source: Path instructions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Engines - refactor GANTrainer into a more general base class that can replace AdversarialTrainer

2 participants