forked from mindspore-Ecosystem/mindspore
!32235 [MD][Offload] Fixed Offload RandomColorAdjust Bug
Merge pull request !32235 from alashkari/fix-color-bug
This commit is contained in:
commit
c03f9c7135
|
@ -104,6 +104,20 @@ def check_input_dims(x_shape, required_dim, offload_op_name):
|
|||
(offload_op_name, required_dim, input_dim))
|
||||
|
||||
|
||||
def assign_min_max_params(in_params, center=1):
|
||||
"""
|
||||
Adjust input parameters for ops.
|
||||
"""
|
||||
if isinstance(in_params, (list, tuple)):
|
||||
min_param = in_params[0]
|
||||
max_param = in_params[1]
|
||||
else:
|
||||
min_param = max(0, center - in_params)
|
||||
max_param = center + in_params
|
||||
|
||||
return min_param, max_param
|
||||
|
||||
|
||||
class ApplyPreTransform(nn.Cell):
|
||||
"""
|
||||
Concatenates offload model with network.
|
||||
|
@ -215,7 +229,7 @@ class GenerateRandBatch(nn.Cell):
|
|||
self.ones = P.Ones()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def __call__(self, degree_min, degree_max, check_rand, shape):
|
||||
def construct(self, degree_min, degree_max, check_rand, shape):
|
||||
|
||||
bs, h, w, c = shape
|
||||
rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
||||
|
@ -235,33 +249,10 @@ class RandomColorAdjust(nn.Cell):
|
|||
def __init__(self, brightness, contrast, saturation, hue):
|
||||
super(RandomColorAdjust, self).__init__()
|
||||
|
||||
if isinstance(brightness, (list, tuple)):
|
||||
self.br_min = brightness[0]
|
||||
self.br_max = brightness[1]
|
||||
else:
|
||||
self.br_min = max(0, 1 - brightness)
|
||||
self.br_max = 1 + brightness
|
||||
|
||||
if isinstance(contrast, (list, tuple)):
|
||||
self.cont_min = contrast[0]
|
||||
self.cont_max = contrast[1]
|
||||
else:
|
||||
self.cont_min = max(0, 1 - contrast)
|
||||
self.cont_max = 1 + contrast
|
||||
|
||||
if isinstance(saturation, (list, tuple)):
|
||||
self.sa_min = saturation[0]
|
||||
self.sa_max = saturation[1]
|
||||
else:
|
||||
self.sa_min = max(0, 1 - saturation)
|
||||
self.sa_max = 1 + saturation
|
||||
|
||||
if isinstance(hue, (list, tuple)):
|
||||
self.hue_min = hue[0]
|
||||
self.hue_max = hue[1]
|
||||
else:
|
||||
self.hue_min = max(0, 1 - hue)
|
||||
self.hue_max = 1 + hue
|
||||
self.br_min, self.br_max = assign_min_max_params(brightness)
|
||||
self.cont_min, self.cont_max = assign_min_max_params(contrast)
|
||||
self.sa_min, self.sa_max = assign_min_max_params(saturation)
|
||||
self.hue_min, self.hue_max = assign_min_max_params(hue)
|
||||
|
||||
self.check_rand_br = Tensor(self.br_min == self.br_max)
|
||||
self.check_rand_cont = Tensor(self.cont_min == self.cont_max)
|
||||
|
@ -280,9 +271,17 @@ class RandomColorAdjust(nn.Cell):
|
|||
self.argminvalue = P.ArgMinWithValue(axis=3, keep_dims=False)
|
||||
self.stack = P.Stack(axis=0)
|
||||
self.epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
||||
self.squeeze_0 = P.Squeeze(axis=0)
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.gatherd = P.GatherD()
|
||||
self.floor = P.Floor()
|
||||
self.fmod = P.FloorMod()
|
||||
self.abs = P.Abs()
|
||||
self.zeros_like = P.ZerosLike()
|
||||
self.stack_axis_1 = P.Stack(axis=1)
|
||||
self.transpose = P.Transpose()
|
||||
self.ones = P.Ones()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
self.generate_rand_batch = GenerateRandBatch()
|
||||
|
||||
|
@ -322,37 +321,37 @@ class RandomColorAdjust(nn.Cell):
|
|||
max_c, max_v = self.argmaxvalue(x)
|
||||
_, min_v = self.argminvalue(x)
|
||||
hsv_denum = max_v - min_v + self.epsilon
|
||||
h1 = self.floormod(((b - g) * 60 / hsv_denum), 360)
|
||||
h1 = self.fmod(((b - g) * 60 / hsv_denum), 360)
|
||||
h2 = (g - r) * 60 / hsv_denum + 120
|
||||
h3 = (r - g) * 60 / hsv_denum + 240
|
||||
h = self.stack((h1, h2, h3))
|
||||
h = self.gatherd(h, 0, self.expand_dims(max_c, 0))
|
||||
h = self.squeeze(h)
|
||||
s = self.cast((max_v > 0), mstype.float32)
|
||||
s = s * (1 - min_v / (max_v + self.epsilon))
|
||||
hue = self.squeeze(self.gatherd(self.stack((h1, h2, h3)), 0, self.expand_dims(max_c, 0)))
|
||||
s = self.cast((max_v > 0), mstype.float32) * (1 - min_v / (max_v + self.epsilon))
|
||||
v = self.cast(max_v, mstype.float32)
|
||||
|
||||
# Adjust hue
|
||||
hue_rand_factor = self.generate_rand_batch(self.hue_min, self.hue_max, self.check_rand_hue, x_shape)
|
||||
h = h + hue_rand_factor * 360.0
|
||||
hue_rand_factor = Tensor(np.random.uniform(size=(bs, 1)), dtype=mstype.float32)
|
||||
hue_rand_factor = self.hue_min + (self.hue_max - self.hue_min)*hue_rand_factor
|
||||
degree_factor = self.hue_min * self.ones((bs, 1), mstype.float32)
|
||||
hue_rand_factor = (self.check_rand_hue * degree_factor) + (~self.check_rand_hue * hue_rand_factor)
|
||||
hue_rand_factor = self.reshape(C.repeat_elements(hue_rand_factor, rep=(h*w)), (bs, h, w))
|
||||
hue = hue + (hue_rand_factor * 360.0)
|
||||
|
||||
# Convert tensor from hsv to rgb
|
||||
h_ = (h - self.floor(h / 360.0) * 360.0) / 60.0
|
||||
h_ = (hue - self.floor(hue / 360.0) * 360.0) / 60.0
|
||||
c = self.mul(s, v)
|
||||
x_ = self.mul(c, (1 - self.abs(self.fmod(h_, 2) - 1)))
|
||||
zero_tensor = self.zeros_like(c)
|
||||
|
||||
y = self.stack((self.stack_axis_1((c, x_, zero_tensor)), self.stack_axis_1((x_, c, zero_tensor)),
|
||||
self.stack_axis_1((zero_tensor, c, x_)), self.stack_axis_1((zero_tensor, x_, c)),
|
||||
self.stack_axis_1((x_, zero_tensor, c)), self.stack_axis_1((c, zero_tensor, x_)),
|
||||
))
|
||||
|
||||
index = self.expand_dims(self.floor(h_), 1)
|
||||
index = self.expand_dims(C.repeat_elements(index, 3, 1), 0)
|
||||
index = self.cast(index, mstype.int32)
|
||||
|
||||
x = self.squeeze(self.gatherd(y, 0, index))
|
||||
x = x + self.reshape(C.repeat_elements((v - c), rep=(3)), self.shape(x))
|
||||
x = self.transpose(x, (0, 2, 3, 1)) * 255.0
|
||||
x = C.clip_by_value(x, 0.0, 255.0)
|
||||
|
||||
return x
|
||||
|
||||
|
@ -496,7 +495,7 @@ op_to_model = {
|
|||
"HWC2CHW": OffloadModel(HwcToChw),
|
||||
"HwcToChw": OffloadModel(HwcToChw),
|
||||
"Normalize": OffloadModel(Normalize, ["mean", "std"]),
|
||||
"RandomColorAdjust": OffloadModel(RandomColorAdjust, ["brightness", "saturation"]),
|
||||
"RandomColorAdjust": OffloadModel(RandomColorAdjust, ["brightness", "contrast", "saturation", "hue"]),
|
||||
"RandomHorizontalFlip": OffloadModel(RandomHorizontalFlip, ["prob"]),
|
||||
"RandomSharpness": OffloadModel(RandomSharpness, ["degrees"]),
|
||||
"RandomVerticalFlip": OffloadModel(RandomVerticalFlip, ["prob"]),
|
||||
|
|
Loading…
Reference in New Issue