diff --git a/Tests/test_imagecms.py b/Tests/test_imagecms.py index d73ab34282b..3382a34e47c 100644 --- a/Tests/test_imagecms.py +++ b/Tests/test_imagecms.py @@ -196,6 +196,10 @@ def test_exceptions() -> None: psRGB = ImageCms.createProfile("sRGB") pLab = ImageCms.createProfile("LAB") t = ImageCms.buildTransform(pLab, psRGB, "LAB", "RGB") + with pytest.raises(ValueError, match="mode mismatch"): + t.apply(hopper("RGBA")) + with pytest.raises(ValueError, match="mode mismatch"): + t.apply(hopper("LAB"), hopper("RGBA")) with pytest.raises(ValueError, match="mode mismatch"): t.apply_in_place(hopper("RGBA")) diff --git a/src/PIL/ImageCms.py b/src/PIL/ImageCms.py index 513e28acf33..388a9296f8b 100644 --- a/src/PIL/ImageCms.py +++ b/src/PIL/ImageCms.py @@ -317,19 +317,21 @@ def point(self, im: Image.Image) -> Image.Image: return self.apply(im) def apply(self, im: Image.Image, imOut: Image.Image | None = None) -> Image.Image: - if imOut is None: + if im.mode != self.input_mode: + msg = "mode mismatch" + raise ValueError(msg) + if imOut is not None: + if imOut.mode != self.output_mode: + msg = "mode mismatch" + raise ValueError(msg) + else: imOut = Image.new(self.output_mode, im.size, None) self.transform.apply(im.getim(), imOut.getim()) imOut.info["icc_profile"] = self.output_profile.tobytes() return imOut def apply_in_place(self, im: Image.Image) -> Image.Image: - if im.mode != self.output_mode: - msg = "mode mismatch" - raise ValueError(msg) # wrong output mode - self.transform.apply(im.getim(), im.getim()) - im.info["icc_profile"] = self.output_profile.tobytes() - return im + return self.apply(im, im) def get_display_profile(handle: SupportsInt | None = None) -> ImageCmsProfile | None: