forked from mindspore-Ecosystem/mindspore
!31320 [MD][Offload] Offload RandomColorAdjust Op Update
Merge pull request !31320 from alashkari/update-color-op
This commit is contained in:
commit
62a2b7c41c
|
@ -204,12 +204,35 @@ class RandomVerticalFlip(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
class GenerateRandBatch(nn.Cell):
|
||||
"""
|
||||
Generate batch with random values uniformly selected from [degree_min, degree_max].
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(GenerateRandBatch, self).__init__()
|
||||
|
||||
self.ones = P.Ones()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def __call__(self, degree_min, degree_max, check_rand, shape):
|
||||
|
||||
bs, h, w, c = shape
|
||||
rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
||||
rand_factor = degree_min + (degree_max - degree_min)*rand_factor
|
||||
degree_factor = degree_min * self.ones((bs, 1), mstype.float32)
|
||||
rand_factor = (check_rand * degree_factor) + (~check_rand * rand_factor)
|
||||
rand_factor = self.reshape(C.repeat_elements(rand_factor, rep=(h*w*c)), (bs, h, w, c))
|
||||
|
||||
return rand_factor
|
||||
|
||||
|
||||
class RandomColorAdjust(nn.Cell):
|
||||
"""
|
||||
Applies Random Color Adjust transform on given input tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, brightness, saturation):
|
||||
def __init__(self, brightness, contrast, saturation, hue):
|
||||
super(RandomColorAdjust, self).__init__()
|
||||
|
||||
if isinstance(brightness, (list, tuple)):
|
||||
|
@ -219,6 +242,13 @@ class RandomColorAdjust(nn.Cell):
|
|||
self.br_min = max(0, 1 - brightness)
|
||||
self.br_max = 1 + brightness
|
||||
|
||||
if isinstance(contrast, (list, tuple)):
|
||||
self.cont_min = contrast[0]
|
||||
self.cont_max = contrast[1]
|
||||
else:
|
||||
self.cont_min = max(0, 1 - contrast)
|
||||
self.cont_max = 1 + contrast
|
||||
|
||||
if isinstance(saturation, (list, tuple)):
|
||||
self.sa_min = saturation[0]
|
||||
self.sa_max = saturation[1]
|
||||
|
@ -226,6 +256,18 @@ class RandomColorAdjust(nn.Cell):
|
|||
self.sa_min = max(0, 1 - saturation)
|
||||
self.sa_max = 1 + saturation
|
||||
|
||||
if isinstance(hue, (list, tuple)):
|
||||
self.hue_min = hue[0]
|
||||
self.hue_max = hue[1]
|
||||
else:
|
||||
self.hue_min = max(0, 1 - hue)
|
||||
self.hue_max = 1 + hue
|
||||
|
||||
self.check_rand_br = Tensor(self.br_min == self.br_max)
|
||||
self.check_rand_cont = Tensor(self.cont_min == self.cont_max)
|
||||
self.check_rand_sa = Tensor(self.sa_min == self.sa_max)
|
||||
self.check_rand_hue = Tensor(self.hue_min == self.hue_max)
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
|
@ -233,6 +275,17 @@ class RandomColorAdjust(nn.Cell):
|
|||
self.expand_dims = P.ExpandDims()
|
||||
self.mul = P.Mul()
|
||||
|
||||
self.mean = P.ReduceMean()
|
||||
self.argmaxvalue = P.ArgMaxWithValue(axis=3, keep_dims=False)
|
||||
self.argminvalue = P.ArgMinWithValue(axis=3, keep_dims=False)
|
||||
self.stack = P.Stack(axis=0)
|
||||
self.epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
||||
self.squeeze_0 = P.Squeeze(axis=0)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.gatherd = P.GatherD()
|
||||
|
||||
self.generate_rand_batch = GenerateRandBatch()
|
||||
|
||||
def construct(self, x):
|
||||
|
||||
x = self.cast(x, mstype.float32)
|
||||
|
@ -240,23 +293,67 @@ class RandomColorAdjust(nn.Cell):
|
|||
check_input_dims(x_shape, 4, 'RandomColorAdjust')
|
||||
bs, h, w, c = x_shape
|
||||
|
||||
br_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
||||
br_rand_factor = self.br_min + (self.br_max - self.br_min)*br_rand_factor
|
||||
br_rand_factor = self.reshape(C.repeat_elements(br_rand_factor, rep=(h*w*c)), (bs, h, w, c))
|
||||
|
||||
sa_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
||||
sa_rand_factor = self.sa_min + (self.sa_max - self.sa_min)*sa_rand_factor
|
||||
sa_rand_factor = self.reshape(C.repeat_elements(sa_rand_factor, rep=(h*w*c)), (bs, h, w, c))
|
||||
br_rand_factor = self.generate_rand_batch(self.br_min, self.br_max, self.check_rand_br, x_shape)
|
||||
cont_rand_factor = self.generate_rand_batch(self.cont_min, self.cont_max, self.check_rand_cont, x_shape)
|
||||
sat_rand_factor = self.generate_rand_batch(self.sa_min, self.sa_max, self.check_rand_sa, x_shape)
|
||||
|
||||
r, g, b = self.unstack(x)
|
||||
x_gray = C.repeat_elements(self.expand_dims((0.2989 * r + 0.587 * g + 0.114 * b), -1), rep=c, axis=-1)
|
||||
|
||||
x_gray = 0.2989 * r + 0.587 * g + 0.114 * b
|
||||
x_gray_mean = self.expand_dims(self.mean(x_gray, (1, 2)) + 0.5, -1)
|
||||
x_gray_mean = self.reshape(C.repeat_elements(x_gray_mean, rep=(h*w*c)), (bs, h, w, c))
|
||||
x_gray = C.repeat_elements(self.expand_dims(x_gray, -1), rep=c, axis=-1)
|
||||
|
||||
# Apply brightness
|
||||
x = self.mul(x, br_rand_factor)
|
||||
x = C.clip_by_value(x, 0.0, 255.0)
|
||||
|
||||
x = self.mul(x, sa_rand_factor) + self.mul((1 - sa_rand_factor), x_gray)
|
||||
# Apply contrast
|
||||
x = self.mul(x, cont_rand_factor) + self.mul((1 - cont_rand_factor), x_gray_mean)
|
||||
x = C.clip_by_value(x, 0.0, 255.0)
|
||||
|
||||
# Apply saturation
|
||||
x = self.mul(x, sat_rand_factor) + self.mul((1 - sat_rand_factor), x_gray)
|
||||
x = C.clip_by_value(x, 0.0, 255.0)
|
||||
|
||||
# Apply Hue Transform
|
||||
# Convert tensor from rgb to hsv
|
||||
r, g, b = self.unstack(x)
|
||||
max_c, max_v = self.argmaxvalue(x)
|
||||
_, min_v = self.argminvalue(x)
|
||||
hsv_denum = max_v - min_v + self.epsilon
|
||||
h1 = self.floormod(((b - g) * 60 / hsv_denum), 360)
|
||||
h2 = (g - r) * 60 / hsv_denum + 120
|
||||
h3 = (r - g) * 60 / hsv_denum + 240
|
||||
h = self.stack((h1, h2, h3))
|
||||
h = self.gatherd(h, 0, self.expand_dims(max_c, 0))
|
||||
h = self.squeeze(h)
|
||||
s = self.cast((max_v > 0), mstype.float32)
|
||||
s = s * (1 - min_v / (max_v + self.epsilon))
|
||||
v = self.cast(max_v, mstype.float32)
|
||||
|
||||
# Adjust hue
|
||||
hue_rand_factor = self.generate_rand_batch(self.hue_min, self.hue_max, self.check_rand_hue, x_shape)
|
||||
h = h + hue_rand_factor * 360.0
|
||||
|
||||
# Convert tensor from hsv to rgb
|
||||
h_ = (h - self.floor(h / 360.0) * 360.0) / 60.0
|
||||
c = self.mul(s, v)
|
||||
x_ = self.mul(c, (1 - self.abs(self.fmod(h_, 2) - 1)))
|
||||
zero_tensor = self.zeros_like(c)
|
||||
|
||||
y = self.stack((self.stack_axis_1((c, x_, zero_tensor)), self.stack_axis_1((x_, c, zero_tensor)),
|
||||
self.stack_axis_1((zero_tensor, c, x_)), self.stack_axis_1((zero_tensor, x_, c)),
|
||||
self.stack_axis_1((x_, zero_tensor, c)), self.stack_axis_1((c, zero_tensor, x_)),
|
||||
))
|
||||
|
||||
index = self.expand_dims(self.floor(h_), 1)
|
||||
index = self.expand_dims(C.repeat_elements(index, 3, 1), 0)
|
||||
index = self.cast(index, mstype.int32)
|
||||
|
||||
x = self.squeeze(self.gatherd(y, 0, index))
|
||||
x = x + self.reshape(C.repeat_elements((v - c), rep=(3)), self.shape(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue