From d02da499b348e445cd29f5244052c37065a1a533 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Fri, 17 Apr 2026 13:04:41 +0800 Subject: [PATCH 01/19] Refactor download_url test into TestDownloadUrl class --- tests/test_utils.py | 35 ++++++++++++++++ tests/testing_data/data_config.json | 62 ++++++++++++++--------------- 2 files changed, 66 insertions(+), 31 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index f87b16fb71..03fa7abce3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -57,6 +57,8 @@ nib, _ = optional_import("nibabel") http_error, has_req = optional_import("requests", name="HTTPError") file_url_error, has_gdown = optional_import("gdown.exceptions", name="FileURLRetrievalError") +hf_http_error, has_hf_hub = optional_import("huggingface_hub.errors", name="HfHubHTTPError") +hf_local_entry_error, _has_hf_local = optional_import("huggingface_hub.errors", name="LocalEntryNotFoundError") quick_test_var = "QUICKTEST" @@ -70,6 +72,8 @@ DOWNLOAD_EXCEPTS += (http_error,) if has_gdown: DOWNLOAD_EXCEPTS += (file_url_error,) +if has_hf_hub: + DOWNLOAD_EXCEPTS += (hf_http_error, hf_local_entry_error) DOWNLOAD_FAIL_MSGS = ( "unexpected EOF", # incomplete download @@ -180,6 +184,37 @@ def skip_if_downloading_fails(): raise rt_e +SAMPLE_TIFF = "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CMU-1.tiff" +SAMPLE_TIFF_HASH = "73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd" +SAMPLE_TIFF_HASH_TYPE = "sha256" + + +class TestDownloadUrl(unittest.TestCase): + """Exercise ``download_url`` success and hash-mismatch paths.""" + + def test_download_url(self): + """Download a sample TIFF and validate hash handling. + + Raises: + RuntimeError: When the downloaded file's hash does not match. + """ + with tempfile.TemporaryDirectory() as tempdir: + with skip_if_downloading_fails(): + download_url( + url=SAMPLE_TIFF, + filepath=os.path.join(tempdir, "model.tiff"), + hash_val=SAMPLE_TIFF_HASH, + hash_type=SAMPLE_TIFF_HASH_TYPE, + ) + with self.assertRaises(RuntimeError): + download_url( + url=SAMPLE_TIFF, + filepath=os.path.join(tempdir, "model_bad.tiff"), + hash_val="0" * 64, + hash_type=SAMPLE_TIFF_HASH_TYPE, + ) + + def test_pretrained_networks(network, input_param, device): with skip_if_downloading_fails(): return network(**input_param).to(device) diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index 79033dd0d6..d05d7735f8 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -1,87 +1,87 @@ { "images": { "wsi_generic_tiff": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CMU-1.tiff", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CMU-1.tiff", "hash_type": "sha256", "hash_val": "73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd" }, "wsi_aperio_svs": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/Aperio-CMU-1.svs", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/Aperio-CMU-1.svs", "hash_type": "sha256", "hash_val": "00a3d54482cd707abf254fe69dccc8d06b8ff757a1663f1290c23418c480eb30" }, "favicon": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/favicon.ico.zip", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/favicon.ico.zip", "hash_type": "sha256", "hash_val": "3a3635c8d8adb81feebc5926b4106e8eb643a24a4be2a69a9d35f9a578acadb5" }, "icon": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/icon.tar.gz", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/icon.tar.gz", "hash_type": "sha256", "hash_val": "90f24cd8f20f3932624da95190ce384302261acf0ea15b358f7832e3b6becac0" }, "mednist": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/MedNIST.tar.gz", "hash_type": "sha256", "hash_val": "f2f4881ff8799a170b10a403495f0ce0ad7486491901cde67a647e6627e7f916" }, "Prostate_T2W_AX_1": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/Prostate_T2W_AX_1.nii", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/Prostate_T2W_AX_1.nii", "hash_type": "sha256", "hash_val": "a14231f539c0f365a5f83f2a046969a9b9870e56ffd126fd8e7242364d25938a" }, "0000_t2_tse_tra_4": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ProstateX-0000_t2_tse_tra_4.nii.gz", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/ProstateX-0000_t2_tse_tra_4.nii.gz", "hash_type": "md5", "hash_val": "adb3f1c4db66a6481c3e4a2a3033c7d5" }, "0000_ep2d_diff_tra_7": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ProstateX-0000_ep2d_diff_tra_7.nii.gz", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/ProstateX-0000_ep2d_diff_tra_7.nii.gz", "hash_type": "md5", "hash_val": "f12a11ad0ebb0b1876e9e010564745d2" }, "ref_avg152T1_LR": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/avg152T1_LR_nifti.nii.gz", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/avg152T1_LR_nifti.nii.gz", "hash_type": "sha256", "hash_val": "c01a50caa7a563158ecda43d93a1466bfc8aa939bc16b06452ac1089c54661c8" }, "ref_avg152T1_RL": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/avg152T1_RL_nifti.nii.gz", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/avg152T1_RL_nifti.nii.gz", "hash_type": "sha256", "hash_val": "8a731128dac4de46ccb2cc60d972b98f75a52f21fb63ddb040ca96f0aed8b51a" }, "MNI152_T1_2mm": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MNI152_T1_2mm.nii.gz", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/MNI152_T1_2mm.nii.gz", "hash_type": "sha256", "hash_val": "0585cd056bf5ccfb8bf97a5f6a66082d4e7caad525718fc11e40d80a827fcb92" }, "MNI152_T1_2mm_strucseg": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MNI152_T1_2mm_strucseg.nii.gz", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/MNI152_T1_2mm_strucseg.nii.gz", "hash_type": "sha256", "hash_val": "eb4f1e596ca85aadaefc359d409fb9a3e27d733e6def04b996953b7c54bc26d4" }, "copd1_highres_INSP_STD_COPD_img": { - "url": "https://data.kitware.com/api/v1/file/62a0f067bddec9d0c4175c5a/download", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/copd1_highres_INSP_STD_COPD_img.nii.gz", "hash_type": "sha512", "hash_val": "60193cd6ef0cf055c623046446b74f969a2be838444801bd32ad5bedc8a7eeecb343e8a1208769c9c7a711e101c806a3133eccdda7790c551a69a64b9b3701e9" }, "copd1_highres_EXP_STD_COPD_img": { - "url": "https://data.kitware.com/api/v1/item/62a0f045bddec9d0c4175c44/download", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/copd1_highres_EXP_STD_COPD_img.nii.gz", "hash_type": "sha512", "hash_val": "841ef303958541474e66c2d1ccdc8b7ed17ba2f2681101307766b979a07979f2ec818ddf13791c3f1ac5a8ec3258d6ea45b692b4b4a838de9188602618972b6d" }, "CT_2D_head_fixed": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_2D_head_fixed.mha", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CT_2D_head_fixed.mha", "hash_type": "sha256", "hash_val": "06f2ce6fbf6a59f0874c735555fcf71717f631156b1b0697c1752442f7fc1cc5" }, "CT_2D_head_moving": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_2D_head_moving.mha", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CT_2D_head_moving.mha", "hash_type": "sha256", "hash_val": "a37c5fe388c38b3f4ac564f456277d09d3982eda58c4da05ead8ee2332360f47" }, "DICOM_single": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_DICOM_SINGLE.zip", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CT_DICOM_SINGLE.zip", "hash_type": "sha256", "hash_val": "a41f6e93d2e3d68956144f9a847273041d36441da12377d6a1d5ae610e0a7023" }, @@ -93,69 +93,69 @@ }, "videos": { "endovis": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/d1_im.mp4", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/d1_im.mp4", "hash_type": "md5", "hash_val": "9b103c07326439b0ea376018d7189384" }, "ultrasound": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/example_data_Ultrasound_Q000_04_tu_segmented_ultrasound_256.avi", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/example_data_Ultrasound_Q000_04_tu_segmented_ultrasound_256.avi", "hash_type": "md5", "hash_val": "f0755960cc4a08a958561cda9a79a157" } }, "models": { "senet154-c7b49a05": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/senet154-c7b49a05.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/senet154-c7b49a05.pth", "hash_type": "sha256", "hash_val": "c7b49a056b98b0bed65b0237c27acdead655e599669215573d357ad337460413" }, "se_resnet101-7e38fcc6": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnet101-7e38fcc6.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/se_resnet101-7e38fcc6.pth", "hash_type": "sha256", "hash_val": "7e38fcc64eff3225a3ea4e6081efeb6087e8d5a61c204d94edc2ed1aab0b9d70" }, "se_resnet152-d17c99b7": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnet152-d17c99b7.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/se_resnet152-d17c99b7.pth", "hash_type": "sha256", "hash_val": "d17c99b703dcca2d2507ddfb68f72625a2f7e23ee64396eb992f1b2cf7e6bdc1" }, "se_resnet50-ce0d4300": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnet50-ce0d4300.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/se_resnet50-ce0d4300.pth", "hash_type": "sha256", "hash_val": "ce0d430017d3f4aa6b5658c72209f3bfffb060207fd26a2ef0b203ce592eba01" }, "se_resnext101_32x4d-3b2fe3d8": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext101_32x4d-3b2fe3d8.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/se_resnext101_32x4d-3b2fe3d8.pth", "hash_type": "sha256", "hash_val": "3b2fe3d8acb8de7d5976c4baf518f24a0237509272a69366e816682d3e57b989" }, "se_resnext50_32x4d-a260b3a4": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext50_32x4d-a260b3a4.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/se_resnext50_32x4d-a260b3a4.pth", "hash_type": "sha256", "hash_val": "a260b3a40f82dfe37c58d26a612bcf7bef0d27c6fed096226b0e4e9fb364168e" }, "ssl_pretrained_weights": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/ssl_pretrained_weights.pth", "hash_type": "sha256", "hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8" }, "decoder_only_transformer_monai_generative_weights": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/decoder_only_transformer.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/decoder_only_transformer.pth", "hash_type": "sha256", "hash_val": "f93de37d64d77cf91f3bde95cdf93d161aee800074c89a92aff9d5699120ec0d" }, "diffusion_model_unet_monai_generative_weights": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/diffusion_model_unet.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/diffusion_model_unet.pth", "hash_type": "sha256", "hash_val": "0d2171b386902f5b4fd3e967b4024f63e353694ca45091b114970019d045beee" }, "autoencoderkl_monai_generative_weights": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/autoencoderkl.pth", "hash_type": "sha256", "hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184" }, "controlnet_monai_generative_weights": { - "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/controlnet.pth", + "url": "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/controlnet.pth", "hash_type": "sha256", "hash_val": "cd100d0c69f47569ae5b4b7df653a1cb19f5e02eff1630db3210e2646fb1ab2e" } @@ -167,4 +167,4 @@ "hash_val": "06954cad2cc5d3784e72077ac76f0fc8" } } -} +} \ No newline at end of file From 6296784055e0d0d4bd4cccb199a5046ffd403a7a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Apr 2026 05:13:28 +0000 Subject: [PATCH 02/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/testing_data/data_config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index d05d7735f8..a6f4e90940 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -167,4 +167,4 @@ "hash_val": "06954cad2cc5d3784e72077ac76f0fc8" } } -} \ No newline at end of file +} From 8a6b6f6ca046c5c6e5d4bce6c61bd187d648ba6a Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Sat, 18 Apr 2026 10:25:01 +0800 Subject: [PATCH 03/19] Enoch Mok DCO Remediation Commit for Enoch Mok I, Enoch Mok , hereby add my Signed-off-by to this commit: d02da499b348e445cd29f5244052c37065a1a533 Signed-off-by: Enoch Mok --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 3639ac1a9e..5aacae7cee 100644 --- a/README.md +++ b/README.md @@ -102,3 +102,4 @@ Ask and answer questions over on [MONAI's GitHub Discussions tab](https://github - conda-forge: - Weekly previews: - Docker Hub: + From 66af96467435d0b935dd79cb7b0c85698a9145d6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 18 Apr 2026 02:25:22 +0000 Subject: [PATCH 04/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 5aacae7cee..3639ac1a9e 100644 --- a/README.md +++ b/README.md @@ -102,4 +102,3 @@ Ask and answer questions over on [MONAI's GitHub Discussions tab](https://github - conda-forge: - Weekly previews: - Docker Hub: - From 834a8b401f85c5333da5a31e9f80b4c587a0a812 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Sat, 18 Apr 2026 12:54:07 +0800 Subject: [PATCH 05/19] Validate downloaded file integrity and raise ValueError on hash mismatch Signed-off-by: Enoch Mok --- monai/apps/utils.py | 8 +++++--- tests/test_utils.py | 7 +++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 856bc64c9e..9607e9a7ef 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -268,9 +268,11 @@ def download_url( pass logger.info(f"Downloaded: {filepath}") if not check_hash(filepath, hash_val, hash_type): - raise RuntimeError( - f"{hash_type} check of downloaded file failed: URL={url}, " - f"filepath={filepath}, expected {hash_type}={hash_val}." + raise ValueError( + f"{hash_type} hash check of downloaded file failed: URL={url}, " + f"filepath={filepath}, expected {hash_type}={hash_val}, " + f"The file may be corrupted or tampered with. " + "Please retry the download or verify the source." ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 03fa7abce3..2773b98900 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -79,7 +79,7 @@ "unexpected EOF", # incomplete download "network issue", "gdown dependency", # gdown not installed - "md5 check", + "hash check", # check hash value of downloaded file "limit", # HTTP Error 503: Egress is over the account limit "authenticate", "timed out", # urlopen error [Errno 110] Connection timed out @@ -180,8 +180,11 @@ def skip_if_downloading_fails(): err_str = str(rt_e) if any(k in err_str for k in DOWNLOAD_FAIL_MSGS): raise unittest.SkipTest(f"Error while downloading: {rt_e}") from rt_e # incomplete download - + raise rt_e + except ValueError as v_e: + if "hash check" in str(v_e): + raise unittest.SkipTest(f"Hash value error while downloading: {v_e}") from v_e SAMPLE_TIFF = "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CMU-1.tiff" From 51a49b0f6bd4c2fb7fd60b646b49ce3e5e12dbe7 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Sat, 18 Apr 2026 12:59:35 +0800 Subject: [PATCH 06/19] Validate downloaded file integrity and raise ValueError on hash mismatch; used ./runtests.sh autofix Signed-off-by: Enoch Mok --- tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2773b98900..014f697ef9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -79,7 +79,7 @@ "unexpected EOF", # incomplete download "network issue", "gdown dependency", # gdown not installed - "hash check", # check hash value of downloaded file + "hash check", # check hash value of downloaded file "limit", # HTTP Error 503: Egress is over the account limit "authenticate", "timed out", # urlopen error [Errno 110] Connection timed out @@ -180,7 +180,7 @@ def skip_if_downloading_fails(): err_str = str(rt_e) if any(k in err_str for k in DOWNLOAD_FAIL_MSGS): raise unittest.SkipTest(f"Error while downloading: {rt_e}") from rt_e # incomplete download - + raise rt_e except ValueError as v_e: if "hash check" in str(v_e): From 98ccecc3f8ed93f99081cec7c74aee3855ab58b7 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Sat, 18 Apr 2026 13:37:22 +0800 Subject: [PATCH 07/19] amended runtimeerrors -> valueerror Signed-off-by: Enoch Mok --- monai/apps/utils.py | 17 ++++++++--------- tests/test_utils.py | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 9607e9a7ef..ce9873b9e2 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -220,8 +220,7 @@ def download_url( HTTPError: See urllib.request.urlretrieve. ContentTooShortError: See urllib.request.urlretrieve. IOError: See urllib.request.urlretrieve. - RuntimeError: When the hash validation of the ``url`` downloaded file fails. - + ValueError: When the hash validation of the ``url`` downloaded file fails. """ if not filepath: filepath = Path(".", _basename(url)).resolve() @@ -260,6 +259,13 @@ def download_url( raise RuntimeError( f"Download of file from {url} to {filepath} failed due to network issue or denied permission." ) + if not check_hash(filepath, hash_val, hash_type): + raise ValueError( + f"{hash_type} hash check of downloaded file failed: URL={url}, " + f"filepath={filepath}, expected {hash_type}={hash_val}, " + f"The file may be corrupted or tampered with. " + "Please retry the download or verify the source." + ) file_dir = filepath.parent if file_dir: os.makedirs(file_dir, exist_ok=True) @@ -267,13 +273,6 @@ def download_url( except (PermissionError, NotADirectoryError): # project-monai/monai issue #3613 #3757 for windows pass logger.info(f"Downloaded: {filepath}") - if not check_hash(filepath, hash_val, hash_type): - raise ValueError( - f"{hash_type} hash check of downloaded file failed: URL={url}, " - f"filepath={filepath}, expected {hash_type}={hash_val}, " - f"The file may be corrupted or tampered with. " - "Please retry the download or verify the source." - ) def _extract_zip(filepath, output_dir): diff --git a/tests/test_utils.py b/tests/test_utils.py index 014f697ef9..f7db852cfa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -199,7 +199,7 @@ def test_download_url(self): """Download a sample TIFF and validate hash handling. Raises: - RuntimeError: When the downloaded file's hash does not match. + ValueError: When the downloaded file's hash does not match. """ with tempfile.TemporaryDirectory() as tempdir: with skip_if_downloading_fails(): @@ -209,7 +209,7 @@ def test_download_url(self): hash_val=SAMPLE_TIFF_HASH, hash_type=SAMPLE_TIFF_HASH_TYPE, ) - with self.assertRaises(RuntimeError): + with self.assertRaises(ValueError): download_url( url=SAMPLE_TIFF, filepath=os.path.join(tempdir, "model_bad.tiff"), From 29ec2fe7038f4a94d4198c95dee80e397e9a26c9 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Sat, 18 Apr 2026 15:59:53 +0800 Subject: [PATCH 08/19] fixing code according to coderabbitai Signed-off-by: Enoch Mok --- monai/apps/utils.py | 6 +++--- tests/test_utils.py | 17 +++++++---------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index ce9873b9e2..264425e9b9 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -228,8 +228,8 @@ def download_url( filepath = Path(filepath) if filepath.exists(): if not check_hash(filepath, hash_val, hash_type): - raise RuntimeError( - f"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}." + raise ValueError( + f"{hash_type} hash check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}." ) logger.info(f"File exists: {filepath}, skipped downloading.") return @@ -259,7 +259,7 @@ def download_url( raise RuntimeError( f"Download of file from {url} to {filepath} failed due to network issue or denied permission." ) - if not check_hash(filepath, hash_val, hash_type): + if not check_hash(tmp_name, hash_val, hash_type): raise ValueError( f"{hash_type} hash check of downloaded file failed: URL={url}, " f"filepath={filepath}, expected {hash_type}={hash_val}, " diff --git a/tests/test_utils.py b/tests/test_utils.py index f7db852cfa..e5418add64 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -186,6 +186,8 @@ def skip_if_downloading_fails(): if "hash check" in str(v_e): raise unittest.SkipTest(f"Hash value error while downloading: {v_e}") from v_e + raise v_e + SAMPLE_TIFF = "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CMU-1.tiff" SAMPLE_TIFF_HASH = "73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd" @@ -202,20 +204,15 @@ def test_download_url(self): ValueError: When the downloaded file's hash does not match. """ with tempfile.TemporaryDirectory() as tempdir: + model_path = os.path.join(tempdir, "model.tiff") + with skip_if_downloading_fails(): download_url( - url=SAMPLE_TIFF, - filepath=os.path.join(tempdir, "model.tiff"), - hash_val=SAMPLE_TIFF_HASH, - hash_type=SAMPLE_TIFF_HASH_TYPE, + url=SAMPLE_TIFF, filepath=model_path, hash_val=SAMPLE_TIFF_HASH, hash_type=SAMPLE_TIFF_HASH_TYPE ) with self.assertRaises(ValueError): - download_url( - url=SAMPLE_TIFF, - filepath=os.path.join(tempdir, "model_bad.tiff"), - hash_val="0" * 64, - hash_type=SAMPLE_TIFF_HASH_TYPE, - ) + # checking for wrong hash + download_url(filepath=model_path, hash_val="0" * 64, hash_type=SAMPLE_TIFF_HASH_TYPE) def test_pretrained_networks(network, input_param, device): From 606e7ca6a5839d078e5dae6bc6a944ed52df24e5 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Sat, 18 Apr 2026 17:11:44 +0800 Subject: [PATCH 09/19] addressed bugs and issues raised in previous commit Signed-off-by: Enoch Mok --- tests/apps/test_download_and_extract.py | 6 +++--- tests/test_utils.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/apps/test_download_and_extract.py b/tests/apps/test_download_and_extract.py index 6d16a72735..9dcd00659f 100644 --- a/tests/apps/test_download_and_extract.py +++ b/tests/apps/test_download_and_extract.py @@ -42,10 +42,10 @@ def test_actions(self): with self.assertLogs(logger="monai.apps", level="ERROR"): try: download_url(url, filepath, wrong_md5) - except (ContentTooShortError, HTTPError, RuntimeError) as e: - if isinstance(e, RuntimeError): + except (ContentTooShortError, HTTPError, ValueError, RuntimeError) as e: + if isinstance(e, ValueError): # FIXME: skip MD5 check as current downloading method may fail - self.assertTrue(str(e).startswith("md5 check")) + self.assertIn("hash check", str(e)) return # skipping this test due the network connection errors try: diff --git a/tests/test_utils.py b/tests/test_utils.py index e5418add64..1816b07458 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -198,10 +198,8 @@ class TestDownloadUrl(unittest.TestCase): """Exercise ``download_url`` success and hash-mismatch paths.""" def test_download_url(self): - """Download a sample TIFF and validate hash handling. - - Raises: - ValueError: When the downloaded file's hash does not match. + """ + Download a sample TIFF and validate hash handling. """ with tempfile.TemporaryDirectory() as tempdir: model_path = os.path.join(tempdir, "model.tiff") @@ -210,9 +208,13 @@ def test_download_url(self): download_url( url=SAMPLE_TIFF, filepath=model_path, hash_val=SAMPLE_TIFF_HASH, hash_type=SAMPLE_TIFF_HASH_TYPE ) + + # Verify file exists before testing hash mismatch + if not os.path.exists(model_path): + self.skipTest("File was not downloaded successfully") + with self.assertRaises(ValueError): - # checking for wrong hash - download_url(filepath=model_path, hash_val="0" * 64, hash_type=SAMPLE_TIFF_HASH_TYPE) + download_url(url=SAMPLE_TIFF, filepath=model_path, hash_val="0" * 64, hash_type=SAMPLE_TIFF_HASH_TYPE) def test_pretrained_networks(network, input_param, device): From 88f83727b3d75e8daf81c66a2a82a153de3d2ce9 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Mon, 20 Apr 2026 15:51:20 +0800 Subject: [PATCH 10/19] addressing Eric's comment in https://github.com/Project-MONAI/MONAI/pull/8833#discussion_r3107633647 Signed-off-by: Enoch Mok --- tests/apps/test_download_and_extract.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/apps/test_download_and_extract.py b/tests/apps/test_download_and_extract.py index 9dcd00659f..a9329e96f6 100644 --- a/tests/apps/test_download_and_extract.py +++ b/tests/apps/test_download_and_extract.py @@ -42,10 +42,11 @@ def test_actions(self): with self.assertLogs(logger="monai.apps", level="ERROR"): try: download_url(url, filepath, wrong_md5) - except (ContentTooShortError, HTTPError, ValueError, RuntimeError) as e: - if isinstance(e, ValueError): - # FIXME: skip MD5 check as current downloading method may fail - self.assertIn("hash check", str(e)) + except ValueError as e: + # FIXME: skip MD5 check as current downloading method may fail + self.assertIn("hash check", str(e)) + return + except (ContentTooShortError, HTTPError, RuntimeError) as e: return # skipping this test due the network connection errors try: From e9861d333413a4f8f079c4d61b01fcc1979ad63f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Apr 2026 07:51:46 +0000 Subject: [PATCH 11/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/apps/test_download_and_extract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/apps/test_download_and_extract.py b/tests/apps/test_download_and_extract.py index a9329e96f6..ec978f676a 100644 --- a/tests/apps/test_download_and_extract.py +++ b/tests/apps/test_download_and_extract.py @@ -46,7 +46,7 @@ def test_actions(self): # FIXME: skip MD5 check as current downloading method may fail self.assertIn("hash check", str(e)) return - except (ContentTooShortError, HTTPError, RuntimeError) as e: + except (ContentTooShortError, HTTPError, RuntimeError): return # skipping this test due the network connection errors try: From 453b23826f044c12b036d6c6dac0020f6b47d5c5 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Mon, 20 Apr 2026 15:56:03 +0800 Subject: [PATCH 12/19] removing duplicate line Signed-off-by: Enoch Mok --- tests/apps/test_download_and_extract.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/apps/test_download_and_extract.py b/tests/apps/test_download_and_extract.py index ec978f676a..c85da687f4 100644 --- a/tests/apps/test_download_and_extract.py +++ b/tests/apps/test_download_and_extract.py @@ -36,7 +36,6 @@ def test_actions(self): hash_val, hash_type = config_dict["hash_val"], config_dict["hash_type"] with skip_if_downloading_fails(): download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) - download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) wrong_md5 = "0" with self.assertLogs(logger="monai.apps", level="ERROR"): From 64b8e090b5bf999302c732c93f121b05ca90aae7 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Mon, 20 Apr 2026 16:41:05 +0800 Subject: [PATCH 13/19] removed duplicated tests for download_url. updated existing tests for updated download_url function Signed-off-by: Enoch Mok --- monai/apps/utils.py | 13 +++-- tests/apps/test_download_and_extract.py | 77 ++++++++++++++++--------- tests/test_utils.py | 59 ------------------- 3 files changed, 58 insertions(+), 91 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 264425e9b9..cbeed9091c 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -326,10 +326,16 @@ def extractall( be False. Raises: - RuntimeError: When the hash validation of the ``filepath`` compressed file fails. + ValueError: When the hash validation of the ``filepath`` compressed file fails. NotImplementedError: When the ``filepath`` file extension is not one of [zip", "tar.gz", "tar"]. """ + filepath = Path(filepath) + if hash_val and not check_hash(filepath, hash_val, hash_type): + raise ValueError( + f"{hash_type} hash check of compressed file failed: " + f"filepath={filepath}, expected {hash_type}={hash_val}." + ) if has_base: # the extracted files will be in this folder cache_dir = Path(output_dir, _basename(filepath).split(".")[0]) @@ -338,11 +344,6 @@ def extractall( if cache_dir.exists() and next(cache_dir.iterdir(), None) is not None: logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.") return - filepath = Path(filepath) - if hash_val and not check_hash(filepath, hash_val, hash_type): - raise RuntimeError( - f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}." - ) logger.info(f"Writing into directory: {output_dir}.") _file_type = file_type.lower().strip() if filepath.name.endswith("zip") or _file_type == "zip": diff --git a/tests/apps/test_download_and_extract.py b/tests/apps/test_download_and_extract.py index c85da687f4..b5dfe2f7f5 100644 --- a/tests/apps/test_download_and_extract.py +++ b/tests/apps/test_download_and_extract.py @@ -26,39 +26,62 @@ @SkipIfNoModule("requests") class TestDownloadAndExtract(unittest.TestCase): + def setUp(self): + self.testing_dir = Path(__file__).parents[1] / "testing_data" + self.config = testing_data_config("images", "mednist") + self.url = self.config["url"] + self.hash_val = self.config["hash_val"] + self.hash_type = self.config["hash_type"] + @skip_if_quick - def test_actions(self): - testing_dir = Path(__file__).parents[1] / "testing_data" - config_dict = testing_data_config("images", "mednist") - url = config_dict["url"] - filepath = Path(testing_dir) / "MedNIST.tar.gz" - output_dir = Path(testing_dir) - hash_val, hash_type = config_dict["hash_val"], config_dict["hash_type"] + def test_download_and_extract_success(self): + """End-to-end: download and extract should succeed with correct hash.""" + filepath = self.testing_dir / "MedNIST.tar.gz" + output_dir = self.testing_dir + with skip_if_downloading_fails(): - download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type) + download_and_extract(self.url, filepath, output_dir, hash_val=self.hash_val, hash_type=self.hash_type) - wrong_md5 = "0" - with self.assertLogs(logger="monai.apps", level="ERROR"): - try: - download_url(url, filepath, wrong_md5) - except ValueError as e: - # FIXME: skip MD5 check as current downloading method may fail - self.assertIn("hash check", str(e)) - return - except (ContentTooShortError, HTTPError, RuntimeError): - return # skipping this test due the network connection errors - - try: - extractall(filepath, output_dir, wrong_md5) - except RuntimeError as e: - self.assertTrue(str(e).startswith("md5 check")) + self.assertTrue(filepath.exists(), "Downloaded file does not exist") + self.assertTrue(any(output_dir.iterdir()), "Extraction output is empty") + + @skip_if_quick + def test_download_url_hash_mismatch(self): + """download_url should raise ValueError on hash mismatch.""" + filepath = self.testing_dir / "MedNIST.tar.gz" + + with skip_if_downloading_fails(): + # First ensure file is downloaded correctly + download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) + + # Now test incorrect hash + with self.assertRaises(ValueError) as ctx: + download_url(self.url, filepath, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) + + self.assertIn("hash check", str(ctx.exception).lower()) @skip_if_quick - @parameterized.expand((("icon", "tar"), ("favicon", "zip"))) - def test_default(self, key, file_type): + def test_extractall_hash_mismatch(self): + """extractall should raise ValueError when hash is incorrect.""" + filepath = self.testing_dir / "MedNIST.tar.gz" + output_dir = self.testing_dir + + with skip_if_downloading_fails(): + download_url(self.url, filepath, hash_val=self.hash_val, hash_type=self.hash_type) + + with self.assertRaises(ValueError) as ctx: + extractall(filepath, output_dir, hash_val="0" * len(self.hash_val), hash_type=self.hash_type) + + self.assertIn("hash check", str(ctx.exception).lower()) + + @skip_if_quick + @parameterized.expand([("icon", "tar"), ("favicon", "zip")]) + def test_download_and_extract_various_formats(self, key, file_type): + """Verify different archive formats download and extract correctly.""" with tempfile.TemporaryDirectory() as tmp_dir: + img_spec = testing_data_config("images", key) + with skip_if_downloading_fails(): - img_spec = testing_data_config("images", key) download_and_extract( img_spec["url"], output_dir=tmp_dir, @@ -67,6 +90,8 @@ def test_default(self, key, file_type): file_type=file_type, ) + self.assertTrue(any(Path(tmp_dir).iterdir()), f"Extraction failed for format: {file_type}") + class TestPathTraversalProtection(unittest.TestCase): """Test cases for path traversal attack protection in extractall function.""" diff --git a/tests/test_utils.py b/tests/test_utils.py index 71af2a5525..02b6a18452 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -189,65 +189,6 @@ def skip_if_downloading_fails(): raise v_e -SAMPLE_TIFF = "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CMU-1.tiff" -SAMPLE_TIFF_HASH = "73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd" -SAMPLE_TIFF_HASH_TYPE = "sha256" - - -class TestDownloadUrl(unittest.TestCase): - """Exercise ``download_url`` success and hash-mismatch paths.""" - - def test_download_url(self): - """ - Download a sample TIFF and validate hash handling. - """ - with tempfile.TemporaryDirectory() as tempdir: - model_path = os.path.join(tempdir, "model.tiff") - - with skip_if_downloading_fails(): - download_url( - url=SAMPLE_TIFF, filepath=model_path, hash_val=SAMPLE_TIFF_HASH, hash_type=SAMPLE_TIFF_HASH_TYPE - ) - - # Verify file exists before testing hash mismatch - if not os.path.exists(model_path): - self.skipTest("File was not downloaded successfully") - - with self.assertRaises(ValueError): - download_url(url=SAMPLE_TIFF, filepath=model_path, hash_val="0" * 64, hash_type=SAMPLE_TIFF_HASH_TYPE) - - -SAMPLE_TIFF = "https://huggingface.co/datasets/MONAI/testing_data/resolve/main/CMU-1.tiff" -SAMPLE_TIFF_HASH = "73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd" -SAMPLE_TIFF_HASH_TYPE = "sha256" - - -class TestDownloadUrl(unittest.TestCase): - """Exercise ``download_url`` success and hash-mismatch paths.""" - - def test_download_url(self): - """Download a sample TIFF and validate hash handling. - - Raises: - RuntimeError: When the downloaded file's hash does not match. - """ - with tempfile.TemporaryDirectory() as tempdir: - with skip_if_downloading_fails(): - download_url( - url=SAMPLE_TIFF, - filepath=os.path.join(tempdir, "model.tiff"), - hash_val=SAMPLE_TIFF_HASH, - hash_type=SAMPLE_TIFF_HASH_TYPE, - ) - with self.assertRaises(RuntimeError): - download_url( - url=SAMPLE_TIFF, - filepath=os.path.join(tempdir, "model_bad.tiff"), - hash_val="0" * 64, - hash_type=SAMPLE_TIFF_HASH_TYPE, - ) - - def test_pretrained_networks(network, input_param, device): with skip_if_downloading_fails(): return network(**input_param).to(device) From 096c6d09eabc6a89b1e99ad61c7fc6499d500427 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Apr 2026 08:41:26 +0000 Subject: [PATCH 14/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/apps/test_download_and_extract.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/apps/test_download_and_extract.py b/tests/apps/test_download_and_extract.py index b5dfe2f7f5..a1e5381d90 100644 --- a/tests/apps/test_download_and_extract.py +++ b/tests/apps/test_download_and_extract.py @@ -16,7 +16,6 @@ import unittest import zipfile from pathlib import Path -from urllib.error import ContentTooShortError, HTTPError from parameterized import parameterized From 5b0b5f4618c4c67de7c1433b031f00dfa3ec6ee2 Mon Sep 17 00:00:00 2001 From: Enoch Mok <65853622+e-mny@users.noreply.github.com> Date: Wed, 3 Jun 2026 15:40:20 +0800 Subject: [PATCH 15/19] Apply suggestions from code review Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Enoch Mok <65853622+e-mny@users.noreply.github.com> --- monai/apps/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index cbeed9091c..69296e7361 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -229,7 +229,7 @@ def download_url( if filepath.exists(): if not check_hash(filepath, hash_val, hash_type): raise ValueError( - f"{hash_type} hash check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}." + f"{hash_type} hash check of existing file failed: {filepath=}, expected {hash_type}={hash_val}." ) logger.info(f"File exists: {filepath}, skipped downloading.") return From 3c61005a696befe45873dc57d5d82fcc54b96460 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Thu, 4 Jun 2026 13:19:44 +0800 Subject: [PATCH 16/19] removed hash checking in a function that checks for download fails Signed-off-by: Enoch Mok --- tests/test_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 02b6a18452..0e30574c3d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -182,11 +182,6 @@ def skip_if_downloading_fails(): raise unittest.SkipTest(f"Error while downloading: {rt_e}") from rt_e # incomplete download raise rt_e - except ValueError as v_e: - if "hash check" in str(v_e): - raise unittest.SkipTest(f"Hash value error while downloading: {v_e}") from v_e - - raise v_e def test_pretrained_networks(network, input_param, device): From 8974b4dab2cce66ecd90e7d61a4804ad709c0cab Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Thu, 25 Jun 2026 12:57:27 +0800 Subject: [PATCH 17/19] added HashCheckError Signed-off-by: Enoch Mok --- monai/apps/utils.py | 19 ++++++++++--------- monai/bundle/utils.py | 6 ++++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 69296e7361..960330b4bd 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -48,6 +48,10 @@ SUPPORTED_HASH_TYPES = {"md5": hashlib.md5, "sha1": hashlib.sha1, "sha256": hashlib.sha256, "sha512": hashlib.sha512} +class HashCheckError(ValueError): + pass + + def get_logger( module_name: str = "monai.apps", fmt: str = DEFAULT_FMT, @@ -228,9 +232,7 @@ def download_url( filepath = Path(filepath) if filepath.exists(): if not check_hash(filepath, hash_val, hash_type): - raise ValueError( - f"{hash_type} hash check of existing file failed: {filepath=}, expected {hash_type}={hash_val}." - ) + raise HashCheckError(f"{hash_type} hash check of existing file failed: {filepath=}, expected {hash_type=}.") logger.info(f"File exists: {filepath}, skipped downloading.") return try: @@ -260,9 +262,9 @@ def download_url( f"Download of file from {url} to {filepath} failed due to network issue or denied permission." ) if not check_hash(tmp_name, hash_val, hash_type): - raise ValueError( - f"{hash_type} hash check of downloaded file failed: URL={url}, " - f"filepath={filepath}, expected {hash_type}={hash_val}, " + raise HashCheckError( + f"{hash_type} hash check of downloaded file failed: {url=}, " + f"{filepath=}, expected {hash_type}={hash_val}, " f"The file may be corrupted or tampered with. " "Please retry the download or verify the source." ) @@ -332,9 +334,8 @@ def extractall( """ filepath = Path(filepath) if hash_val and not check_hash(filepath, hash_val, hash_type): - raise ValueError( - f"{hash_type} hash check of compressed file failed: " - f"filepath={filepath}, expected {hash_type}={hash_val}." + raise HashCheckError( + f"{hash_type} hash check of compressed file failed: " f"{filepath=}, expected {hash_type}={hash_val}." ) if has_base: # the extracted files will be in this folder diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py index d37d7f1c05..53d619f234 100644 --- a/monai/bundle/utils.py +++ b/monai/bundle/utils.py @@ -124,8 +124,10 @@ "run_name": None, # may fill it at runtime "save_execute_config": True, - "is_not_rank0": ("$torch.distributed.is_available() \ - and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0"), + "is_not_rank0": ( + "$torch.distributed.is_available() \ + and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0" + ), # MLFlowHandler config for the trainer "trainer": { "_target_": "MLFlowHandler", From 32ebca7cddbc426916f19d85c6928a7441563fdc Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Thu, 25 Jun 2026 15:58:34 +0800 Subject: [PATCH 18/19] refactor(engines): introduce SingleNetworkTrainer and DualNetworkTrainer base classes Signed-off-by: Enoch Mok --- monai/engines/trainer.py | 301 ++++++++++++++++++++++++++++++--------- 1 file changed, 233 insertions(+), 68 deletions(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 921d54a59c..609ea215fd 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -36,7 +36,14 @@ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") -__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer", "AdversarialTrainer"] +__all__ = [ + "Trainer", + "SingleNetworkTrainer", + "SupervisedTrainer", + "DualNetworkTrainer", + "GanTrainer", + "AdversarialTrainer", +] class Trainer(Workflow): @@ -51,7 +58,7 @@ def run(self) -> None: # type: ignore[override] If call this function multiple times, it will continuously run from the previous state. """ - self.scaler = torch.cuda.amp.GradScaler() if self.amp else None + self.scaler = torch.amp.GradScaler("cuda") if self.amp else None super().run() def get_stats(self, *vars): @@ -76,10 +83,144 @@ def get_stats(self, *vars): stats[k] = getattr(self.state, k, None) return stats + @staticmethod + def _unpack_batch(batch) -> tuple: + """ + Unpack a prepared batch into ``(inputs, targets, args, kwargs)``. + + If the batch contains exactly 2 elements, ``args`` and ``kwargs`` default to an empty + tuple and dict respectively. Otherwise the batch is expected to already be a 4-tuple. + + Args: + batch: output of ``prepare_batch``, either a 2-tuple ``(inputs, targets)`` + or a 4-tuple ``(inputs, targets, args, kwargs)``. + + Returns: + A 4-tuple ``(inputs, targets, args, kwargs)``. + """ + if len(batch) == 2: + inputs, targets = batch + return inputs, targets, (), {} + return batch + -class SupervisedTrainer(Trainer): +class SingleNetworkTrainer(Trainer): """ - Standard supervised training method with image and label, inherits from ``Trainer`` and ``Workflow``. + Intermediate base class for trainers that operate with a single network and optimizer pair, + inherits from ``Trainer``. + + Centralises the ownership and initialisation of ``network``, ``optimizer``, ``inferer``, + ``optim_set_to_none``, ``accumulation_steps``, and optional ``torch.compile`` wrapping so + that concrete subclasses (e.g. ``SupervisedTrainer``) only need to provide their + task-specific loss and forward logic. + + Provides two helper methods intended for use in ``_iteration`` implementations: + + - :meth:`_accumulation_flags` — returns ``(should_zero_grad, should_step)`` for the + current iteration based on the configured ``accumulation_steps``. + - :meth:`_amp_step` — runs the scaled backward pass and conditional optimizer step, + handling both AMP and non-AMP code paths uniformly. + + Args: + network: network to train, should be regular PyTorch ``torch.nn.Module``. + optimizer: the optimizer associated to the network. + inferer: inference method that executes the model forward on input data. + Defaults to ``SimpleInferer()``. + optim_set_to_none: when calling ``optimizer.zero_grad()``, set grads to ``None`` + instead of zero. Default: ``False``. + accumulation_steps: number of mini-batches over which to accumulate gradients. + Must be a positive integer. Default: ``1`` (no accumulation). + compile: whether to wrap the network with ``torch.compile``. Default: ``False``. + compile_kwargs: keyword arguments forwarded to ``torch.compile()``. + **kwargs: remaining arguments forwarded to ``Trainer`` / ``Workflow``. + """ + + def __init__( + self, + network: torch.nn.Module, + optimizer: Optimizer, + inferer: Inferer | None = None, + optim_set_to_none: bool = False, + accumulation_steps: int = 1, + compile: bool = False, + compile_kwargs: dict | None = None, + **kwargs, + ) -> None: + if accumulation_steps < 1: + raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") + super().__init__(**kwargs) + if compile: + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] + self.network = network + self.compile = compile + self.optimizer = optimizer + self.inferer = SimpleInferer() if inferer is None else inferer + self.optim_set_to_none = optim_set_to_none + self.accumulation_steps = accumulation_steps + + def _accumulation_flags(self, engine: SingleNetworkTrainer) -> tuple[bool, bool]: + """ + Return ``(should_zero_grad, should_step)`` for the current iteration. + + When ``accumulation_steps == 1`` both flags are always ``True``. + Otherwise the flags reflect whether the current iteration is the first or last step + in an accumulation window, with an end-of-epoch flush when ``epoch_length`` is known + and not evenly divisible by ``accumulation_steps``. + + Args: + engine: the trainer engine for the current iteration. + + Returns: + A 2-tuple ``(should_zero_grad, should_step)``. + """ + acc = engine.accumulation_steps + if acc == 1: + return True, True + epoch_length = engine.state.epoch_length + if epoch_length is not None: + local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch + should_zero_grad = local_iter % acc == 0 + should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length + else: + local_iter = engine.state.iteration - 1 # 0-indexed global + should_zero_grad = local_iter % acc == 0 + should_step = (local_iter + 1) % acc == 0 + return should_zero_grad, should_step + + def _amp_step(self, engine: SingleNetworkTrainer, loss: torch.Tensor, *, should_step: bool) -> None: + """ + Run the backward pass and, when ``should_step`` is ``True``, the optimizer step. + + Handles both the AMP (``engine.scaler`` present) and non-AMP code paths, and + divides the loss by ``accumulation_steps`` before calling ``.backward()`` so + gradients accumulate correctly across micro-batches. Fires + ``IterationEvents.BACKWARD_COMPLETED`` after the backward pass in both paths. + + Args: + engine: the trainer engine for the current iteration. + loss: the **unscaled** loss value (as stored in ``engine.state.output``). + should_step: whether to call ``optimizer.step()`` after the backward pass. + """ + acc = engine.accumulation_steps + scaled_loss = loss / acc if acc > 1 else loss + if engine.amp and engine.scaler is not None: + engine.scaler.scale(scaled_loss).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + if should_step: + engine.scaler.step(engine.optimizer) + engine.scaler.update() + else: + scaled_loss.backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + if should_step: + engine.optimizer.step() + + +class SupervisedTrainer(SingleNetworkTrainer): + """ + Standard supervised training method with image and label, inherits from ``SingleNetworkTrainer``, + ``Trainer`` and ``Workflow``. Args: device: an object representing the device on which to run. @@ -168,9 +309,16 @@ def __init__( compile_kwargs: dict | None = None, accumulation_steps: int = 1, ) -> None: - if accumulation_steps < 1: - raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") super().__init__( + # SingleNetworkTrainer args + network=network, + optimizer=optimizer, + inferer=inferer, + optim_set_to_none=optim_set_to_none, + accumulation_steps=accumulation_steps, + compile=compile, + compile_kwargs=compile_kwargs, + # Workflow args forwarded via **kwargs through SingleNetworkTrainer → Trainer → Workflow device=device, max_epochs=max_epochs, data_loader=train_data_loader, @@ -190,16 +338,7 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - if compile: - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - network = torch.compile(network, **compile_kwargs) # type: ignore[assignment] - self.network = network - self.compile = compile - self.optimizer = optimizer self.loss_function = loss_function - self.inferer = SimpleInferer() if inferer is None else inferer - self.optim_set_to_none = optim_set_to_none - self.accumulation_steps = accumulation_steps def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]) -> dict: """ @@ -221,12 +360,7 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tenso if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) - if len(batch) == 2: - inputs, targets = batch - args: tuple = () - kwargs: dict = {} - else: - inputs, targets, args, kwargs = batch + inputs, targets, args, kwargs = self._unpack_batch(batch) # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026 if self.compile: inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None @@ -255,21 +389,7 @@ def _compute_pred_loss(): engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean() engine.fire_event(IterationEvents.LOSS_COMPLETED) - # Determine gradient accumulation state - acc = engine.accumulation_steps - if acc > 1: - epoch_length = engine.state.epoch_length - if epoch_length is not None: - local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch - should_zero_grad = local_iter % acc == 0 - should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length - else: - local_iter = engine.state.iteration - 1 # 0-indexed global - should_zero_grad = local_iter % acc == 0 - should_step = (local_iter + 1) % acc == 0 - else: - should_zero_grad = True - should_step = True + should_zero_grad, should_step = self._accumulation_flags(engine) engine.network.train() if should_zero_grad: @@ -278,19 +398,11 @@ def _compute_pred_loss(): if engine.amp and engine.scaler is not None: with torch.autocast("cuda", **engine.amp_kwargs): _compute_pred_loss() - loss = engine.state.output[Keys.LOSS] - engine.scaler.scale(loss / acc if acc > 1 else loss).backward() - engine.fire_event(IterationEvents.BACKWARD_COMPLETED) - if should_step: - engine.scaler.step(engine.optimizer) - engine.scaler.update() else: _compute_pred_loss() - loss = engine.state.output[Keys.LOSS] - (loss / acc if acc > 1 else loss).backward() - engine.fire_event(IterationEvents.BACKWARD_COMPLETED) - if should_step: - engine.optimizer.step() + + self._amp_step(engine, engine.state.output[Keys.LOSS], should_step=should_step) + # copy back meta info if self.compile: if inputs_meta is not None: @@ -309,10 +421,54 @@ def _compute_pred_loss(): return engine.state.output -class GanTrainer(Trainer): +class DualNetworkTrainer(Trainer): + """ + Intermediate base class for trainers that manage a generator and discriminator network pair, + inherits from ``Trainer``. + + Centralises the ownership and initialisation of the generator (``g_network``, + ``g_optimizer``, ``g_inferer``) and discriminator (``d_network``, ``d_optimizer``, + ``d_inferer``) components, along with ``optim_set_to_none``, so that concrete + subclasses (e.g. ``GanTrainer``, ``AdversarialTrainer``) only need to provide their + task-specific training logic. + + Args: + g_network: generator (G) network architecture. + g_optimizer: G optimizer. + d_network: discriminator (D) network architecture. + d_optimizer: D optimizer. + g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``. + d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``. + optim_set_to_none: when calling ``optimizer.zero_grad()``, set grads to ``None`` + instead of zero. Default: ``False``. + **kwargs: remaining arguments forwarded to ``Trainer`` / ``Workflow``. + """ + + def __init__( + self, + g_network: torch.nn.Module, + g_optimizer: Optimizer, + d_network: torch.nn.Module, + d_optimizer: Optimizer, + g_inferer: Inferer | None = None, + d_inferer: Inferer | None = None, + optim_set_to_none: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.g_network = g_network + self.g_optimizer = g_optimizer + self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer + self.d_network = d_network + self.d_optimizer = d_optimizer + self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer + self.optim_set_to_none = optim_set_to_none + + +class GanTrainer(DualNetworkTrainer): """ Generative adversarial network training based on Goodfellow et al. 2014 https://arxiv.org/abs/1406.266, - inherits from ``Trainer`` and ``Workflow``. + inherits from ``DualNetworkTrainer``, ``Trainer`` and ``Workflow``. Training Loop: for each batch of data size `m` 1. Generate `m` fakes from random latent codes. @@ -407,6 +563,15 @@ def __init__( # set up Ignite engine and environments super().__init__( + # DualNetworkTrainer args + g_network=g_network, + g_optimizer=g_optimizer, + g_inferer=g_inferer, + d_network=d_network, + d_optimizer=d_optimizer, + d_inferer=d_inferer, + optim_set_to_none=optim_set_to_none, + # Workflow args forwarded via **kwargs through DualNetworkTrainer → Trainer → Workflow device=device, max_epochs=max_epochs, data_loader=train_data_loader, @@ -423,19 +588,12 @@ def __init__( to_kwargs=to_kwargs, amp_kwargs=amp_kwargs, ) - self.g_network = g_network - self.g_optimizer = g_optimizer self.g_loss_function = g_loss_function - self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer - self.d_network = d_network - self.d_optimizer = d_optimizer self.d_loss_function = d_loss_function - self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer self.d_train_steps = d_train_steps self.latent_shape = latent_shape self.g_prepare_batch = g_prepare_batch self.g_update_latents = g_update_latents - self.optim_set_to_none = optim_set_to_none def _iteration( self, engine: GanTrainer, batchdata: dict | Sequence @@ -498,9 +656,10 @@ def _iteration( } -class AdversarialTrainer(Trainer): +class AdversarialTrainer(DualNetworkTrainer): """ - Standard supervised training workflow for adversarial loss enabled neural networks. + Standard supervised training workflow for adversarial loss enabled neural networks, + inherits from ``DualNetworkTrainer``, ``Trainer`` and ``Workflow``. Args: device: an object representing the device on which to run. @@ -579,6 +738,15 @@ def __init__( amp_kwargs: dict | None = None, ): super().__init__( + # DualNetworkTrainer args + g_network=g_network, + g_optimizer=g_optimizer, + g_inferer=g_inferer, + d_network=d_network, + d_optimizer=d_optimizer, + d_inferer=d_inferer, + optim_set_to_none=optim_set_to_none, + # Workflow args forwarded via **kwargs through DualNetworkTrainer → Trainer → Workflow device=device, max_epochs=max_epochs, data_loader=train_data_loader, @@ -601,22 +769,19 @@ def __init__( self.register_events(*AdversarialIterationEvents) - self.state.g_network = g_network - self.state.g_optimizer = g_optimizer + # Promote network/optimizer references into engine.state for checkpoint saving + self.state.g_network = self.g_network + self.state.g_optimizer = self.g_optimizer self.state.g_loss_function = g_loss_function self.state.recon_loss_function = recon_loss_function - self.state.d_network = d_network - self.state.d_optimizer = d_optimizer + self.state.d_network = self.d_network + self.state.d_optimizer = self.d_optimizer self.state.d_loss_function = d_loss_function - self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer - self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer + self.state.g_scaler = torch.amp.GradScaler("cuda") if self.amp else None + self.state.d_scaler = torch.amp.GradScaler("cuda") if self.amp else None - self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None - self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None - - self.optim_set_to_none = optim_set_to_none self._complete_state_dict_user_keys() def _complete_state_dict_user_keys(self) -> None: From 731e0e28bce490f5d258ab03ade18b6a9ed35583 Mon Sep 17 00:00:00 2001 From: Enoch Mok Date: Thu, 25 Jun 2026 18:06:10 +0800 Subject: [PATCH 19/19] adding according to coderabbit's suggestions Signed-off-by: Enoch Mok --- monai/engines/trainer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 609ea215fd..d314256b95 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -37,12 +37,12 @@ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") __all__ = [ - "Trainer", - "SingleNetworkTrainer", - "SupervisedTrainer", + "AdversarialTrainer", "DualNetworkTrainer", "GanTrainer", - "AdversarialTrainer", + "SingleNetworkTrainer", + "SupervisedTrainer", + "Trainer", ] @@ -98,9 +98,13 @@ def _unpack_batch(batch) -> tuple: Returns: A 4-tuple ``(inputs, targets, args, kwargs)``. """ - if len(batch) == 2: + batch_len = len(batch) + if batch_len == 2: inputs, targets = batch return inputs, targets, (), {} + if batch_len != 4: + raise ValueError(f"Expected prepared batch to have 2 or 4 items, got {batch_len}.") + return batch @@ -146,8 +150,9 @@ def __init__( compile_kwargs: dict | None = None, **kwargs, ) -> None: - if accumulation_steps < 1: + if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1: raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.") + super().__init__(**kwargs) if compile: compile_kwargs = {} if compile_kwargs is None else compile_kwargs