diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 856bc64c9e..960330b4bd 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -48,6 +48,10 @@ SUPPORTED_HASH_TYPES = {"md5": hashlib.md5, "sha1": hashlib.sha1, "sha256": hashlib.sha256, "sha512": hashlib.sha512} +class HashCheckError(ValueError): + pass + + def get_logger( module_name: str = "monai.apps", fmt: str = DEFAULT_FMT, @@ -220,8 +224,7 @@ def download_url( HTTPError: See urllib.request.urlretrieve. ContentTooShortError: See urllib.request.urlretrieve. IOError: See urllib.request.urlretrieve. - RuntimeError: When the hash validation of the ``url`` downloaded file fails. - + ValueError: When the hash validation of the ``url`` downloaded file fails. """ if not filepath: filepath = Path(".", _basename(url)).resolve() @@ -229,9 +232,7 @@ def download_url( filepath = Path(filepath) if filepath.exists(): if not check_hash(filepath, hash_val, hash_type): - raise RuntimeError( - f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}." - ) + raise HashCheckError(f"{hash_type} hash check of existing file failed: {filepath=}, expected {hash_type=}.") logger.info(f"File exists: {filepath}, skipped downloading.") return try: @@ -260,6 +261,13 @@ def download_url( raise RuntimeError( f"Download of file from {url} to {filepath} failed due to network issue or denied permission." ) + if not check_hash(tmp_name, hash_val, hash_type): + raise HashCheckError( + f"{hash_type} hash check of downloaded file failed: {url=}, " + f"{filepath=}, expected {hash_type}={hash_val}, " + f"The file may be corrupted or tampered with. " + "Please retry the download or verify the source." + ) file_dir = filepath.parent if file_dir: os.makedirs(file_dir, exist_ok=True) @@ -267,11 +275,6 @@ def download_url( except (PermissionError, NotADirectoryError): # project-monai/monai issue #3613 #3757 for windows pass logger.info(f"Downloaded: {filepath}") - if not check_hash(filepath, hash_val, hash_type): - raise RuntimeError( - f"{hash_type} check of downloaded file failed: URL={url}, " - f"filepath={filepath}, expected {hash_type}={hash_val}." - ) def _extract_zip(filepath, output_dir): @@ -325,10 +328,15 @@ def extractall( be False. Raises: - RuntimeError: When the hash validation of the ``filepath`` compressed file fails. + ValueError: When the hash validation of the ``filepath`` compressed file fails. NotImplementedError: When the ``filepath`` file extension is not one of [zip", "tar.gz", "tar"]. """ + filepath = Path(filepath) + if hash_val and not check_hash(filepath, hash_val, hash_type): + raise HashCheckError( + f"{hash_type} hash check of compressed file failed: " f"{filepath=}, expected {hash_type}={hash_val}." + ) if has_base: # the extracted files will be in this folder cache_dir = Path(output_dir, _basename(filepath).split(".")[0]) @@ -337,11 +345,6 @@ def extractall( if cache_dir.exists() and next(cache_dir.iterdir(), None) is not None: logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.") return - filepath = Path(filepath) - if hash_val and not check_hash(filepath, hash_val, hash_type): - raise RuntimeError( - f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}." - ) logger.info(f"Writing into directory: {output_dir}.") _file_type = file_type.lower().strip() if filepath.name.endswith("zip") or _file_type == "zip": diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index d37d7f1c05..53d619f234 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -124,8 +124,10 @@ "run_name": None, # may fill it at runtime "save_execute_config": True, - "is_not_rank0": ("$torch.distributed.is_available() \ - and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0"), + "is_not_rank0": ( + "$torch.distributed.is_available() \ + and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0" + ), # MLFlowHandler config for the trainer "trainer": { "_target_": "MLFlowHandler", diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 921d54a59c..d314256b95 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -36,7 +36,14 @@ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") -__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer", "AdversarialTrainer"] +__all__ = [ + "AdversarialTrainer", + "DualNetworkTrainer", + "GanTrainer", + "SingleNetworkTrainer", + "SupervisedTrainer", + "Trainer", +] class Trainer(Workflow): @@ -51,7 +58,7 @@ def run(self) -> None: # type: ignore[override] If call this function multiple times, it will continuously run from the previous state. """ - self.scaler = torch.cuda.amp.GradScaler() if self.amp else None + self.scaler = torch.amp.GradScaler("cuda") if self.amp else None super().run() def get_stats(self, *vars): @@ -76,10 +83,149 @@ def get_stats(self, *vars): stats[k] = getattr(self.state, k, None) return stats + @staticmethod + def _unpack_batch(batch) -> tuple: + """ + Unpack a prepared batch into ``(inputs, targets, args, kwargs)``. + + If the batch contains exactly 2 elements, ``args`` and ``kwargs`` default to an empty + tuple and dict respectively. Otherwise the batch is expected to already be a 4-tuple. + + Args: + batch: output of ``prepare_batch``, either a 2-tuple ``(inputs, targets)`` + or a 4-tuple ``(inputs, targets, args, kwargs)``. + + Returns: + A 4-tuple ``(inputs, targets, args, kwargs)``. + """ + batch_len = len(batch) + if batch_len == 2: + inputs, targets = batch + return inputs, targets, (), {} + if batch_len != 4: + raise ValueError(f"Expected prepared batch to have 2 or 4 items, got {batch_len}.") + + return batch + + +class SingleNetworkTrainer(Trainer): + """ + Intermediate base class for trainers that operate with a single network and optimizer pair, + inherits from ``Trainer``. + + Centralises the ownership and initialisation of ``network``, ``optimizer``, ``inferer``, + ``optim_set_to_none``, ``accumulation_steps``, and optional ``torch.compile`` wrapping so + that concrete subclasses (e.g. ``SupervisedTrainer``) only need to provide their + task-specific loss and forward logic. + + Provides two helper methods intended for use in ``_iteration`` implementations: + + - :meth:`_accumulation_flags` — returns ``(should_zero_grad, should_step)`` for the + current iteration based on the configured ``accumulation_steps``. + - :meth:`_amp_step` — runs the scaled backward pass and conditional optimizer step, + handling both AMP and non-AMP code paths uniformly. + + Args: + network: network to train, should be regular PyTorch ``torch.nn.Module``. + optimizer: the optimizer associated to the network. + inferer: inference method that executes the model forward on input data. + Defaults to ``SimpleInferer()``. + optim_set_to_none: when calling ``optimizer.zero_grad()``, set grads to ``None`` + instead of zero. Default: ``False``. + accumulation_steps: number of mini-batches over which to accumulate gradients. + Must be a positive integer. Default: ``1`` (no accumulation). + compile: whether to wrap the network with ``torch.compile``. Default: ``False``. + compile_kwargs: keyword arguments forwarded to ``torch.compile()``. + **kwargs: remaining arguments forwarded to ``Trainer`` / ``Workflow``. + """ + + def __init__( + self, + network: torch.nn.Module, + optimizer: Optimizer, + inferer: Inferer | None = None, + optim_set_to_none: bool = False, + accumulation_steps: int = 1, + compile: bool = False, + compile_kwargs: dict | None = None, + **kwargs, + ) -> None: + if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1: + raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") + + super().__init__(**kwargs) + if compile: + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] + self.network = network + self.compile = compile + self.optimizer = optimizer + self.inferer = SimpleInferer() if inferer is None else inferer + self.optim_set_to_none = optim_set_to_none + self.accumulation_steps = accumulation_steps + + def _accumulation_flags(self, engine: SingleNetworkTrainer) -> tuple[bool, bool]: + """ + Return ``(should_zero_grad, should_step)`` for the current iteration. + + When ``accumulation_steps == 1`` both flags are always ``True``. + Otherwise the flags reflect whether the current iteration is the first or last step + in an accumulation window, with an end-of-epoch flush when ``epoch_length`` is known + and not evenly divisible by ``accumulation_steps``. + + Args: + engine: the trainer engine for the current iteration. + + Returns: + A 2-tuple ``(should_zero_grad, should_step)``. + """ + acc = engine.accumulation_steps + if acc == 1: + return True, True + epoch_length = engine.state.epoch_length + if epoch_length is not None: + local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch + should_zero_grad = local_iter % acc == 0 + should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length + else: + local_iter = engine.state.iteration - 1 # 0-indexed global + should_zero_grad = local_iter % acc == 0 + should_step = (local_iter + 1) % acc == 0 + return should_zero_grad, should_step + + def _amp_step(self, engine: SingleNetworkTrainer, loss: torch.Tensor, *, should_step: bool) -> None: + """ + Run the backward pass and, when ``should_step`` is ``True``, the optimizer step. -class SupervisedTrainer(Trainer): + Handles both the AMP (``engine.scaler`` present) and non-AMP code paths, and + divides the loss by ``accumulation_steps`` before calling ``.backward()`` so + gradients accumulate correctly across micro-batches. Fires + ``IterationEvents.BACKWARD_COMPLETED`` after the backward pass in both paths. + + Args: + engine: the trainer engine for the current iteration. + loss: the **unscaled** loss value (as stored in ``engine.state.output``). + should_step: whether to call ``optimizer.step()`` after the backward pass. + """ + acc = engine.accumulation_steps + scaled_loss = loss / acc if acc > 1 else loss + if engine.amp and engine.scaler is not None: + engine.scaler.scale(scaled_loss).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + if should_step: + engine.scaler.step(engine.optimizer) + engine.scaler.update() + else: + scaled_loss.backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + if should_step: + engine.optimizer.step() + + +class SupervisedTrainer(SingleNetworkTrainer): """ - Standard supervised training method with image and label, inherits from ``Trainer`` and ``Workflow``. + Standard supervised training method with image and label, inherits from ``SingleNetworkTrainer``, + ``Trainer`` and ``Workflow``. Args: device: an object representing the device on which to run. @@ -168,9 +314,16 @@ def __init__( compile_kwargs: dict | None = None, accumulation_steps: int = 1, ) -> None: - if accumulation_steps < 1: - raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") super().__init__( + # SingleNetworkTrainer args + network=network, + optimizer=optimizer, + inferer=inferer, + optim_set_to_none=optim_set_to_none, + accumulation_steps=accumulation_steps, + compile=compile, + compile_kwargs=compile_kwargs, + # Workflow args forwarded via **kwargs through SingleNetworkTrainer → Trainer → Workflow device=device, max_epochs=max_epochs, data_loader=train_data_loader, @@ -190,16 +343,7 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - if compile: - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] - self.network = network - self.compile = compile - self.optimizer = optimizer self.loss_function = loss_function - self.inferer = SimpleInferer() if inferer is None else inferer - self.optim_set_to_none = optim_set_to_none - self.accumulation_steps = accumulation_steps def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]) -> dict: """ @@ -221,12 +365,7 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) - if len(batch) == 2: - inputs, targets = batch - args: tuple = () - kwargs: dict = {} - else: - inputs, targets, args, kwargs = batch + inputs, targets, args, kwargs = self._unpack_batch(batch) # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 if self.compile: inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None @@ -255,21 +394,7 @@ def _compute_pred_loss(): engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean() engine.fire_event(IterationEvents.LOSS_COMPLETED) - # Determine gradient accumulation state - acc = engine.accumulation_steps - if acc > 1: - epoch_length = engine.state.epoch_length - if epoch_length is not None: - local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch - should_zero_grad = local_iter % acc == 0 - should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length - else: - local_iter = engine.state.iteration - 1 # 0-indexed global - should_zero_grad = local_iter % acc == 0 - should_step = (local_iter + 1) % acc == 0 - else: - should_zero_grad = True - should_step = True + should_zero_grad, should_step = self._accumulation_flags(engine) engine.network.train() if should_zero_grad: @@ -278,19 +403,11 @@ def _compute_pred_loss(): if engine.amp and engine.scaler is not None: with torch.autocast("cuda", **engine.amp_kwargs): _compute_pred_loss() - loss = engine.state.output[Keys.LOSS] - engine.scaler.scale(loss / acc if acc > 1 else loss).backward() - engine.fire_event(IterationEvents.BACKWARD_COMPLETED) - if should_step: - engine.scaler.step(engine.optimizer) - engine.scaler.update() else: _compute_pred_loss() - loss = engine.state.output[Keys.LOSS] - (loss / acc if acc > 1 else loss).backward() - engine.fire_event(IterationEvents.BACKWARD_COMPLETED) - if should_step: - engine.optimizer.step() + + self._amp_step(engine, engine.state.output[Keys.LOSS], should_step=should_step) + # copy back meta info if self.compile: if inputs_meta is not None: @@ -309,10 +426,54 @@ def _compute_pred_loss(): return engine.state.output -class GanTrainer(Trainer): +class DualNetworkTrainer(Trainer): + """ + Intermediate base class for trainers that manage a generator and discriminator network pair, + inherits from ``Trainer``. + + Centralises the ownership and initialisation of the generator (``g_network``, + ``g_optimizer``, ``g_inferer``) and discriminator (``d_network``, ``d_optimizer``, + ``d_inferer``) components, along with ``optim_set_to_none``, so that concrete + subclasses (e.g. ``GanTrainer``, ``AdversarialTrainer``) only need to provide their + task-specific training logic. + + Args: + g_network: generator (G) network architecture. + g_optimizer: G optimizer. + d_network: discriminator (D) network architecture. + d_optimizer: D optimizer. + g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``. + d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``. + optim_set_to_none: when calling ``optimizer.zero_grad()``, set grads to ``None`` + instead of zero. Default: ``False``. + **kwargs: remaining arguments forwarded to ``Trainer`` / ``Workflow``. + """ + + def __init__( + self, + g_network: torch.nn.Module, + g_optimizer: Optimizer, + d_network: torch.nn.Module, + d_optimizer: Optimizer, + g_inferer: Inferer | None = None, + d_inferer: Inferer | None = None, + optim_set_to_none: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.g_network = g_network + self.g_optimizer = g_optimizer + self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer + self.d_network = d_network + self.d_optimizer = d_optimizer + self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer + self.optim_set_to_none = optim_set_to_none + + +class GanTrainer(DualNetworkTrainer): """ Generative adversarial network training based on Goodfellow et al. 2014 https://arxiv.org/abs/1406.266, - inherits from ``Trainer`` and ``Workflow``. + inherits from ``DualNetworkTrainer``, ``Trainer`` and ``Workflow``. Training Loop: for each batch of data size `m` 1. Generate `m` fakes from random latent codes. @@ -407,6 +568,15 @@ def __init__( # set up Ignite engine and environments super().__init__( + # DualNetworkTrainer args + g_network=g_network, + g_optimizer=g_optimizer, + g_inferer=g_inferer, + d_network=d_network, + d_optimizer=d_optimizer, + d_inferer=d_inferer, + optim_set_to_none=optim_set_to_none, + # Workflow args forwarded via **kwargs through DualNetworkTrainer → Trainer → Workflow device=device, max_epochs=max_epochs, data_loader=train_data_loader, @@ -423,19 +593,12 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - self.g_network = g_network - self.g_optimizer = g_optimizer self.g_loss_function = g_loss_function - self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer - self.d_network = d_network - self.d_optimizer = d_optimizer self.d_loss_function = d_loss_function - self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer self.d_train_steps = d_train_steps self.latent_shape = latent_shape self.g_prepare_batch = g_prepare_batch self.g_update_latents = g_update_latents - self.optim_set_to_none = optim_set_to_none def _iteration( self, engine: GanTrainer, batchdata: dict | Sequence @@ -498,9 +661,10 @@ def _iteration( } -class AdversarialTrainer(Trainer): +class AdversarialTrainer(DualNetworkTrainer): """ - Standard supervised training workflow for adversarial loss enabled neural networks. + Standard supervised training workflow for adversarial loss enabled neural networks, + inherits from ``DualNetworkTrainer``, ``Trainer`` and ``Workflow``. Args: device: an object representing the device on which to run. @@ -579,6 +743,15 @@ def __init__( amp_kwargs: dict | None = None, ): super().__init__( + # DualNetworkTrainer args + g_network=g_network, + g_optimizer=g_optimizer, + g_inferer=g_inferer, + d_network=d_network, + d_optimizer=d_optimizer, + d_inferer=d_inferer, + optim_set_to_none=optim_set_to_none, + # Workflow args forwarded via **kwargs through DualNetworkTrainer → Trainer → Workflow device=device, max_epochs=max_epochs, data_loader=train_data_loader, @@ -601,22 +774,19 @@ def __init__( self.register_events(*AdversarialIterationEvents) - self.state.g_network = g_network - self.state.g_optimizer = g_optimizer + # Promote network/optimizer references into engine.state for checkpoint saving + self.state.g_network = self.g_network + self.state.g_optimizer = self.g_optimizer self.state.g_loss_function = g_loss_function self.state.recon_loss_function = recon_loss_function - self.state.d_network = d_network - self.state.d_optimizer = d_optimizer + self.state.d_network = self.d_network + self.state.d_optimizer = self.d_optimizer self.state.d_loss_function = d_loss_function - self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer - self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer + self.state.g_scaler = torch.amp.GradScaler("cuda") if self.amp else None + self.state.d_scaler = torch.amp.GradScaler("cuda") if self.amp else None - self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None - self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None - - self.optim_set_to_none = optim_set_to_none self._complete_state_dict_user_keys() def _complete_state_dict_user_keys(self) -> None: diff --git a/tests/apps/test_download_and_extract.py b/tests/apps/test_download_and_extract.py index 6d16a72735..a1e5381d90 100644 --- a/tests/apps/test_download_and_extract.py +++ b/tests/apps/test_download_and_extract.py @@ -16,7 +16,6 @@ import unittest import zipfile from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError from parameterized import parameterized @@ -26,39 +25,62 @@ @SkipIfNoModule("requests") class TestDownloadAndExtract(unittest.TestCase): + def setUp(self): + self.testing_dir = Path(__file__).parents[1] / "testing_data" + self.config = testing_data_config("images", "mednist") + self.url = self.config["url"] + self.hash_val = self.config["hash_val"] + self.hash_type = self.config["hash_type"] + @skip_if_quick - def test_actions(self): - testing_dir = Path(__file__).parents[1] / "testing_data" - config_dict = testing_data_config("images", "mednist") - url = config_dict["url"] - filepath = Path(testing_dir) / "MedNIST.tar.gz" - output_dir = Path(testing_dir) - hash_val, hash_type = config_dict["hash_val"], config_dict["hash_type"] + def test_download_and_extract_success(self): + """End-to-end: download and extract should succeed with correct hash.""" + filepath = self.testing_dir / "MedNIST.tar.gz" + output_dir = self.testing_dir + with skip_if_downloading_fails(): - download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) - download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) + download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type) - wrong_md5 = "0" - with self.assertLogs(logger="monai.apps", level="ERROR"): - try: - download_url(url, filepath, wrong_md5) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - if isinstance(e, RuntimeError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) - return # skipping this test due the network connection errors - - try: - extractall(filepath, output_dir, wrong_md5) - except RuntimeError as e: - self.assertTrue(str(e).startswith("md5 check")) + self.assertTrue(filepath.exists(), "Downloaded file does not exist") + self.assertTrue(any(output_dir.iterdir()), "Extraction output is empty") + + @skip_if_quick + def test_download_url_hash_mismatch(self): + """download_url should raise ValueError on hash mismatch.""" + filepath = self.testing_dir / "MedNIST.tar.gz" + + with skip_if_downloading_fails(): + # First ensure file is downloaded correctly + download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) + + # Now test incorrect hash + with self.assertRaises(ValueError) as ctx: + download_url(self.url, filepath, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) + + self.assertIn("hash check", str(ctx.exception).lower()) @skip_if_quick - @parameterized.expand((("icon", "tar"), ("favicon", "zip"))) - def test_default(self, key, file_type): + def test_extractall_hash_mismatch(self): + """extractall should raise ValueError when hash is incorrect.""" + filepath = self.testing_dir / "MedNIST.tar.gz" + output_dir = self.testing_dir + + with skip_if_downloading_fails(): + download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) + + with self.assertRaises(ValueError) as ctx: + extractall(filepath, output_dir, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) + + self.assertIn("hash check", str(ctx.exception).lower()) + + @skip_if_quick + @parameterized.expand([("icon", "tar"), ("favicon", "zip")]) + def test_download_and_extract_various_formats(self, key, file_type): + """Verify different archive formats download and extract correctly.""" with tempfile.TemporaryDirectory() as tmp_dir: + img_spec = testing_data_config("images", key) + with skip_if_downloading_fails(): - img_spec = testing_data_config("images", key) download_and_extract( img_spec["url"], output_dir=tmp_dir, @@ -67,6 +89,8 @@ def test_default(self, key, file_type): file_type=file_type, ) + self.assertTrue(any(Path(tmp_dir).iterdir()), f"Extraction failed for format: {file_type}") + class TestPathTraversalProtection(unittest.TestCase): """Test cases for path traversal attack protection in extractall function.""" diff --git a/tests/test_utils.py b/tests/test_utils.py index 5e21e48068..320a23d0cd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -81,7 +81,7 @@ "unexpected EOF", # incomplete download "network issue", "gdown dependency", # gdown not installed - "md5 check", + "hash check", # check hash value of downloaded file "limit", # HTTP Error 503: Egress is over the account limit "authenticate", "timed out", # urlopen error [Errno 110] Connection timed out @@ -186,37 +186,6 @@ def skip_if_downloading_fails(): raise rt_e -SAMPLE_TIFF = "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CMU-1.tiff" -SAMPLE_TIFF_HASH = "73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd" -SAMPLE_TIFF_HASH_TYPE = "sha256" - - -class TestDownloadUrl(unittest.TestCase): - """Exercise ``download_url`` success and hash-mismatch paths.""" - - def test_download_url(self): - """Download a sample TIFF and validate hash handling. - - Raises: - RuntimeError: When the downloaded file's hash does not match. - """ - with tempfile.TemporaryDirectory() as tempdir: - with skip_if_downloading_fails(): - download_url( - url=SAMPLE_TIFF, - filepath=os.path.join(tempdir, "model.tiff"), - hash_val=SAMPLE_TIFF_HASH, - hash_type=SAMPLE_TIFF_HASH_TYPE, - ) - with self.assertRaises(RuntimeError): - download_url( - url=SAMPLE_TIFF, - filepath=os.path.join(tempdir, "model_bad.tiff"), - hash_val="0" * 64, - hash_type=SAMPLE_TIFF_HASH_TYPE, - ) - - def test_pretrained_networks(network, input_param, device): with skip_if_downloading_fails(): return network(**input_param).to(device)