forked from mindspore-Ecosystem/mindspore
!15741 [GraphKernel] batchnorm expander supports when first input is float16
From: @looop5 Reviewed-by: @gaoxiong1,@dylangeng Signed-off-by: @dylangeng
This commit is contained in:
commit
52e7f51970
|
@ -33,6 +33,14 @@ class BatchNorm(Expander):
|
|||
input_variance = self.inputs[4]
|
||||
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
|
||||
|
||||
input_x_ori_type = input_x.dtype
|
||||
input_x_new_type = input_x.dtype
|
||||
if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
|
||||
input_mean.dtype == "float32" and input_variance.dtype == "float32":
|
||||
input_x_new_type = "float32"
|
||||
if input_x_new_type != input_x_ori_type:
|
||||
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
|
||||
|
||||
if self.attrs['is_training']:
|
||||
reduce_axis = ()
|
||||
shape_x = input_x.shape
|
||||
|
@ -109,7 +117,8 @@ class BatchNorm(Expander):
|
|||
variance_res = graph_builder.emit(
|
||||
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
|
||||
attrs={'fake_output': True})
|
||||
|
||||
if input_x_new_type != input_x_ori_type:
|
||||
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
||||
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
||||
# infer mode
|
||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||
|
@ -128,4 +137,6 @@ class BatchNorm(Expander):
|
|||
'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])})
|
||||
x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
|
||||
res_y = graph_builder.emit('Add', [input_offset, x_div])
|
||||
if input_x_new_type != input_x_ori_type:
|
||||
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
||||
return res_y, var_add, var_add, var_add, var_add
|
||||
|
|
Loading…
Reference in New Issue