forked from mindspore-Ecosystem/mindspore
!15415 [GraphKernel]adapt for logsoftmax in ascend
From: @wenfangpei Reviewed-by: @gaoxiong1,@ckey_dou,@gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou,@ckey_dou
This commit is contained in:
commit
ed539597c2
|
@ -17,7 +17,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.add_format(DF.DEFAULT)
|
||||
@VLD.check_attrs('axis')
|
||||
class LogSoftmax(Expander):
|
||||
"""LogSoftmax expander"""
|
||||
|
@ -25,10 +25,18 @@ class LogSoftmax(Expander):
|
|||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
axis = self.attrs['axis']
|
||||
processor = self.processor
|
||||
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
|
||||
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||
ori_dtype = input_x.dtype
|
||||
if ori_dtype != "float16" and processor == "aicore":
|
||||
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
|
||||
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
|
||||
else:
|
||||
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
||||
data_exp = graph_builder.emit('Exp', [data_sub])
|
||||
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||
|
|
|
@ -53,6 +53,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
|||
prim::kPrimBiasAddGrad,
|
||||
prim::kPrimGeLU,
|
||||
prim::kPrimSoftmax,
|
||||
prim::kPrimLogSoftmax,
|
||||
prim::kPrimLogSoftmaxGrad,
|
||||
prim::kPrimTile,
|
||||
#if ENABLE_D
|
||||
prim::kPrimSqrtGrad,
|
||||
|
|
|
@ -106,12 +106,18 @@ def test_logsoftmaxgrad_gpu():
|
|||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
test_logsoftmaxgrad()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmax_asend():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||
test_logsoftmax()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmaxgrad_asend():
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||
test_logsoftmaxgrad()
|
||||
|
|
Loading…
Reference in New Issue