forked from mindspore-Ecosystem/mindspore
!8678 expand logsoftmax and grad, delete cast in softmax and fix layernorm compute dsl
From: @zengzitao Reviewed-by: @gaoxiong1,@ryanww Signed-off-by: @ryanww
This commit is contained in:
commit
3b946d4eb2
|
@ -29,3 +29,5 @@ from .maximum_grad import expand_maximumgrad
|
||||||
from .minimum_grad import expand_minimumgrad
|
from .minimum_grad import expand_minimumgrad
|
||||||
from .dropout_grad import expand_dropoutgrad
|
from .dropout_grad import expand_dropoutgrad
|
||||||
from .layernorm_grad import expand_layernormgrad
|
from .layernorm_grad import expand_layernormgrad
|
||||||
|
from .logsoftmax import expand_logsoftmax
|
||||||
|
from .logsoftmax_grad import expand_logsoftmaxgrad
|
||||||
|
|
|
@ -18,7 +18,6 @@ from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||||
|
|
||||||
def expand_layernorm(expand_info):
|
def expand_layernorm(expand_info):
|
||||||
"""LayerNorm expander"""
|
"""LayerNorm expander"""
|
||||||
|
|
||||||
# get op info.
|
# get op info.
|
||||||
input_desc_0 = expand_info['input_desc'][0]
|
input_desc_0 = expand_info['input_desc'][0]
|
||||||
input_desc_1 = expand_info['input_desc'][1]
|
input_desc_1 = expand_info['input_desc'][1]
|
||||||
|
@ -70,11 +69,8 @@ def expand_layernorm(expand_info):
|
||||||
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
|
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
|
||||||
epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format)
|
epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format)
|
||||||
normalize_add = graph_builder.emit('TensorAdd', [variance, epsilon_v])
|
normalize_add = graph_builder.emit('TensorAdd', [variance, epsilon_v])
|
||||||
normalize_log = graph_builder.emit('Log', [normalize_add])
|
normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add])
|
||||||
input_y = graph_builder.value(input_x.dtype, -0.5, input_x.data_format)
|
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
|
||||||
normalize_log_mul = graph_builder.emit('Mul', [normalize_log, input_y])
|
|
||||||
normalize_exp = graph_builder.emit('Exp', [normalize_log_mul])
|
|
||||||
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normalize_exp])
|
|
||||||
|
|
||||||
# Calculate scale and translate
|
# Calculate scale and translate
|
||||||
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
|
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for LogSoftmax"""
|
||||||
|
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||||
|
|
||||||
|
|
||||||
|
def expand_logsoftmax(expand_info):
|
||||||
|
"""LogSoftmax expander"""
|
||||||
|
# get op info.
|
||||||
|
input_desc = expand_info['input_desc'][0]
|
||||||
|
attrs = expand_info['attr']
|
||||||
|
axis = None
|
||||||
|
for item in attrs:
|
||||||
|
if 'axis' in item:
|
||||||
|
axis = item['axis']
|
||||||
|
graph_builder = builder.GraphBuilder()
|
||||||
|
if isinstance(axis, int):
|
||||||
|
axis = (axis,)
|
||||||
|
# generate a graph.
|
||||||
|
with graph_builder.graph_scope('main') as graph_scope:
|
||||||
|
# create tensor input.
|
||||||
|
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||||
|
graph_scope.set_input(input_x)
|
||||||
|
|
||||||
|
# cal logsoftmax.
|
||||||
|
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||||
|
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
||||||
|
data_exp = graph_builder.emit('Exp', [data_sub])
|
||||||
|
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||||
|
log_expsum = graph_builder.emit('Log', [data_expsum])
|
||||||
|
result = graph_builder.emit('Sub', [data_sub, log_expsum])
|
||||||
|
|
||||||
|
# set graph output.
|
||||||
|
graph_scope.set_output(result)
|
||||||
|
|
||||||
|
graph = graph_builder.get()[0]
|
||||||
|
return graph
|
|
@ -0,0 +1,50 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for LogSoftmaxGrad"""
|
||||||
|
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||||
|
|
||||||
|
|
||||||
|
def expand_logsoftmaxgrad(expand_info):
|
||||||
|
"""LogSoftmaxGrad expander"""
|
||||||
|
# get op info.
|
||||||
|
input_desc_0 = expand_info['input_desc'][0]
|
||||||
|
input_desc_1 = expand_info['input_desc'][1]
|
||||||
|
attrs = expand_info['attr']
|
||||||
|
axis = None
|
||||||
|
for item in attrs:
|
||||||
|
if 'axis' in item:
|
||||||
|
axis = item['axis']
|
||||||
|
graph_builder = builder.GraphBuilder()
|
||||||
|
|
||||||
|
if isinstance(axis, int):
|
||||||
|
axis = (axis,)
|
||||||
|
# generate a graph.
|
||||||
|
with graph_builder.graph_scope('main') as graph_scope:
|
||||||
|
# create tensor input.
|
||||||
|
input_logits = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||||
|
input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||||
|
graph_scope.set_input(input_logits, input_dy)
|
||||||
|
|
||||||
|
# cal logsoftmaxgrad.
|
||||||
|
softmax = graph_builder.emit('Exp', [input_logits])
|
||||||
|
dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||||
|
mul_result = graph_builder.emit('Mul', [softmax, dy_sum])
|
||||||
|
result = graph_builder.emit('Sub', [input_dy, mul_result])
|
||||||
|
|
||||||
|
# set graph output.
|
||||||
|
graph_scope.set_output(result)
|
||||||
|
|
||||||
|
graph = graph_builder.get()[0]
|
||||||
|
return graph
|
|
@ -18,7 +18,6 @@ from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||||
|
|
||||||
def expand_softmax(expand_info):
|
def expand_softmax(expand_info):
|
||||||
"""Softmax expander"""
|
"""Softmax expander"""
|
||||||
|
|
||||||
# get op info.
|
# get op info.
|
||||||
input_desc = expand_info['input_desc'][0]
|
input_desc = expand_info['input_desc'][0]
|
||||||
attrs = expand_info['attr']
|
attrs = expand_info['attr']
|
||||||
|
@ -33,13 +32,7 @@ def expand_softmax(expand_info):
|
||||||
# create tensor input.
|
# create tensor input.
|
||||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||||
# cal softmax.
|
# cal softmax.
|
||||||
|
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||||
if input_x.dtype == 'float32':
|
|
||||||
input_x_cast = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
|
|
||||||
max_x = graph_builder.emit('ReduceMax', [input_x_cast], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
||||||
max_x = graph_builder.emit('Cast', [max_x], attrs={'dst_type': 'float32'})
|
|
||||||
else:
|
|
||||||
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
||||||
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
||||||
data_exp = graph_builder.emit('Exp', [data_sub])
|
data_exp = graph_builder.emit('Exp', [data_sub])
|
||||||
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
|
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class LogSoftmax(nn.Cell):
|
||||||
|
def __init__(self, axis=1):
|
||||||
|
super(LogSoftmax, self).__init__()
|
||||||
|
self.logsoftmax = P.LogSoftmax(axis)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.logsoftmax(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Grad(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(Grad, self).__init__()
|
||||||
|
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, input_data, sens):
|
||||||
|
gout = self.grad(self.network)(input_data, sens)
|
||||||
|
return gout
|
||||||
|
|
||||||
|
|
||||||
|
def test_logsoftmax():
|
||||||
|
x = np.array([[-0.08082921, -0.13706027, -0.4711177, -0.05606057],
|
||||||
|
[-0.46082982, 1.1761844, -1.016654, -1.743829],
|
||||||
|
[-1.5062045, 0.6910976, 0.4839723, 1.1502692]]).astype(np.float32)
|
||||||
|
expect = np.array([[-1.2939762, -1.3502073, -1.6842647, -1.2692076],
|
||||||
|
[-1.9445671, -0.3075528, -2.5003912, -3.2275662],
|
||||||
|
[-3.452001, -1.2546989, -1.4618242, -0.79552734]]).astype(np.float32)
|
||||||
|
logSoftmax = LogSoftmax()
|
||||||
|
output = logSoftmax(Tensor(x))
|
||||||
|
assert np.allclose(output.asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
def test_logsoftmaxgrad():
|
||||||
|
x = np.array([[-0.47705367, 0.48267725, -1.0453935, 1.574488, 0.20362134, 0.4435456, -0.23984082, -0.43684655,
|
||||||
|
-0.7725506, 1.4481013],
|
||||||
|
[1.1012247, 1.7069651, 0.55062026, 0.3361901, -1.1082426, -0.5001939, -0.3255393, -0.7972024,
|
||||||
|
-0.27965206, -0.702805],
|
||||||
|
[0.19450496, 0.87596166, 0.6467245, -1.044987, 0.5248943, -2.6166635, 1.6719198, 0.06600758,
|
||||||
|
-0.4099178, 1.1861311],
|
||||||
|
[1.1305193, -1.97308, 2.1047623, -1.5105937, 0.93052036, 1.2467804, 0.5310002, 0.7084912, -1.3681422,
|
||||||
|
-0.9686862],
|
||||||
|
[1.871408, 0.14219497, -0.41050452, -0.749807, 1.4900619, -1.8172716, -0.73839617, 0.17565694,
|
||||||
|
-0.4553867, -1.5423119]]).astype(np.float32)
|
||||||
|
dy = np.array([[1.516363, -0.15196544, 0.598733, 0.64357865, 0.16265012, -1.3521105, 0.22621834, 0.7168259,
|
||||||
|
-0.6709239, 0.79757756],
|
||||||
|
[-0.32457778, 1.2831115, 1.1211495, -0.02665559, 1.9170904, -1.3397789, 1.4124829, -1.4298155,
|
||||||
|
0.758519, -0.25322974],
|
||||||
|
[-0.24226122, -1.2555921, 0.6492511, -0.34847677, 0.19916506, 0.628554, -0.19658111, 0.44939864,
|
||||||
|
-0.11677749, -1.2131723],
|
||||||
|
[0.24267715, 0.28106326, 1.1075432, -0.29006946, 0.31335673, 0.8833154, 0.13152207, 1.5482179,
|
||||||
|
0.29770762, -0.16246222],
|
||||||
|
[0.02145994, 0.80424, -0.95061, 1.5875458, -0.00308682, 0.17964548, 0.49912593, 0.46977136,
|
||||||
|
0.2151897, 0.30908248]]).astype(np.float32)
|
||||||
|
expect = np.array([[1.4219905, -0.39837134, 0.5452743, -0.09062839, -0.02375537, -1.5890603, 0.10658137, 0.6185817,
|
||||||
|
-0.7411523, 0.15054005],
|
||||||
|
[-0.94926417, 0.13830578, 0.7609547, -0.31733334, 1.8485254, -1.4657221, 1.2625053, -1.523396,
|
||||||
|
0.601499, -0.35607445],
|
||||||
|
[-0.14447737, -1.0622973, 0.80294746, -0.32016528, 0.33523226, 0.63443416, 0.23186903,
|
||||||
|
0.53539133, -0.0633494, -0.9495847],
|
||||||
|
[-0.36894822, 0.253609, -0.5127511, -0.33366728, -0.18740037, 0.19628316, -0.20430653, 1.1471655,
|
||||||
|
0.24743511, -0.23741922],
|
||||||
|
[-1.2582518, 0.57718843, -1.0812542, 1.4944922, -0.8770549, 0.1476463, 0.40500447, 0.23499368,
|
||||||
|
0.09027944, 0.26695627]]).astype(np.float32)
|
||||||
|
net = LogSoftmax()
|
||||||
|
dx = Grad(net)(Tensor(x), Tensor(dy))
|
||||||
|
assert np.allclose(dx[0].asnumpy(), expect)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_logsoftmax_gpu():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||||
|
test_logsoftmax()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_logsoftmaxgrad_gpu():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||||
|
test_logsoftmaxgrad()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_logsoftmax_asend():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||||
|
test_logsoftmax()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_logsoftmaxgrad_asend():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
|
||||||
|
test_logsoftmaxgrad()
|
Loading…
Reference in New Issue