adapt for logsoftmax in ascend

This commit is contained in:
wenfangpei 2021-04-20 10:27:10 +08:00
parent 1827697642
commit db8256e61f
3 changed files with 20 additions and 4 deletions

View File

@ -17,7 +17,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT)
@VLD.check_attrs('axis')
class LogSoftmax(Expander):
"""LogSoftmax expander"""
@ -25,10 +25,18 @@ class LogSoftmax(Expander):
def _expand(self, graph_builder):
input_x = self.inputs[0]
axis = self.attrs['axis']
processor = self.processor
if isinstance(axis, int):
axis = (axis,)
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
ori_dtype = input_x.dtype
if ori_dtype != "float16" and processor == "aicore":
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True})
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
else:
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
data_sub = graph_builder.emit('Sub', [input_x, max_x])
data_exp = graph_builder.emit('Exp', [data_sub])
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})

View File

@ -53,6 +53,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimBiasAddGrad,
prim::kPrimGeLU,
prim::kPrimSoftmax,
prim::kPrimLogSoftmax,
prim::kPrimLogSoftmaxGrad,
prim::kPrimTile,
#if ENABLE_D
prim::kPrimSqrtGrad,

View File

@ -106,12 +106,18 @@ def test_logsoftmaxgrad_gpu():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
test_logsoftmaxgrad()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_logsoftmax_asend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_logsoftmax()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_logsoftmaxgrad_asend():
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
test_logsoftmaxgrad()