!17680 Avoid overflow of in realdiv

From: @wenfangpei
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-06-08 10:39:33 +08:00 committed by Gitee
commit 7e5d68991a
1 changed files with 27 additions and 13 deletions

View File

@ -15,6 +15,7 @@
"""generate json desc for LambApplyWeightAssign"""
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same
class LambApplyWeightAssign(Expander):
"""LambApplyWeightAssign expander"""
@ -23,28 +24,41 @@ class LambApplyWeightAssign(Expander):
w_norm, g_norm, input_lr, update, input_param = self.inputs
# ratio
const_zero = graph_builder.value(g_norm.dtype, 0)
const_one = graph_builder.value(g_norm.dtype, 1)
dtype = update.dtype
dtype = g_norm.dtype
if dtype == "float32":
data_min = graph_builder.value(dtype, 2**(-126))
elif dtype == "float16":
data_min = graph_builder.value(dtype, 2*(-24))
else:
raise ValueError("Only support float32 and float16, but input type is : {}!".format(dtype))
const_zero = graph_builder.value(dtype, 0)
const_one = graph_builder.value(dtype, 1)
# w_norm >= 0, g_norm >= 0
# ratio = select(greater(w_norm, 0), select(greater(g_norm, 0), w_norm/g_norm, 1), 1)
# cal ratio
g_norm_greater_res = graph_builder.emit('Greater', [g_norm, const_zero])
g_norm_greater_res_float = graph_builder.emit('Cast', [g_norm_greater_res], attrs={'dst_type': dtype})
g_norm_res = graph_builder.emit('Cast', [g_norm_greater_res], attrs={'dst_type': dtype})
g_norm = graph_builder.emit('Add', [g_norm, data_min])
w_norm_g_norm = graph_builder.emit('RealDiv', [w_norm, g_norm])
g_norm_value_1 = graph_builder.emit('Mul', [g_norm_res, w_norm_g_norm])
# select
g_norm_greater_res_neg = graph_builder.emit('Neg', [g_norm_greater_res_float])
g_norm_greater_res_f = graph_builder.emit('Add', [g_norm_greater_res_neg, const_one])
g_norm_value_1 = graph_builder.emit('Mul', [g_norm_greater_res_float, w_norm_g_norm])
g_norm_value = graph_builder.emit('Add', [g_norm_value_1, g_norm_greater_res_f])
g_norm_res_neg = graph_builder.emit('Neg', [g_norm_res])
g_norm_res_f = graph_builder.emit('Add', [g_norm_res_neg, const_one])
g_norm_value = graph_builder.emit('Add', [g_norm_value_1, g_norm_res_f])
w_norm_greater_res = graph_builder.emit('Greater', [w_norm, const_zero])
w_norm_greater_res_float = graph_builder.emit('Cast', [w_norm_greater_res], attrs={'dst_type': dtype})
w_norm_res = graph_builder.emit('Cast', [w_norm_greater_res], attrs={'dst_type': dtype})
w_norm_value_1 = graph_builder.emit('Mul', [w_norm_res, g_norm_value])
# select
w_norm_greater_res_neg = graph_builder.emit('Neg', [w_norm_greater_res_float])
w_norm_greater_res_f = graph_builder.emit('Add', [w_norm_greater_res_neg, const_one])
w_norm_value_1 = graph_builder.emit('Mul', [w_norm_greater_res_float, g_norm_value])
ratio = graph_builder.emit('Add', [w_norm_value_1, w_norm_greater_res_f])
w_norm_res_neg = graph_builder.emit('Neg', [w_norm_res])
w_norm_res_f = graph_builder.emit('Add', [w_norm_res_neg, const_one])
ratio = graph_builder.emit('Add', [w_norm_value_1, w_norm_res_f])
# ratio * input_lr * update
update_with_ir = graph_builder.emit('Mul', [update, input_lr])