forked from mindspore-Ecosystem/mindspore
declare fp32 and than cast to fp16 in expander
This commit is contained in:
parent
1bd0ed450c
commit
80f071e9fa
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit ef1a6b06781035540023819afead4bdfbd49af81
|
||||
Subproject commit 6d01f5e364224da0d58b8b20761b9af67587950b
|
|
@ -36,9 +36,6 @@ def expand_gelu(expand_info):
|
|||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
graph_scope.set_input(input_x)
|
||||
dtype = input_x.dtype
|
||||
if dtype == 'float16':
|
||||
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
|
||||
|
||||
# cal y
|
||||
mul_0 = graph_builder.emit('Mul', [input_x, input_x])
|
||||
|
@ -58,8 +55,6 @@ def expand_gelu(expand_info):
|
|||
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
|
||||
result = graph_builder.emit('Mul', [const_half, mul_x])
|
||||
|
||||
if dtype == 'float16':
|
||||
result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'})
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
|
|
|
@ -43,9 +43,6 @@ def expand_gelugrad(expand_info):
|
|||
input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
graph_scope.set_input(input_dy, input_x, input_y)
|
||||
dtype = input_dy.dtype
|
||||
if dtype == 'float16':
|
||||
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
|
||||
|
||||
# create some const var
|
||||
const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format'])
|
||||
|
@ -83,8 +80,6 @@ def expand_gelugrad(expand_info):
|
|||
result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final])
|
||||
result = graph_builder.emit('Mul', [input_dy, result_tmp])
|
||||
|
||||
if dtype == 'float16':
|
||||
result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'})
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
|
|
|
@ -149,7 +149,15 @@ class GraphBuilder:
|
|||
"""Create a new Value"""
|
||||
if name in (None, ''):
|
||||
name = self._alloc_tensor_name()
|
||||
return Value(name, dtype, value, data_format)
|
||||
|
||||
if dtype == "float16":
|
||||
# For float16 value, it will be changed to float32 wrongly. And there is no good solution for now.
|
||||
# So instead just declare float32 value and then cast it to float16.
|
||||
v_fp32 = Value(name, "float32", value, data_format)
|
||||
v = self.emit("Cast", [v_fp32], attrs={"dst_type": "float16"})
|
||||
else:
|
||||
v = Value(name, dtype, value, data_format)
|
||||
return v
|
||||
|
||||
def op(self, prim, output, inputs, attrs=None):
|
||||
"""Insert an operator into graph"""
|
||||
|
@ -166,9 +174,9 @@ class GraphBuilder:
|
|||
"""Emit a new operation"""
|
||||
if attrs is None:
|
||||
attrs = {}
|
||||
if isinstance(inputs, Tensor):
|
||||
if isinstance(inputs, (Tensor, Value)):
|
||||
inputs = [inputs]
|
||||
tensor_inputs = [t for t in inputs if isinstance(t, Tensor)]
|
||||
tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
|
||||
out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs)
|
||||
output = self.tensor(out_shape, out_dtype, out_format, name)
|
||||
self.op(prim, output, inputs, attrs)
|
||||
|
|
Loading…
Reference in New Issue