!15741 [GraphKernel] batchnorm expander supports when first input is float16

From: @looop5
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
This commit is contained in:
mindspore-ci-bot 2021-04-27 16:27:29 +08:00 committed by Gitee
commit 52e7f51970
1 changed files with 12 additions and 1 deletions

View File

@ -33,6 +33,14 @@ class BatchNorm(Expander):
input_variance = self.inputs[4]
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
input_x_ori_type = input_x.dtype
input_x_new_type = input_x.dtype
if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
input_mean.dtype == "float32" and input_variance.dtype == "float32":
input_x_new_type = "float32"
if input_x_new_type != input_x_ori_type:
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
if self.attrs['is_training']:
reduce_axis = ()
shape_x = input_x.shape
@ -109,7 +117,8 @@ class BatchNorm(Expander):
variance_res = graph_builder.emit(
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
attrs={'fake_output': True})
if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
# infer mode
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
@ -128,4 +137,6 @@ class BatchNorm(Expander):
'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])})
x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
res_y = graph_builder.emit('Add', [input_offset, x_div])
if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, var_add, var_add, var_add, var_add