-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimg_transformation.py
More file actions
56 lines (47 loc) · 2.42 KB
/
img_transformation.py
File metadata and controls
56 lines (47 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import random
import warnings
import numpy as np
import torch
from torchvision import transforms
def img_transformation(data, prob=1, gamma=3, random_erase_greyscale=True, seed=None):
"""Function to apply transformations to an image or batch of images.
Args:
data (tensor): A 3D image (single) or a 4D tensor (batch).
prob (float, 0 <= prob <= 1): Probability that any given image will be transformed. Defaults to 1.
gamma (float): Transformation intensity parameter. Clamped between 0 and 10, defaults to 3.
random_erase_greyscale (bool, optional): Whether to include random erasure and grayscale transformations. Defaults to True.
seed (int, optional): A seed for reproducibility. Defaults to None.
"""
# Clamp gamma to be between 0 and 10
gamma = min(max(gamma, 0), 10)
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Define common transformations
transform_list = [
transforms.RandomResizedCrop(32, (1 - 0.1 * gamma, 1 + 0.2 * gamma), antialias=True),
transforms.Pad(padding=12, fill=0, padding_mode='symmetric'),
transforms.RandomAffine(degrees=9 * gamma, shear=3.3 * gamma),
transforms.CenterCrop(32),
transforms.ColorJitter(brightness=0.08 * gamma, contrast=0.08 * gamma, saturation=0.1 * gamma, hue=(0.06 * gamma if gamma <= 5 else 0.5))
]
# Add conditional transformations for higher gamma values
if gamma >= 1:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list.append(transforms.GaussianBlur(kernel_size=int(np.floor(gamma / 4)) * 2 + 1))
if random_erase_greyscale:
transform_list.append(transforms.RandomGrayscale(p=0.04 * gamma))
if gamma >= 1:
transform_list.append(transforms.RandomErasing(p=(0.15 * gamma if gamma <= 6 else 1), scale=(0.01 * gamma, 0.05 * gamma)))
transform = transforms.Compose(transform_list)
# Apply transformations to the data
if data.ndim == 3: # Single image
return transform(data) if random.random() < prob else data
elif data.ndim == 4: # Batch of images
out_imgs = torch.empty_like(data)
for i in range(len(data)):
out_imgs[i] = transform(data[i]) if random.random() < prob else data[i]
return out_imgs
else:
raise ValueError("Input tensor must be 3D (single image) or 4D (batch of images).")