!32235 [MD][Offload] Fixed Offload RandomColorAdjust Bug

Merge pull request !32235 from alashkari/fix-color-bug
This commit is contained in:
i-robot 2022-03-30 06:10:14 +00:00 committed by Gitee
commit c03f9c7135
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 41 additions and 42 deletions

View File

@ -104,6 +104,20 @@ def check_input_dims(x_shape, required_dim, offload_op_name):
(offload_op_name, required_dim, input_dim))
def assign_min_max_params(in_params, center=1):
"""
Adjust input parameters for ops.
"""
if isinstance(in_params, (list, tuple)):
min_param = in_params[0]
max_param = in_params[1]
else:
min_param = max(0, center - in_params)
max_param = center + in_params
return min_param, max_param
class ApplyPreTransform(nn.Cell):
"""
Concatenates offload model with network.
@ -215,7 +229,7 @@ class GenerateRandBatch(nn.Cell):
self.ones = P.Ones()
self.reshape = P.Reshape()
def __call__(self, degree_min, degree_max, check_rand, shape):
def construct(self, degree_min, degree_max, check_rand, shape):
bs, h, w, c = shape
rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
@ -235,33 +249,10 @@ class RandomColorAdjust(nn.Cell):
def __init__(self, brightness, contrast, saturation, hue):
super(RandomColorAdjust, self).__init__()
if isinstance(brightness, (list, tuple)):
self.br_min = brightness[0]
self.br_max = brightness[1]
else:
self.br_min = max(0, 1 - brightness)
self.br_max = 1 + brightness
if isinstance(contrast, (list, tuple)):
self.cont_min = contrast[0]
self.cont_max = contrast[1]
else:
self.cont_min = max(0, 1 - contrast)
self.cont_max = 1 + contrast
if isinstance(saturation, (list, tuple)):
self.sa_min = saturation[0]
self.sa_max = saturation[1]
else:
self.sa_min = max(0, 1 - saturation)
self.sa_max = 1 + saturation
if isinstance(hue, (list, tuple)):
self.hue_min = hue[0]
self.hue_max = hue[1]
else:
self.hue_min = max(0, 1 - hue)
self.hue_max = 1 + hue
self.br_min, self.br_max = assign_min_max_params(brightness)
self.cont_min, self.cont_max = assign_min_max_params(contrast)
self.sa_min, self.sa_max = assign_min_max_params(saturation)
self.hue_min, self.hue_max = assign_min_max_params(hue)
self.check_rand_br = Tensor(self.br_min == self.br_max)
self.check_rand_cont = Tensor(self.cont_min == self.cont_max)
@ -280,9 +271,17 @@ class RandomColorAdjust(nn.Cell):
self.argminvalue = P.ArgMinWithValue(axis=3, keep_dims=False)
self.stack = P.Stack(axis=0)
self.epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
self.squeeze_0 = P.Squeeze(axis=0)
self.squeeze = P.Squeeze(axis=0)
self.expand_dims = P.ExpandDims()
self.gatherd = P.GatherD()
self.floor = P.Floor()
self.fmod = P.FloorMod()
self.abs = P.Abs()
self.zeros_like = P.ZerosLike()
self.stack_axis_1 = P.Stack(axis=1)
self.transpose = P.Transpose()
self.ones = P.Ones()
self.reshape = P.Reshape()
self.generate_rand_batch = GenerateRandBatch()
@ -322,37 +321,37 @@ class RandomColorAdjust(nn.Cell):
max_c, max_v = self.argmaxvalue(x)
_, min_v = self.argminvalue(x)
hsv_denum = max_v - min_v + self.epsilon
h1 = self.floormod(((b - g) * 60 / hsv_denum), 360)
h1 = self.fmod(((b - g) * 60 / hsv_denum), 360)
h2 = (g - r) * 60 / hsv_denum + 120
h3 = (r - g) * 60 / hsv_denum + 240
h = self.stack((h1, h2, h3))
h = self.gatherd(h, 0, self.expand_dims(max_c, 0))
h = self.squeeze(h)
s = self.cast((max_v > 0), mstype.float32)
s = s * (1 - min_v / (max_v + self.epsilon))
hue = self.squeeze(self.gatherd(self.stack((h1, h2, h3)), 0, self.expand_dims(max_c, 0)))
s = self.cast((max_v > 0), mstype.float32) * (1 - min_v / (max_v + self.epsilon))
v = self.cast(max_v, mstype.float32)
# Adjust hue
hue_rand_factor = self.generate_rand_batch(self.hue_min, self.hue_max, self.check_rand_hue, x_shape)
h = h + hue_rand_factor * 360.0
hue_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
hue_rand_factor = self.hue_min + (self.hue_max - self.hue_min)*hue_rand_factor
degree_factor = self.hue_min * self.ones((bs, 1), mstype.float32)
hue_rand_factor = (self.check_rand_hue * degree_factor) + (~self.check_rand_hue * hue_rand_factor)
hue_rand_factor = self.reshape(C.repeat_elements(hue_rand_factor, rep=(h*w)), (bs, h, w))
hue = hue + (hue_rand_factor * 360.0)
# Convert tensor from hsv to rgb
h_ = (h - self.floor(h / 360.0) * 360.0) / 60.0
h_ = (hue - self.floor(hue / 360.0) * 360.0) / 60.0
c = self.mul(s, v)
x_ = self.mul(c, (1 - self.abs(self.fmod(h_, 2) - 1)))
zero_tensor = self.zeros_like(c)
y = self.stack((self.stack_axis_1((c, x_, zero_tensor)), self.stack_axis_1((x_, c, zero_tensor)),
self.stack_axis_1((zero_tensor, c, x_)), self.stack_axis_1((zero_tensor, x_, c)),
self.stack_axis_1((x_, zero_tensor, c)), self.stack_axis_1((c, zero_tensor, x_)),
))
index = self.expand_dims(self.floor(h_), 1)
index = self.expand_dims(C.repeat_elements(index, 3, 1), 0)
index = self.cast(index, mstype.int32)
x = self.squeeze(self.gatherd(y, 0, index))
x = x + self.reshape(C.repeat_elements((v - c), rep=(3)), self.shape(x))
x = self.transpose(x, (0, 2, 3, 1)) * 255.0
x = C.clip_by_value(x, 0.0, 255.0)
return x
@ -496,7 +495,7 @@ op_to_model = {
"HWC2CHW": OffloadModel(HwcToChw),
"HwcToChw": OffloadModel(HwcToChw),
"Normalize": OffloadModel(Normalize, ["mean", "std"]),
"RandomColorAdjust": OffloadModel(RandomColorAdjust, ["brightness", "saturation"]),
"RandomColorAdjust": OffloadModel(RandomColorAdjust, ["brightness", "contrast", "saturation", "hue"]),
"RandomHorizontalFlip": OffloadModel(RandomHorizontalFlip, ["prob"]),
"RandomSharpness": OffloadModel(RandomSharpness, ["degrees"]),
"RandomVerticalFlip": OffloadModel(RandomVerticalFlip, ["prob"]),