diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 049779f606..aadd96763d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -299,7 +299,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep ) if self.ensure_channel_first: - img = EnsureChannelFirst()(img) + img = EnsureChannelFirst()(img, meta_dict=meta_data) if self.image_only: return img return img, img.meta if isinstance(img, MetaTensor) else meta_data diff --git a/tests/transforms/test_load_image.py b/tests/transforms/test_load_image.py index 031e38272e..4a470a624c 100644 --- a/tests/transforms/test_load_image.py +++ b/tests/transforms/test_load_image.py @@ -25,7 +25,7 @@ from monai.apps import download_and_extract from monai.data import NibabelReader, PydicomReader -from monai.data.meta_obj import set_track_meta +from monai.data.meta_obj import get_track_meta, set_track_meta from monai.data.meta_tensor import MetaTensor from monai.transforms import LoadImage from monai.utils import optional_import @@ -497,6 +497,17 @@ def test_correct(self, input_param, expected_shape, track_meta): self.assertNotIsInstance(r, MetaTensor) self.assertFalse(hasattr(r, "affine")) + def test_track_meta_false_ensure_channel_first(self): + _previous_meta = get_track_meta() + try: + set_track_meta(False) + r = LoadImage(image_only=True, ensure_channel_first=True)(self.test_data) + self.assertTupleEqual(r.shape, (1, 128, 128, 128)) + self.assertIsInstance(r, torch.Tensor) + self.assertNotIsInstance(r, MetaTensor) + finally: + set_track_meta(_previous_meta) + if __name__ == "__main__": unittest.main()