forked from mindspore-Ecosystem/mindspore
!17680 Avoid overflow of in realdiv
From: @wenfangpei Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
7e5d68991a
|
@ -15,6 +15,7 @@
|
|||
"""generate json desc for LambApplyWeightAssign"""
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
@VLD.check_all_formats_same
|
||||
class LambApplyWeightAssign(Expander):
|
||||
"""LambApplyWeightAssign expander"""
|
||||
|
@ -23,28 +24,41 @@ class LambApplyWeightAssign(Expander):
|
|||
|
||||
w_norm, g_norm, input_lr, update, input_param = self.inputs
|
||||
# ratio
|
||||
const_zero = graph_builder.value(g_norm.dtype, 0)
|
||||
const_one = graph_builder.value(g_norm.dtype, 1)
|
||||
dtype = update.dtype
|
||||
dtype = g_norm.dtype
|
||||
|
||||
if dtype == "float32":
|
||||
data_min = graph_builder.value(dtype, 2**(-126))
|
||||
elif dtype == "float16":
|
||||
data_min = graph_builder.value(dtype, 2*(-24))
|
||||
else:
|
||||
raise ValueError("Only support float32 and float16, but input type is : {}!".format(dtype))
|
||||
|
||||
const_zero = graph_builder.value(dtype, 0)
|
||||
const_one = graph_builder.value(dtype, 1)
|
||||
|
||||
# w_norm >= 0, g_norm >= 0
|
||||
# ratio = select(greater(w_norm, 0), select(greater(g_norm, 0), w_norm/g_norm, 1), 1)
|
||||
# cal ratio
|
||||
g_norm_greater_res = graph_builder.emit('Greater', [g_norm, const_zero])
|
||||
g_norm_greater_res_float = graph_builder.emit('Cast', [g_norm_greater_res], attrs={'dst_type': dtype})
|
||||
g_norm_res = graph_builder.emit('Cast', [g_norm_greater_res], attrs={'dst_type': dtype})
|
||||
|
||||
g_norm = graph_builder.emit('Add', [g_norm, data_min])
|
||||
w_norm_g_norm = graph_builder.emit('RealDiv', [w_norm, g_norm])
|
||||
|
||||
g_norm_value_1 = graph_builder.emit('Mul', [g_norm_res, w_norm_g_norm])
|
||||
# select
|
||||
g_norm_greater_res_neg = graph_builder.emit('Neg', [g_norm_greater_res_float])
|
||||
g_norm_greater_res_f = graph_builder.emit('Add', [g_norm_greater_res_neg, const_one])
|
||||
g_norm_value_1 = graph_builder.emit('Mul', [g_norm_greater_res_float, w_norm_g_norm])
|
||||
g_norm_value = graph_builder.emit('Add', [g_norm_value_1, g_norm_greater_res_f])
|
||||
g_norm_res_neg = graph_builder.emit('Neg', [g_norm_res])
|
||||
g_norm_res_f = graph_builder.emit('Add', [g_norm_res_neg, const_one])
|
||||
g_norm_value = graph_builder.emit('Add', [g_norm_value_1, g_norm_res_f])
|
||||
|
||||
w_norm_greater_res = graph_builder.emit('Greater', [w_norm, const_zero])
|
||||
w_norm_greater_res_float = graph_builder.emit('Cast', [w_norm_greater_res], attrs={'dst_type': dtype})
|
||||
w_norm_res = graph_builder.emit('Cast', [w_norm_greater_res], attrs={'dst_type': dtype})
|
||||
|
||||
w_norm_value_1 = graph_builder.emit('Mul', [w_norm_res, g_norm_value])
|
||||
# select
|
||||
w_norm_greater_res_neg = graph_builder.emit('Neg', [w_norm_greater_res_float])
|
||||
w_norm_greater_res_f = graph_builder.emit('Add', [w_norm_greater_res_neg, const_one])
|
||||
w_norm_value_1 = graph_builder.emit('Mul', [w_norm_greater_res_float, g_norm_value])
|
||||
ratio = graph_builder.emit('Add', [w_norm_value_1, w_norm_greater_res_f])
|
||||
w_norm_res_neg = graph_builder.emit('Neg', [w_norm_res])
|
||||
w_norm_res_f = graph_builder.emit('Add', [w_norm_res_neg, const_one])
|
||||
ratio = graph_builder.emit('Add', [w_norm_value_1, w_norm_res_f])
|
||||
|
||||
# ratio * input_lr * update
|
||||
update_with_ir = graph_builder.emit('Mul', [update, input_lr])
|
||||
|
|
Loading…
Reference in New Issue