From bcb7430ff0dd13606ecd203d190ec651d4b469fa Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 12 May 2026 15:05:33 +0200 Subject: [PATCH 1/4] Update augmentations --- doctr/transforms/modules/base.py | 150 ++++++++++++--- doctr/transforms/modules/pytorch.py | 125 ++++++++---- tests/common/test_transforms.py | 72 ++++++- tests/pytorch/test_transforms_pt.py | 288 ++++++++++++++++++++++++++++ 4 files changed, 573 insertions(+), 62 deletions(-) diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py index 2868f3bde6..e28b104909 100644 --- a/doctr/transforms/modules/base.py +++ b/doctr/transforms/modules/base.py @@ -112,7 +112,9 @@ class OneOf(NestedObject): def __init__(self, transforms: list[Callable[[Any], Any]]) -> None: self.transforms = transforms - def __call__(self, img: Any, target: np.ndarray | None = None) -> Any | tuple[Any, np.ndarray]: + def __call__( + self, img: Any, target: np.ndarray | dict[str, np.ndarray] | None = None + ) -> Any | tuple[Any, np.ndarray | dict[str, np.ndarray]]: # Pick transformation transfo = self.transforms[int(random.random() * len(self.transforms))] # Apply @@ -141,7 +143,11 @@ def __init__(self, transform: Callable[[Any], Any], p: float = 0.5) -> None: def extra_repr(self) -> str: return f"transform={self.transform}, p={self.p}" - def __call__(self, img: Any, target: np.ndarray | None = None) -> Any | tuple[Any, np.ndarray]: + def __call__( + self, + img: Any, + target: np.ndarray | dict[str, np.ndarray] | None = None, + ) -> Any | tuple[Any, np.ndarray | dict[str, np.ndarray]]: if random.random() < self.p: return self.transform(img) if target is None else self.transform(img, target) # type: ignore[call-arg] return img if target is None else (img, target) @@ -165,12 +171,46 @@ def __init__(self, max_angle: float = 5.0, expand: bool = False) -> None: def extra_repr(self) -> str: return f"max_angle={self.max_angle}, expand={self.expand}" - def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]: - angle = random.uniform(-self.max_angle, self.max_angle) + def _rotate_array(self, img: Any, target: np.ndarray, angle: float) -> tuple[Any, np.ndarray]: + """Rotate the image and the target, and keep only boxes with at least partial visibility after rotation""" + is_polygon = target.shape[1:] == (4, 2) + r_img, r_polys = F.rotate_sample(img, target, angle, self.expand) - # Removes deleted boxes + is_kept = (r_polys.max(1) > r_polys.min(1)).sum(1) == 2 - return r_img, r_polys[is_kept] + r_polys = r_polys[is_kept] + + # convert back if input was boxes + if not is_polygon: + # (N, 4, 2) -> (N, 4) + x1y1 = r_polys.min(axis=1) + x2y2 = r_polys.max(axis=1) + r_boxes = np.concatenate([x1y1, x2y2], axis=1) + return r_img, r_boxes + + return r_img, r_polys + + def __call__( + self, img: Any, target: np.ndarray | dict[str, np.ndarray] + ) -> tuple[Any, np.ndarray | dict[str, np.ndarray]]: + angle = random.uniform(-self.max_angle, self.max_angle) + + if isinstance(target, dict): + rotated_targets = {} + rotated_img = None + + for cls_name, arr in target.items(): + if len(arr) == 0: + rotated_targets[cls_name] = arr.copy() + continue + + r_img, r_arr = self._rotate_array(img, arr, angle) + if rotated_img is None: + rotated_img = r_img + rotated_targets[cls_name] = r_arr + return rotated_img if rotated_img is not None else img, rotated_targets + + return self._rotate_array(img, target, angle) class RandomCrop(NestedObject): @@ -188,7 +228,61 @@ def __init__(self, scale: tuple[float, float] = (0.08, 1.0), ratio: tuple[float, def extra_repr(self) -> str: return f"scale={self.scale}, ratio={self.ratio}" - def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]: + def _crop_array( + self, + img: Any, + target: np.ndarray, + crop_box: tuple[float, float, float, float], + ) -> tuple[Any, np.ndarray]: + is_polygon = target.shape[1:] == (4, 2) + # For polygons, we need to reproject the coordinates into the cropped frame, + # and keep only those with at least partial visibility + if is_polygon: + cropped_img, _ = F.crop_detection( + img, + np.concatenate( + ( + np.min(target, axis=1), + np.max(target, axis=1), + ), + axis=1, + ), + crop_box, + ) + + cropped_polys = target.copy() + + crop_w = crop_box[2] - crop_box[0] + crop_h = crop_box[3] - crop_box[1] + + # Reproject coordinates into cropped frame + cropped_polys[..., 0] = (cropped_polys[..., 0] - crop_box[0]) / crop_w + cropped_polys[..., 1] = (cropped_polys[..., 1] - crop_box[1]) / crop_h + + # Keep polygons with at least partial visibility + poly_min = np.min(cropped_polys, axis=1) + poly_max = np.max(cropped_polys, axis=1) + is_kept = (poly_max[:, 0] > 0) & (poly_min[:, 0] < 1) & (poly_max[:, 1] > 0) & (poly_min[:, 1] < 1) + cropped_polys = cropped_polys[is_kept] + + if cropped_polys.shape[0] == 0: + return img, target + + return cropped_img, np.clip(cropped_polys, 0, 1) + + # For detection boxes, we can directly crop and clip them + cropped_img, crop_boxes = F.crop_detection(img, target, crop_box) + + if crop_boxes.shape[0] == 0: + return img, target + + return cropped_img, np.clip(crop_boxes, 0, 1) + + def __call__( + self, + img: Any, + target: np.ndarray | dict[str, np.ndarray], + ) -> tuple[Any, np.ndarray | dict[str, np.ndarray]]: scale = random.uniform(self.scale[0], self.scale[1]) ratio = random.uniform(self.ratio[0], self.ratio[1]) @@ -208,19 +302,29 @@ def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]: x = random.randint(0, width - crop_width) y = random.randint(0, height - crop_height) - # relative crop box - crop_box = (x / width, y / height, (x + crop_width) / width, (y + crop_height) / height) - if target.shape[1:] == (4, 2): - min_xy = np.min(target, axis=1) - max_xy = np.max(target, axis=1) - _target = np.concatenate((min_xy, max_xy), axis=1) - else: - _target = target - - # Crop image and targets - croped_img, crop_boxes = F.crop_detection(img, _target, crop_box) - # hard fallback if no box is kept - if crop_boxes.shape[0] == 0: - return img, target - # clip boxes - return croped_img, np.clip(crop_boxes, 0, 1) + crop_box = ( + x / width, + y / height, + (x + crop_width) / width, + (y + crop_height) / height, + ) + + if isinstance(target, dict): + cropped_targets = {} + cropped_img = None + + for cls_name, arr in target.items(): + if len(arr) == 0: + cropped_targets[cls_name] = arr.copy() + continue + + c_img, c_arr = self._crop_array(img, arr, crop_box) + + if cropped_img is None: + cropped_img = c_img + + cropped_targets[cls_name] = c_arr + + return cropped_img if cropped_img is not None else img, cropped_targets + + return self._crop_array(img, target, crop_box) diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index 760cb88c9f..1f20a805b7 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -57,6 +57,37 @@ def __init__( self.symmetric_pad = symmetric_pad self.return_padding_mask = return_padding_mask + def _resize_target( + self, + target, + raw_shape, + final_shape, + symmetric_pad=False, + offset=(0, 0), + ): + target = target.copy() + + if target.shape[1:] == (4,): + if symmetric_pad: + target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / final_shape[-1] + target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / final_shape[-2] + else: + target[:, [0, 2]] *= raw_shape[-1] / final_shape[-1] + target[:, [1, 3]] *= raw_shape[-2] / final_shape[-2] + + elif target.shape[1:] == (4, 2): + if symmetric_pad: + target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / final_shape[-1] + target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / final_shape[-2] + else: + target[..., 0] *= raw_shape[-1] / final_shape[-1] + target[..., 1] *= raw_shape[-2] / final_shape[-2] + + else: + raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)") + + return np.clip(target, 0, 1) + def forward( self, img: torch.Tensor, @@ -116,32 +147,36 @@ def forward( # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio) if target is not None: if self.symmetric_pad: - offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2] - - if self.preserve_aspect_ratio: - # Get absolute coords - if target.shape[1:] == (4,): - if self.symmetric_pad: - target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1] - target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2] - else: - target[:, [0, 2]] *= raw_shape[-1] / img.shape[-1] - target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2] - elif target.shape[1:] == (4, 2): - if self.symmetric_pad: - target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1] - target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2] - else: - target[..., 0] *= raw_shape[-1] / img.shape[-1] - target[..., 1] *= raw_shape[-2] / img.shape[-2] - else: - raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)") - - target = np.clip(target, 0, 1) - - if self.return_padding_mask: - return img, target, padding_mask - return img, target + offset = ( + half_pad[0] / img.shape[-1], + half_pad[1] / img.shape[-2], + ) + else: + offset = (0, 0) + + if isinstance(target, dict): + target = { # type: ignore[assignment] + cls_name: self._resize_target( + arr, + raw_shape, + img.shape[-2:], + symmetric_pad=self.symmetric_pad, + offset=offset, + ) + for cls_name, arr in target.items() + } + else: + target = self._resize_target( + target, + raw_shape, + img.shape[-2:], + symmetric_pad=self.symmetric_pad, + offset=offset, + ) + if target is not None: + if self.return_padding_mask: + return img, target, padding_mask + return img, target if self.return_padding_mask: return img, padding_mask @@ -234,16 +269,30 @@ def forward(self, img: torch.Tensor) -> torch.Tensor: class RandomHorizontalFlip(T.RandomHorizontalFlip): """Randomly flip the input image horizontally""" - def forward(self, img: torch.Tensor | Image, target: np.ndarray) -> tuple[torch.Tensor | Image, np.ndarray]: + def _flip_array(self, target): + _target = target.copy() + # Changing the relative bbox coordinates + if target.shape[1:] == (4,): + _target[:, ::2] = 1 - target[:, [2, 0]] + else: + _target[..., 0] = 1 - target[..., 0] + + return _target + + def forward( + self, + img: torch.Tensor | Image, + target: np.ndarray | dict[str, np.ndarray], + ) -> tuple[torch.Tensor | Image, np.ndarray | dict[str, np.ndarray]]: + if torch.rand(1) < self.p: _img = F.hflip(img) - _target = target.copy() - # Changing the relative bbox coordinates - if target.shape[1:] == (4,): - _target[:, ::2] = 1 - target[:, [2, 0]] - else: - _target[..., 0] = 1 - target[..., 0] - return _img, _target + + if isinstance(target, dict): + return _img, {cls_name: self._flip_array(arr) for cls_name, arr in target.items()} + + return _img, self._flip_array(target) + return img, target @@ -319,7 +368,11 @@ def __init__( self.p = p self._resize = Resize - def forward(self, img: torch.Tensor, target: np.ndarray) -> tuple[torch.Tensor, np.ndarray]: + def forward( + self, + img: torch.Tensor, + target: np.ndarray | dict[str, np.ndarray], + ) -> tuple[torch.Tensor, np.ndarray | dict[str, np.ndarray]]: if torch.rand(1) < self.p: scale_h = np.random.uniform(*self.scale_range) scale_w = np.random.uniform(*self.scale_range) @@ -329,7 +382,7 @@ def forward(self, img: torch.Tensor, target: np.ndarray) -> tuple[torch.Tensor, new_size, preserve_aspect_ratio=self.preserve_aspect_ratio if isinstance(self.preserve_aspect_ratio, bool) - else bool(torch.rand(1) <= self.symmetric_pad), + else bool(torch.rand(1) <= self.preserve_aspect_ratio), symmetric_pad=self.symmetric_pad if isinstance(self.symmetric_pad, bool) else bool(torch.rand(1) <= self.symmetric_pad), diff --git a/tests/common/test_transforms.py b/tests/common/test_transforms.py index 5c136c6d09..0ad9fdb138 100644 --- a/tests/common/test_transforms.py +++ b/tests/common/test_transforms.py @@ -21,20 +21,86 @@ def test_oneof(): transfo = T.OneOf(transfos) out = transfo(1) assert out == 0 or out == 11 - # test with target + + # test with ndarray target transfos = [lambda x, y: (1 - x, y), lambda x, y: (x + 10, y)] transfo = T.OneOf(transfos) out = transfo(1, np.array([2])) - assert out == (0, 2) or out == (11, 2) and isinstance(out[1], np.ndarray) + assert out == (0, 2) or out == (11, 2) + assert isinstance(out[1], np.ndarray) + + # test with dict target + dict_target = { + "boxes": np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), + "labels": np.array([1], dtype=np.int64), + } + transfos = [ + lambda x, y: (1 - x, y), + lambda x, y: (x + 10, y), + ] + transfo = T.OneOf(transfos) + out = transfo(1, dict_target) + assert out[0] == 0 or out[0] == 11 + assert isinstance(out[1], dict) + assert set(out[1].keys()) == {"boxes", "labels"} + assert isinstance(out[1]["boxes"], np.ndarray) + assert isinstance(out[1]["labels"], np.ndarray) + np.testing.assert_array_equal( + out[1]["boxes"], + np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), + ) + np.testing.assert_array_equal( + out[1]["labels"], + np.array([1], dtype=np.int64), + ) def test_randomapply(): transfo = T.RandomApply(lambda x: 1 - x) out = transfo(1) assert out == 0 or out == 1 + + # test with ndarray target transfo = T.RandomApply(lambda x, y: (1 - x, 2 * y)) out = transfo(1, np.array([2])) - assert out == (0, 4) or out == (1, 2) and isinstance(out[1], np.ndarray) + assert out == (0, 4) or out == (1, 2) + assert isinstance(out[1], np.ndarray) + + # test with dict target + dict_target = { + "boxes": np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), + "labels": np.array([1], dtype=np.int64), + } + transfo = T.RandomApply( + lambda x, y: ( + 1 - x, + { + "boxes": 2 * y["boxes"], + "labels": y["labels"], + }, + ) + ) + + out = transfo(1, dict_target) + assert out[0] == 0 or out[0] == 1 + assert isinstance(out[1], dict) + assert set(out[1].keys()) == {"boxes", "labels"} + assert isinstance(out[1]["boxes"], np.ndarray) + assert isinstance(out[1]["labels"], np.ndarray) + if out[0] == 0: + np.testing.assert_array_equal( + out[1]["boxes"], + 2 * np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), + ) + else: + np.testing.assert_array_equal( + out[1]["boxes"], + np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), + ) + np.testing.assert_array_equal( + out[1]["labels"], + np.array([1], dtype=np.int64), + ) assert repr(transfo).endswith(", p=0.5)") diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py index 3ed262608b..6c89c88fb7 100644 --- a/tests/pytorch/test_transforms_pt.py +++ b/tests/pytorch/test_transforms_pt.py @@ -1,4 +1,5 @@ import math +import random import numpy as np import pytest @@ -9,12 +10,16 @@ ColorInversion, GaussianBlur, GaussianNoise, + ImageTransform, + OneOf, + RandomApply, RandomCrop, RandomHorizontalFlip, RandomResize, RandomRotate, RandomShadow, Resize, + SampleCompose, ) from doctr.transforms.functional import crop_detection, rotate_sample @@ -129,6 +134,47 @@ def test_resize(): with pytest.raises(AssertionError): transfo(input_t, target) + # Test dict targets + target_dict = { + "boxes": np.array([[0.1, 0.1, 0.9, 0.9]], dtype=np.float32), + "polygons": np.array([[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8], [0.2, 0.8]]], dtype=np.float32), + } + transfo = Resize( + (64, 64), + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + _, new_target = transfo(input_t, target_dict) + assert isinstance(new_target, dict) + assert set(new_target.keys()) == {"boxes", "polygons"} + assert new_target["boxes"].shape == (1, 4) + assert new_target["polygons"].shape == (1, 4, 2) + + # Test return type combinations + transfo = Resize((32, 32)) + out = transfo(input_t) + assert isinstance(out, torch.Tensor) + + transfo = Resize((32, 32), return_padding_mask=True) + out = transfo(input_t) + assert isinstance(out, tuple) + assert len(out) == 2 + + transfo = Resize((32, 32), preserve_aspect_ratio=True) + out = transfo(input_t, target_boxes) + assert isinstance(out, tuple) + assert len(out) == 2 + + transfo = Resize( + (32, 32), + preserve_aspect_ratio=True, + return_padding_mask=True, + ) + + out = transfo(input_t, target_boxes) + assert isinstance(out, tuple) + assert len(out) == 3 + @pytest.mark.parametrize( "rgb_min", @@ -215,6 +261,37 @@ def test_random_rotate(): r_img, _r_boxes = rotator(input_t, boxes) assert r_img.shape != input_t.shape + # Test dict targets + dict_target = { + "boxes": np.array([[15, 20, 35, 30]]), + "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), + } + r_img, r_targets = rotator(input_t, dict_target) + assert isinstance(r_targets, dict) + assert set(r_targets.keys()) == {"boxes", "polygons"} + assert isinstance(r_targets["boxes"], np.ndarray) + assert isinstance(r_targets["polygons"], np.ndarray) + # Check rotated image + assert r_img.ndim == input_t.ndim + # Check boxes + assert np.all(r_targets["boxes"] >= 0) + if len(r_targets["boxes"]) > 0: + assert r_targets["boxes"].shape[1] == 4 + # Check polygons + assert np.all(r_targets["polygons"] >= 0) + if len(r_targets["polygons"]) > 0: + assert r_targets["polygons"].shape[1:] == (4, 2) + + # Empty dict targets + empty_targets = { + "boxes": np.zeros((0, 4), dtype=np.float32), + "polygons": np.zeros((0, 4, 2), dtype=np.float32), + } + r_img, r_targets = rotator(input_t, empty_targets) + assert isinstance(r_targets, dict) + assert r_targets["boxes"].shape == (0, 4) + assert r_targets["polygons"].shape == (0, 4, 2) + # FP16 (only on GPU) if torch.cuda.is_available(): input_t = torch.ones((3, 50, 50), dtype=torch.float16).cuda() @@ -263,18 +340,55 @@ def test_crop_detection(): def test_random_crop(target): cropper = RandomCrop(scale=(0.5, 1.0), ratio=(0.75, 1.33)) input_t = torch.ones((3, 50, 50), dtype=torch.float32) + img, target = cropper(input_t, target) + # Check the scale assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] + # Check aspect ratio assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6 + # Check the target assert np.all(target >= 0) + if target.ndim == 2: assert np.all(target[:, [0, 2]] <= img.shape[-1]) and np.all(target[:, [1, 3]] <= img.shape[-2]) else: assert np.all(target[..., 0] <= img.shape[-1]) and np.all(target[..., 1] <= img.shape[-2]) + # Test dict targets + dict_target = { + "boxes": np.array([[15, 20, 35, 30]]), + "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), + } + + img, cropped_targets = cropper(input_t, dict_target) + + assert isinstance(cropped_targets, dict) + assert set(cropped_targets.keys()) == {"boxes", "polygons"} + + assert isinstance(cropped_targets["boxes"], np.ndarray) + assert isinstance(cropped_targets["polygons"], np.ndarray) + + # Check cropped image properties + assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] + assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6 + + # Check boxes + assert np.all(cropped_targets["boxes"] >= 0) + + if len(cropped_targets["boxes"]) > 0: + assert np.all(cropped_targets["boxes"][:, [0, 2]] <= img.shape[-1]) + assert np.all(cropped_targets["boxes"][:, [1, 3]] <= img.shape[-2]) + + # Check polygons + assert np.all(cropped_targets["polygons"] >= 0) + + if len(cropped_targets["polygons"]) > 0: + assert np.all(cropped_targets["polygons"][..., 0] <= img.shape[-1]) + assert np.all(cropped_targets["polygons"][..., 1] <= img.shape[-2]) + @pytest.mark.parametrize( "input_dtype, input_size", @@ -395,6 +509,45 @@ def test_randomhorizontalflip(p, target): assert np.all(_target == np.array([[[0.1, 0.1], [0.3, 0.1], [0.3, 0.4], [0.1, 0.4]]], dtype=np.float32)) assert torch.all(transformed.mean((0, 1)) == torch.tensor([0] * 16 + [1] * 16, dtype=torch.float32)) + # Test dict targets + dict_target = { + "boxes": np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), + "polygons": np.array( + [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.4], [0.1, 0.4]]], + dtype=np.float32, + ), + } + + transformed, _target = transform(input_t, dict_target) + + assert isinstance(_target, dict) + assert set(_target.keys()) == {"boxes", "polygons"} + + assert _target["boxes"].dtype == np.float32 + assert _target["polygons"].dtype == np.float32 + + if p == 1: + assert np.all(_target["boxes"] == np.array([[0.7, 0.1, 0.9, 0.4]], dtype=np.float32)) + + assert np.all( + _target["polygons"] + == np.array( + [[[0.9, 0.1], [0.7, 0.1], [0.7, 0.4], [0.9, 0.4]]], + dtype=np.float32, + ) + ) + + elif p == 0: + assert np.all(_target["boxes"] == np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)) + + assert np.all( + _target["polygons"] + == np.array( + [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.4], [0.1, 0.4]]], + dtype=np.float32, + ) + ) + @pytest.mark.parametrize( "input_dtype,input_shape", @@ -452,3 +605,138 @@ def test_random_resize(p, preserve_aspect_ratio, symmetric_pad, target): # Resize is already well tested assert torch.all(out_img == img) if p == 0 else out_img.shape != img.shape assert out_target.shape == target.shape + + +# ---------------------------------------------------------------------------- +# End-to-end tests for SampleCompose with geometric and photometric transforms +# ---------------------------------------------------------------------------- + + +def _make_pipeline(): + return SampleCompose([ + RandomHorizontalFlip(p=1.0), + RandomRotate(max_angle=10.0, expand=False), + RandomCrop(scale=(0.8, 1.0), ratio=(0.9, 1.1)), + RandomResize( + scale_range=(0.8, 1.2), + preserve_aspect_ratio=True, + symmetric_pad=True, + p=1.0, + ), + ImageTransform(ColorInversion(min_val=0.7)), + ImageTransform(GaussianNoise(mean=0.0, std=0.1)), + ImageTransform(ChannelShuffle()), + ImageTransform(RandomShadow((0.2, 0.8))), + RandomApply(RandomHorizontalFlip(p=1.0), p=1.0), + OneOf([ + RandomRotate(max_angle=5.0, expand=False), + RandomCrop(scale=(0.9, 1.0), ratio=(0.95, 1.05)), + ]), + ]) + + +def test_samplecompose_end_to_end_boxes(): + random.seed(42) + np.random.seed(42) + torch.manual_seed(42) + + input_t = torch.rand((3, 64, 64), dtype=torch.float32) + targets = { + "boxes": np.array( + [ + [0.1, 0.1, 0.4, 0.4], + [0.5, 0.5, 0.9, 0.9], + ], + dtype=np.float32, + ) + } + + transforms = _make_pipeline() + out_img, out_targets = transforms(input_t, targets) + + # image checks + assert isinstance(out_img, torch.Tensor) + assert out_img.ndim == 3 + assert out_img.shape[0] == 3 + assert torch.all((out_img >= 0) & (out_img <= 1)) + + # target checks + assert isinstance(out_targets, dict) + assert "boxes" in out_targets + boxes = out_targets["boxes"] + assert isinstance(boxes, np.ndarray) + + if len(boxes) > 0: + # must stay boxes + assert boxes.ndim == 2 + assert boxes.shape[1] == 4 + # geometry validity + assert np.all(boxes[:, 2] >= boxes[:, 0]) + assert np.all(boxes[:, 3] >= boxes[:, 1]) + assert np.all(np.isfinite(boxes)) + assert np.all((boxes >= 0) & (boxes <= 1)) + + # immutability check + np.testing.assert_array_equal( + targets["boxes"], + np.array( + [ + [0.1, 0.1, 0.4, 0.4], + [0.5, 0.5, 0.9, 0.9], + ], + dtype=np.float32, + ), + ) + + +def test_samplecompose_end_to_end_polygons(): + random.seed(42) + np.random.seed(42) + torch.manual_seed(42) + + input_t = torch.rand((3, 64, 64), dtype=torch.float32) + targets = { + "polygons": np.array( + [ + [[0.1, 0.1], [0.4, 0.1], [0.4, 0.4], [0.1, 0.4]], + [[0.5, 0.5], [0.9, 0.5], [0.9, 0.9], [0.5, 0.9]], + ], + dtype=np.float32, + ) + } + + transforms = _make_pipeline() + out_img, out_targets = transforms(input_t, targets) + + # image checks + assert isinstance(out_img, torch.Tensor) + assert out_img.ndim == 3 + assert out_img.shape[0] == 3 + assert torch.all((out_img >= 0) & (out_img <= 1)) + + # target checks + assert isinstance(out_targets, dict) + assert "polygons" in out_targets + polys = out_targets["polygons"] + assert isinstance(polys, np.ndarray) + + if len(polys) > 0: + assert polys.ndim == 3 + assert polys.shape[1:] == (4, 2) + # geometry validity + assert np.all(np.isfinite(polys)) + assert np.all((polys >= 0) & (polys <= 1)) + # ensure valid polygon structure (non-degenerate) + assert np.all(np.linalg.norm(polys[:, 1] - polys[:, 0], axis=1) > 0) + + # immutability check + np.testing.assert_array_equal( + targets["polygons"], + np.array( + [ + [[0.1, 0.1], [0.4, 0.1], [0.4, 0.4], [0.1, 0.4]], + [[0.5, 0.5], [0.9, 0.5], [0.9, 0.9], [0.5, 0.9]], + ], + dtype=np.float32, + ), + ) From 0f7465e06510e8b6d8bd74ec137f16e3803235b3 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 12 May 2026 15:25:26 +0200 Subject: [PATCH 2/4] mypy --- doctr/transforms/modules/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index 1f20a805b7..27977ce825 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -155,7 +155,7 @@ def forward( offset = (0, 0) if isinstance(target, dict): - target = { # type: ignore[assignment] + target = { cls_name: self._resize_target( arr, raw_shape, From d64c4f86eb5127372a6e40c8c3687babf59bf26e Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 12 May 2026 15:29:20 +0200 Subject: [PATCH 3/4] Update types and docstring --- doctr/transforms/modules/pytorch.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index 27977ce825..22f3e3ddac 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -4,6 +4,7 @@ # See LICENSE or go to for full license details. import math +from collections.abc import Sequence import numpy as np import torch @@ -59,12 +60,13 @@ def __init__( def _resize_target( self, - target, - raw_shape, - final_shape, - symmetric_pad=False, - offset=(0, 0), - ): + target: np.ndarray, + raw_shape: Sequence[int], + final_shape: Sequence[int], + symmetric_pad: bool = False, + offset: tuple[int, int] = (0, 0), + ) -> np.ndarray: + """Resize the target boxes according to the resizing of the image and the padding if needed""" target = target.copy() if target.shape[1:] == (4,): From 2cac7bd828e366d8844cce206a66d6639ce9823e Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 12 May 2026 15:33:28 +0200 Subject: [PATCH 4/4] style --- tests/pytorch/test_transforms_pt.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py index 6c89c88fb7..657e1561fc 100644 --- a/tests/pytorch/test_transforms_pt.py +++ b/tests/pytorch/test_transforms_pt.py @@ -342,16 +342,12 @@ def test_random_crop(target): input_t = torch.ones((3, 50, 50), dtype=torch.float32) img, target = cropper(input_t, target) - # Check the scale assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] - # Check aspect ratio assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6 - # Check the target assert np.all(target >= 0) - if target.ndim == 2: assert np.all(target[:, [0, 2]] <= img.shape[-1]) and np.all(target[:, [1, 3]] <= img.shape[-2]) else: @@ -362,29 +358,21 @@ def test_random_crop(target): "boxes": np.array([[15, 20, 35, 30]]), "polygons": np.array([[[15, 20], [35, 20], [35, 30], [15, 30]]]), } - img, cropped_targets = cropper(input_t, dict_target) - assert isinstance(cropped_targets, dict) assert set(cropped_targets.keys()) == {"boxes", "polygons"} - assert isinstance(cropped_targets["boxes"], np.ndarray) assert isinstance(cropped_targets["polygons"], np.ndarray) - # Check cropped image properties assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.6 - # Check boxes assert np.all(cropped_targets["boxes"] >= 0) - if len(cropped_targets["boxes"]) > 0: assert np.all(cropped_targets["boxes"][:, [0, 2]] <= img.shape[-1]) assert np.all(cropped_targets["boxes"][:, [1, 3]] <= img.shape[-2]) - # Check polygons assert np.all(cropped_targets["polygons"] >= 0) - if len(cropped_targets["polygons"]) > 0: assert np.all(cropped_targets["polygons"][..., 0] <= img.shape[-1]) assert np.all(cropped_targets["polygons"][..., 1] <= img.shape[-2]) @@ -519,16 +507,12 @@ def test_randomhorizontalflip(p, target): } transformed, _target = transform(input_t, dict_target) - assert isinstance(_target, dict) assert set(_target.keys()) == {"boxes", "polygons"} - assert _target["boxes"].dtype == np.float32 assert _target["polygons"].dtype == np.float32 - if p == 1: assert np.all(_target["boxes"] == np.array([[0.7, 0.1, 0.9, 0.4]], dtype=np.float32)) - assert np.all( _target["polygons"] == np.array( @@ -536,10 +520,8 @@ def test_randomhorizontalflip(p, target): dtype=np.float32, ) ) - elif p == 0: assert np.all(_target["boxes"] == np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)) - assert np.all( _target["polygons"] == np.array(