Skip to content

Commit a3dcced

Browse files
Merge pull request #162 from Jasonlee1995/color_aug
Inplace Color Augmentations (RandomBrightness, RandomContrast, RandomSaturation)
2 parents 7022e9b + c2c8ccf commit a3dcced

3 files changed

Lines changed: 142 additions & 1 deletion

File tree

ffcv/.DS_Store

6 KB
Binary file not shown.

ffcv/transforms/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from .translate import RandomTranslate
1010
from .mixup import ImageMixup, LabelMixup, MixupToOneHot
1111
from .module import ModuleWrapper
12+
from .color_jitter import RandomBrightness, RandomContrast, RandomSaturation
1213

1314
__all__ = ['ToTensor', 'ToDevice',
1415
'ToTorchImage', 'NormalizeImage',
1516
'Convert', 'Squeeze', 'View',
1617
'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate',
1718
'Cutout', 'ImageMixup', 'LabelMixup', 'MixupToOneHot',
1819
'Poison', 'ReplaceLabel',
19-
'ModuleWrapper']
20+
'ModuleWrapper',
21+
'RandomBrightness', 'RandomContrast', 'RandomSaturation']

ffcv/transforms/color_jitter.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
'''
2+
Random color operations similar to torchvision.transforms.ColorJitter except not supporting hue
3+
Reference : https://github.com/pytorch/vision/blob/main/torchvision/transforms/functional_tensor.py
4+
'''
5+
6+
import numpy as np
7+
8+
from dataclasses import replace
9+
from ..pipeline.allocation_query import AllocationQuery
10+
from ..pipeline.operation import Operation
11+
from ..pipeline.state import State
12+
from ..pipeline.compiler import Compiler
13+
14+
15+
16+
class RandomBrightness(Operation):
17+
'''
18+
Randomly adjust image brightness. Operates on raw arrays (not tensors).
19+
20+
Parameters
21+
----------
22+
magnitude : float
23+
randomly choose brightness enhancement factor on [max(0, 1-magnitude), 1+magnitude]
24+
p : float
25+
probability to apply brightness
26+
'''
27+
def __init__(self, magnitude: float, p=0.5):
28+
super().__init__()
29+
self.p = p
30+
self.magnitude = magnitude
31+
32+
def generate_code(self):
33+
my_range = Compiler.get_iterator()
34+
p = self.p
35+
magnitude = self.magnitude
36+
37+
def brightness(images, *_):
38+
def blend(img1, img2, ratio): return (ratio*img1 + (1-ratio)*img2).clip(0, 255).astype(img1.dtype)
39+
40+
apply_bright = np.random.rand(images.shape[0]) < p
41+
magnitudes = np.random.uniform(max(0, 1-magnitude), 1+magnitude, images.shape[0])
42+
for i in my_range(images.shape[0]):
43+
if apply_bright[i]:
44+
images[i] = blend(images[i], 0, magnitudes[i])
45+
46+
return images
47+
48+
brightness.is_parallel = True
49+
return brightness
50+
51+
def declare_state_and_memory(self, previous_state):
52+
return (replace(previous_state, jit_mode=True), AllocationQuery(previous_state.shape, previous_state.dtype))
53+
54+
55+
56+
class RandomContrast(Operation):
57+
'''
58+
Randomly adjust image contrast. Operates on raw arrays (not tensors).
59+
60+
Parameters
61+
----------
62+
magnitude : float
63+
randomly choose contrast enhancement factor on [max(0, 1-magnitude), 1+magnitude]
64+
p : float
65+
probability to apply contrast
66+
'''
67+
def __init__(self, magnitude, p=0.5):
68+
super().__init__()
69+
self.p = p
70+
self.magnitude = magnitude
71+
72+
def generate_code(self):
73+
my_range = Compiler.get_iterator()
74+
p = self.p
75+
magnitude = self.magnitude
76+
77+
def contrast(images, *_):
78+
def blend(img1, img2, ratio): return (ratio*img1 + (1-ratio)*img2).clip(0, 255).astype(img1.dtype)
79+
80+
apply_contrast = np.random.rand(images.shape[0]) < p
81+
magnitudes = np.random.uniform(max(0, 1-magnitude), 1+magnitude, images.shape[0])
82+
for i in my_range(images.shape[0]):
83+
if apply_contrast[i]:
84+
r, g, b = images[i,:,:,0], images[i,:,:,1], images[i,:,:,2]
85+
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).astype(images[i].dtype)
86+
images[i] = blend(images[i], l_img.mean(), magnitudes[i])
87+
88+
return images
89+
90+
contrast.is_parallel = True
91+
return contrast
92+
93+
def declare_state_and_memory(self, previous_state):
94+
return (replace(previous_state, jit_mode=True), AllocationQuery(previous_state.shape, previous_state.dtype))
95+
96+
97+
98+
class RandomSaturation(Operation):
99+
'''
100+
Randomly adjust image color balance. Operates on raw arrays (not tensors).
101+
102+
Parameters
103+
----------
104+
magnitude : float
105+
randomly choose color balance enhancement factor on [max(0, 1-magnitude), 1+magnitude]
106+
p : float
107+
probability to apply saturation
108+
'''
109+
def __init__(self, magnitude, p=0.5):
110+
super().__init__()
111+
self.p = p
112+
self.magnitude = magnitude
113+
114+
def generate_code(self):
115+
my_range = Compiler.get_iterator()
116+
p = self.p
117+
magnitude = self.magnitude
118+
119+
def saturation(images, *_):
120+
def blend(img1, img2, ratio): return (ratio*img1 + (1-ratio)*img2).clip(0, 255).astype(img1.dtype)
121+
122+
apply_saturation = np.random.rand(images.shape[0]) < p
123+
magnitudes = np.random.uniform(max(0, 1-magnitude), 1+magnitude, images.shape[0])
124+
for i in my_range(images.shape[0]):
125+
if apply_saturation[i]:
126+
r, g, b = images[i,:,:,0], images[i,:,:,1], images[i,:,:,2]
127+
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).astype(images[i].dtype)
128+
l_img3 = np.zeros_like(images[i])
129+
for j in my_range(images[i].shape[-1]):
130+
l_img3[:,:,j] = l_img
131+
images[i] = blend(images[i], l_img3, magnitudes[i])
132+
133+
return images
134+
135+
saturation.is_parallel = True
136+
return saturation
137+
138+
def declare_state_and_memory(self, previous_state):
139+
return (replace(previous_state, jit_mode=True), AllocationQuery(previous_state.shape, previous_state.dtype))

0 commit comments

Comments
 (0)