!15415 [GraphKernel]adapt for logsoftmax in ascend

From: @wenfangpei
Reviewed-by: @gaoxiong1,@ckey_dou,@gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou,@ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-05-07 11:19:58 +08:00 committed by Gitee
commit ed539597c2
3 changed files with 20 additions and 4 deletions

View File

@ -17,7 +17,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT)
@VLD.check_attrs('axis') @VLD.check_attrs('axis')
class LogSoftmax(Expander): class LogSoftmax(Expander):
"""LogSoftmax expander""" """LogSoftmax expander"""
@ -25,9 +25,17 @@ class LogSoftmax(Expander):
def _expand(self, graph_builder): def _expand(self, graph_builder):
input_x = self.inputs[0] input_x = self.inputs[0]
axis = self.attrs['axis'] axis = self.attrs['axis']
processor = self.processor
if isinstance(axis, int): if isinstance(axis, int):
axis = (axis,) axis = (axis,)
ori_dtype = input_x.dtype
if ori_dtype != "float16" and processor == "aicore":
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True})
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
else:
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
data_sub = graph_builder.emit('Sub', [input_x, max_x]) data_sub = graph_builder.emit('Sub', [input_x, max_x])
data_exp = graph_builder.emit('Exp', [data_sub]) data_exp = graph_builder.emit('Exp', [data_sub])

View File

@ -53,6 +53,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimBiasAddGrad, prim::kPrimBiasAddGrad,
prim::kPrimGeLU, prim::kPrimGeLU,
prim::kPrimSoftmax, prim::kPrimSoftmax,
prim::kPrimLogSoftmax,
prim::kPrimLogSoftmaxGrad,
prim::kPrimTile, prim::kPrimTile,
#if ENABLE_D #if ENABLE_D
prim::kPrimSqrtGrad, prim::kPrimSqrtGrad,

View File

@ -106,12 +106,18 @@ def test_logsoftmaxgrad_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_logsoftmaxgrad() test_logsoftmaxgrad()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_logsoftmax_asend(): def test_logsoftmax_asend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_logsoftmax() test_logsoftmax()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_logsoftmaxgrad_asend(): def test_logsoftmaxgrad_asend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_logsoftmaxgrad() test_logsoftmaxgrad()