Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 127 additions & 23 deletions doctr/transforms/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@
def __init__(self, transforms: list[Callable[[Any], Any]]) -> None:
self.transforms = transforms

def __call__(self, img: Any, target: np.ndarray | None = None) -> Any | tuple[Any, np.ndarray]:
def __call__(
self, img: Any, target: np.ndarray | dict[str, np.ndarray] | None = None
) -> Any | tuple[Any, np.ndarray | dict[str, np.ndarray]]:
# Pick transformation
transfo = self.transforms[int(random.random() * len(self.transforms))]
# Apply
Expand Down Expand Up @@ -141,7 +143,11 @@
def extra_repr(self) -> str:
return f"transform={self.transform}, p={self.p}"

def __call__(self, img: Any, target: np.ndarray | None = None) -> Any | tuple[Any, np.ndarray]:
def __call__(
self,
img: Any,
target: np.ndarray | dict[str, np.ndarray] | None = None,
) -> Any | tuple[Any, np.ndarray | dict[str, np.ndarray]]:
if random.random() < self.p:
return self.transform(img) if target is None else self.transform(img, target) # type: ignore[call-arg]
return img if target is None else (img, target)
Expand All @@ -165,12 +171,46 @@
def extra_repr(self) -> str:
return f"max_angle={self.max_angle}, expand={self.expand}"

def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]:
angle = random.uniform(-self.max_angle, self.max_angle)
def _rotate_array(self, img: Any, target: np.ndarray, angle: float) -> tuple[Any, np.ndarray]:
"""Rotate the image and the target, and keep only boxes with at least partial visibility after rotation"""

Check notice on line 175 in doctr/transforms/modules/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/transforms/modules/base.py#L175

First line should end with a period, question mark, or exclamation point (not 'n') (D415)
is_polygon = target.shape[1:] == (4, 2)

r_img, r_polys = F.rotate_sample(img, target, angle, self.expand)
# Removes deleted boxes

is_kept = (r_polys.max(1) > r_polys.min(1)).sum(1) == 2
return r_img, r_polys[is_kept]
r_polys = r_polys[is_kept]

# convert back if input was boxes
if not is_polygon:
# (N, 4, 2) -> (N, 4)
x1y1 = r_polys.min(axis=1)
x2y2 = r_polys.max(axis=1)
r_boxes = np.concatenate([x1y1, x2y2], axis=1)
return r_img, r_boxes

return r_img, r_polys

def __call__(
self, img: Any, target: np.ndarray | dict[str, np.ndarray]
) -> tuple[Any, np.ndarray | dict[str, np.ndarray]]:
angle = random.uniform(-self.max_angle, self.max_angle)

if isinstance(target, dict):
rotated_targets = {}
rotated_img = None

for cls_name, arr in target.items():
if len(arr) == 0:
rotated_targets[cls_name] = arr.copy()
continue

r_img, r_arr = self._rotate_array(img, arr, angle)
if rotated_img is None:
rotated_img = r_img
rotated_targets[cls_name] = r_arr
return rotated_img if rotated_img is not None else img, rotated_targets

return self._rotate_array(img, target, angle)


class RandomCrop(NestedObject):
Expand All @@ -188,7 +228,61 @@
def extra_repr(self) -> str:
return f"scale={self.scale}, ratio={self.ratio}"

def __call__(self, img: Any, target: np.ndarray) -> tuple[Any, np.ndarray]:
def _crop_array(
self,
img: Any,
target: np.ndarray,
crop_box: tuple[float, float, float, float],
) -> tuple[Any, np.ndarray]:
is_polygon = target.shape[1:] == (4, 2)
# For polygons, we need to reproject the coordinates into the cropped frame,
# and keep only those with at least partial visibility
if is_polygon:
cropped_img, _ = F.crop_detection(
img,
np.concatenate(
(
np.min(target, axis=1),
np.max(target, axis=1),
),
axis=1,
),
crop_box,
)

cropped_polys = target.copy()

crop_w = crop_box[2] - crop_box[0]
crop_h = crop_box[3] - crop_box[1]

