forked from mindspore-Ecosystem/mindspore
!15415 [GraphKernel]adapt for logsoftmax in ascend
From: @wenfangpei Reviewed-by: @gaoxiong1,@ckey_dou,@gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou,@ckey_dou
This commit is contained in:
commit
ed539597c2
|
@ -17,7 +17,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||||
|
|
||||||
|
|
||||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
@VLD.add_format(DF.DEFAULT)
|
||||||
@VLD.check_attrs('axis')
|
@VLD.check_attrs('axis')
|
||||||
class LogSoftmax(Expander):
|
class LogSoftmax(Expander):
|
||||||
"""LogSoftmax expander"""
|
"""LogSoftmax expander"""
|
||||||
|
@ -25,9 +25,17 @@ class LogSoftmax(Expander):
|
||||||
def _expand(self, graph_builder):
|
def _expand(self, graph_builder):
|
||||||
input_x = self.inputs[0]
|
input_x = self.inputs[0]
|
||||||
axis = self.attrs['axis']
|
axis = self.attrs['axis']
|
||||||
|
processor = self.processor
|
||||||
|
|
||||||
if isinstance(axis, int):
|
if isinstance(axis, int):
|
||||||
axis = (axis,)
|
axis = (axis,)
|
||||||
|
|
||||||
|
ori_dtype = input_x.dtype
|
||||||
|
if ori_dtype != "float16" and processor == "aicore":
|
||||||
|
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
|
||||||
|
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||||
|
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
|
||||||
|
else:
|
||||||
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||||
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
||||||
data_exp = graph_builder.emit('Exp', [data_sub])
|
data_exp = graph_builder.emit('Exp', [data_sub])
|
||||||
|
|
|
@ -53,6 +53,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
|
||||||
prim::kPrimBiasAddGrad,
|
prim::kPrimBiasAddGrad,
|
||||||
prim::kPrimGeLU,
|
prim::kPrimGeLU,
|
||||||
prim::kPrimSoftmax,
|
prim::kPrimSoftmax,
|
||||||
|
prim::kPrimLogSoftmax,
|
||||||
|
prim::kPrimLogSoftmaxGrad,
|
||||||
prim::kPrimTile,
|
prim::kPrimTile,
|
||||||
#if ENABLE_D
|
#if ENABLE_D
|
||||||
prim::kPrimSqrtGrad,
|
prim::kPrimSqrtGrad,
|
||||||
|
|
|
@ -106,12 +106,18 @@ def test_logsoftmaxgrad_gpu():
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||||
test_logsoftmaxgrad()
|
test_logsoftmaxgrad()
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
def test_logsoftmax_asend():
|
def test_logsoftmax_asend():
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||||
test_logsoftmax()
|
test_logsoftmax()
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
def test_logsoftmaxgrad_asend():
|
def test_logsoftmaxgrad_asend():
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||||
test_logsoftmaxgrad()
|
test_logsoftmaxgrad()
|
||||||
|
|
Loading…
Reference in New Issue