!31320 [MD][Offload] Offload RandomColorAdjust Op Update

Merge pull request !31320 from alashkari/update-color-op
This commit is contained in:
i-robot 2022-03-16 20:50:38 +00:00 committed by Gitee
commit 62a2b7c41c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 107 additions and 10 deletions

View File

@ -204,12 +204,35 @@ class RandomVerticalFlip(nn.Cell):
return x
class GenerateRandBatch(nn.Cell):
"""
Generate batch with random values uniformly selected from [degree_min, degree_max].
"""
def __init__(self):
super(GenerateRandBatch, self).__init__()
self.ones = P.Ones()
self.reshape = P.Reshape()
def __call__(self, degree_min, degree_max, check_rand, shape):
bs, h, w, c = shape
rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
rand_factor = degree_min + (degree_max - degree_min)*rand_factor
degree_factor = degree_min * self.ones((bs, 1), mstype.float32)
rand_factor = (check_rand * degree_factor) + (~check_rand * rand_factor)
rand_factor = self.reshape(C.repeat_elements(rand_factor, rep=(h*w*c)), (bs, h, w, c))
return rand_factor
class RandomColorAdjust(nn.Cell):
"""
Applies Random Color Adjust transform on given input tensors.
"""
def __init__(self, brightness, saturation):
def __init__(self, brightness, contrast, saturation, hue):
super(RandomColorAdjust, self).__init__()
if isinstance(brightness, (list, tuple)):
@ -219,6 +242,13 @@ class RandomColorAdjust(nn.Cell):
self.br_min = max(0, 1 - brightness)
self.br_max = 1 + brightness
if isinstance(contrast, (list, tuple)):
self.cont_min = contrast[0]
self.cont_max = contrast[1]
else:
self.cont_min = max(0, 1 - contrast)
self.cont_max = 1 + contrast
if isinstance(saturation, (list, tuple)):
self.sa_min = saturation[0]
self.sa_max = saturation[1]
@ -226,6 +256,18 @@ class RandomColorAdjust(nn.Cell):
self.sa_min = max(0, 1 - saturation)
self.sa_max = 1 + saturation
if isinstance(hue, (list, tuple)):
self.hue_min = hue[0]
self.hue_max = hue[1]
else:
self.hue_min = max(0, 1 - hue)
self.hue_max = 1 + hue
self.check_rand_br = Tensor(self.br_min == self.br_max)
self.check_rand_cont = Tensor(self.cont_min == self.cont_max)
self.check_rand_sa = Tensor(self.sa_min == self.sa_max)
self.check_rand_hue = Tensor(self.hue_min == self.hue_max)
self.cast = P.Cast()
self.shape = P.Shape()
self.reshape = P.Reshape()
@ -233,6 +275,17 @@ class RandomColorAdjust(nn.Cell):
self.expand_dims = P.ExpandDims()
self.mul = P.Mul()
self.mean = P.ReduceMean()
self.argmaxvalue = P.ArgMaxWithValue(axis=3, keep_dims=False)
self.argminvalue = P.ArgMinWithValue(axis=3, keep_dims=False)
self.stack = P.Stack(axis=0)
self.epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
self.squeeze_0 = P.Squeeze(axis=0)
self.expand_dims = P.ExpandDims()
self.gatherd = P.GatherD()
self.generate_rand_batch = GenerateRandBatch()
def construct(self, x):
x = self.cast(x, mstype.float32)
@ -240,23 +293,67 @@ class RandomColorAdjust(nn.Cell):
check_input_dims(x_shape, 4, 'RandomColorAdjust')
bs, h, w, c = x_shape
br_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
br_rand_factor = self.br_min + (self.br_max - self.br_min)*br_rand_factor
br_rand_factor = self.reshape(C.repeat_elements(br_rand_factor, rep=(h*w*c)), (bs, h, w, c))
sa_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
sa_rand_factor = self.sa_min + (self.sa_max - self.sa_min)*sa_rand_factor
sa_rand_factor = self.reshape(C.repeat_elements(sa_rand_factor, rep=(h*w*c)), (bs, h, w, c))
br_rand_factor = self.generate_rand_batch(self.br_min, self.br_max, self.check_rand_br, x_shape)
cont_rand_factor = self.generate_rand_batch(self.cont_min, self.cont_max, self.check_rand_cont, x_shape)
sat_rand_factor = self.generate_rand_batch(self.sa_min, self.sa_max, self.check_rand_sa, x_shape)
r, g, b = self.unstack(x)
x_gray = C.repeat_elements(self.expand_dims((0.2989 * r + 0.587 * g + 0.114 * b), -1), rep=c, axis=-1)
x_gray = 0.2989 * r + 0.587 * g + 0.114 * b
x_gray_mean = self.expand_dims(self.mean(x_gray, (1, 2)) + 0.5, -1)
x_gray_mean = self.reshape(C.repeat_elements(x_gray_mean, rep=(h*w*c)), (bs, h, w, c))
x_gray = C.repeat_elements(self.expand_dims(x_gray, -1), rep=c, axis=-1)
# Apply brightness
x = self.mul(x, br_rand_factor)
x = C.clip_by_value(x, 0.0, 255.0)
x = self.mul(x, sa_rand_factor) + self.mul((1 - sa_rand_factor), x_gray)
# Apply contrast
x = self.mul(x, cont_rand_factor) + self.mul((1 - cont_rand_factor), x_gray_mean)
x = C.clip_by_value(x, 0.0, 255.0)
# Apply saturation
x = self.mul(x, sat_rand_factor) + self.mul((1 - sat_rand_factor), x_gray)
x = C.clip_by_value(x, 0.0, 255.0)
# Apply Hue Transform
# Convert tensor from rgb to hsv
r, g, b = self.unstack(x)
max_c, max_v = self.argmaxvalue(x)
_, min_v = self.argminvalue(x)
hsv_denum = max_v - min_v + self.epsilon
h1 = self.floormod(((b - g) * 60 / hsv_denum), 360)
h2 = (g - r) * 60 / hsv_denum + 120
h3 = (r - g) * 60 / hsv_denum + 240
h = self.stack((h1, h2, h3))
h = self.gatherd(h, 0, self.expand_dims(max_c, 0))
h = self.squeeze(h)
s = self.cast((max_v > 0), mstype.float32)
s = s * (1 - min_v / (max_v + self.epsilon))
v = self.cast(max_v, mstype.float32)
# Adjust hue
hue_rand_factor = self.generate_rand_batch(self.hue_min, self.hue_max, self.check_rand_hue, x_shape)
h = h + hue_rand_factor * 360.0
# Convert tensor from hsv to rgb
h_ = (h - self.floor(h / 360.0) * 360.0) / 60.0
c = self.mul(s, v)
x_ = self.mul(c, (1 - self.abs(self.fmod(h_, 2) - 1)))
zero_tensor = self.zeros_like(c)
y = self.stack((self.stack_axis_1((c, x_, zero_tensor)), self.stack_axis_1((x_, c, zero_tensor)),
self.stack_axis_1((zero_tensor, c, x_)), self.stack_axis_1((zero_tensor, x_, c)),
self.stack_axis_1((x_, zero_tensor, c)), self.stack_axis_1((c, zero_tensor, x_)),
))
index = self.expand_dims(self.floor(h_), 1)
index = self.expand_dims(C.repeat_elements(index, 3, 1), 0)
index = self.cast(index, mstype.int32)
x = self.squeeze(self.gatherd(y, 0, index))
x = x + self.reshape(C.repeat_elements((v - c), rep=(3)), self.shape(x))
return x