# Reproject coordinates into cropped frame
cropped_polys[..., 0] = (cropped_polys[..., 0] - crop_box[0]) / crop_w
cropped_polys[..., 1] = (cropped_polys[..., 1] - crop_box[1]) / crop_h

# Keep polygons with at least partial visibility
poly_min = np.min(cropped_polys, axis=1)
poly_max = np.max(cropped_polys, axis=1)
is_kept = (poly_max[:, 0] > 0) & (poly_min[:, 0] < 1) & (poly_max[:, 1] > 0) & (poly_min[:, 1] < 1)
cropped_polys = cropped_polys[is_kept]

if cropped_polys.shape[0] == 0:
return img, target

return cropped_img, np.clip(cropped_polys, 0, 1)

# For detection boxes, we can directly crop and clip them
cropped_img, crop_boxes = F.crop_detection(img, target, crop_box)

if crop_boxes.shape[0] == 0:
return img, target

return cropped_img, np.clip(crop_boxes, 0, 1)

def __call__(
self,
img: Any,
target: np.ndarray | dict[str, np.ndarray],
) -> tuple[Any, np.ndarray | dict[str, np.ndarray]]:
scale = random.uniform(self.scale[0], self.scale[1])
ratio = random.uniform(self.ratio[0], self.ratio[1])

Expand All @@ -208,19 +302,29 @@
x = random.randint(0, width - crop_width)
y = random.randint(0, height - crop_height)

# relative crop box
crop_box = (x / width, y / height, (x + crop_width) / width, (y + crop_height) / height)
if target.shape[1:] == (4, 2):
min_xy = np.min(target, axis=1)
max_xy = np.max(target, axis=1)
_target = np.concatenate((min_xy, max_xy), axis=1)
else:
_target = target

# Crop image and targets
croped_img, crop_boxes = F.crop_detection(img, _target, crop_box)
# hard fallback if no box is kept
if crop_boxes.shape[0] == 0:
return img, target
# clip boxes
return croped_img, np.clip(crop_boxes, 0, 1)
crop_box = (
x / width,
y / height,
(x + crop_width) / width,
(y + crop_height) / height,
)

if isinstance(target, dict):
cropped_targets = {}
cropped_img = None

for cls_name, arr in target.items():
if len(arr) == 0:
cropped_targets[cls_name] = arr.copy()
continue

c_img, c_arr = self._crop_array(img, arr, crop_box)

if cropped_img is None:
cropped_img = c_img

cropped_targets[cls_name] = c_arr

return cropped_img if cropped_img is not None else img, cropped_targets

return self._crop_array(img, target, crop_box)
127 changes: 91 additions & 36 deletions doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import math
from collections.abc import Sequence

import numpy as np
import torch
Expand Down Expand Up @@ -57,6 +58,38 @@
self.symmetric_pad = symmetric_pad
self.return_padding_mask = return_padding_mask

def _resize_target(
self,
target: np.ndarray,
raw_shape: Sequence[int],
final_shape: Sequence[int],
symmetric_pad: bool = False,
offset: tuple[int, int] = (0, 0),
) -> np.ndarray:
"""Resize the target boxes according to the resizing of the image and the padding if needed"""

Check notice on line 69 in doctr/transforms/modules/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/transforms/modules/pytorch.py#L69

First line should end with a period, question mark, or exclamation point (not 'd') (D415)
target = target.copy()

if target.shape[1:] == (4,):
if symmetric_pad:
target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / final_shape[-1]
target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / final_shape[-2]
else:
target[:, [0, 2]] *= raw_shape[-1] / final_shape[-1]
target[:, [1, 3]] *= raw_shape[-2] / final_shape[-2]

elif target.shape[1:] == (4, 2):
if symmetric_pad:
target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / final_shape[-1]
target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / final_shape[-2]
else:
target[..., 0] *= raw_shape[-1] / final_shape[-1]
target[..., 1] *= raw_shape[-2] / final_shape[-2]

else:
raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")

return np.clip(target, 0, 1)

def forward(
self,
img: torch.Tensor,
Expand Down Expand Up @@ -116,32 +149,36 @@
# In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
if target is not None:
if self.symmetric_pad:
offset = half_pad[0] / img.shape[-1], half_pad[1] / img.shape[-2]

if self.preserve_aspect_ratio:
# Get absolute coords
if target.shape[1:] == (4,):
if self.symmetric_pad:
target[:, [0, 2]] = offset[0] + target[:, [0, 2]] * raw_shape[-1] / img.shape[-1]
target[:, [1, 3]] = offset[1] + target[:, [1, 3]] * raw_shape[-2] / img.shape[-2]
else:
target[:, [0, 2]] *= raw_shape[-1] / img.shape[-1]
target[:, [1, 3]] *= raw_shape[-2] / img.shape[-2]
elif target.shape[1:] == (4, 2):
if self.symmetric_pad:
target[..., 0] = offset[0] + target[..., 0] * raw_shape[-1] / img.shape[-1]
target[..., 1] = offset[1] + target[..., 1] * raw_shape[-2] / img.shape[-2]
else:
target[..., 0] *= raw_shape[-1] / img.shape[-1]
target[..., 1] *= raw_shape[-2] / img.shape[-2]
else:
raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")

target = np.clip(target, 0, 1)

if self.return_padding_mask:
return img, target, padding_mask
return img, target
offset = (
half_pad[0] / img.shape[-1],
half_pad[1] / img.shape[-2],
)
else:
offset = (0, 0)

if isinstance(target, dict):
target = {
cls_name: self._resize_target(
arr,
raw_shape,
img.shape[-2:],
symmetric_pad=self.symmetric_pad,
offset=offset,
)
for cls_name, arr in target.items()
}
else:
target = self._resize_target(
target,
raw_shape,
img.shape[-2:],
symmetric_pad=self.symmetric_pad,
offset=offset,
)
if target is not None:
if self.return_padding_mask:
return img, target, padding_mask
return img, target

if self.return_padding_mask:
return img, padding_mask
Expand Down Expand Up @@ -234,16 +271,30 @@
class RandomHorizontalFlip(T.RandomHorizontalFlip):
"""Randomly flip the input image horizontally"""

def forward(self, img: torch.Tensor | Image, target: np.ndarray) -> tuple[torch.Tensor | Image, np.ndarray]:
def _flip_array(self, target):
_target = target.copy()
# Changing the relative bbox coordinates
if target.shape[1:] == (4,):
_target[:, ::2] = 1 - target[:, [2, 0]]
else:
_target[..., 0] = 1 - target[..., 0]

return _target

def forward(
self,
img: torch.Tensor | Image,
target: np.ndarray | dict[str, np.ndarray],
) -> tuple[torch.Tensor | Image, np.ndarray | dict[str, np.ndarray]]:

if torch.rand(1) < self.p:
_img = F.hflip(img)
_target = target.copy()
# Changing the relative bbox coordinates
if target.shape[1:] == (4,):
_target[:, ::2] = 1 - target[:, [2, 0]]
else:
_target[..., 0] = 1 - target[..., 0]
return _img, _target

if isinstance(target, dict):
return _img, {cls_name: self._flip_array(arr) for cls_name, arr in target.items()}

return _img, self._flip_array(target)

return img, target


Expand Down Expand Up @@ -319,7 +370,11 @@
self.p = p
self._resize = Resize

def forward(self, img: torch.Tensor, target: np.ndarray) -> tuple[torch.Tensor, np.ndarray]:
def forward(
self,
img: torch.Tensor,
target: np.ndarray | dict[str, np.ndarray],
) -> tuple[torch.Tensor, np.ndarray | dict[str, np.ndarray]]:
if torch.rand(1) < self.p:
scale_h = np.random.uniform(*self.scale_range)
scale_w = np.random.uniform(*self.scale_range)
Expand All @@ -329,7 +384,7 @@
new_size,
preserve_aspect_ratio=self.preserve_aspect_ratio
if isinstance(self.preserve_aspect_ratio, bool)
else bool(torch.rand(1) <= self.symmetric_pad),
else bool(torch.rand(1) <= self.preserve_aspect_ratio),
symmetric_pad=self.symmetric_pad
if isinstance(self.symmetric_pad, bool)
else bool(torch.rand(1) <= self.symmetric_pad),
Expand Down
Loading
Loading