!15458 [GraphKernel]expanders of some fusion ops

From: @wenfangpei
Reviewed-by: @ckey_dou,@gaoxiong1
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-05-07 09:31:29 +08:00 committed by Gitee
commit 71ab230aa0
5 changed files with 116 additions and 0 deletions

View File

@ -49,3 +49,5 @@ from .tanh_grad import TanhGrad
from .tile import Tile
from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign
from .lamb_apply_weight_assign import LambApplyWeightAssign
from .softmax_grad_ext import SoftmaxGradExt
from .square_sum_v1 import SquareSumV1

View File

@ -0,0 +1,57 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for SoftmaxGradExt"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis
@VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('axis')
class SoftmaxGradExt(Expander):
"""SoftmaxGradExt expander"""
def _expand(self, graph_builder):
x, y, z = self.inputs
axis = self.attrs['axis']
ori_shape = x.shape
if x.data_format == DF.FRAC_NZ:
ori_shape = infer_shape_from_fractalNz(ori_shape)
if not axis:
axis = []
for i, _ in enumerate(ori_shape):
axis.append(i)
else:
if isinstance(axis, int):
axis = [axis]
for i, _ in enumerate(list(axis)):
if axis[i] < 0:
axis[i] += len(ori_shape)
ori_reduced_shape = ori_shape
if x.data_format == DF.FRAC_NZ:
ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis)
axis = to_frac_z_axis(ori_shape, axis)
data_mul = graph_builder.emit('Mul', [x, y])
data_sum = graph_builder.emit('ReduceSum', [data_mul],
attrs={'reduce_axis': axis, 'keep_dims': True, 'reduce_output_fuse': True})
if x.data_format == DF.FRAC_NZ:
data_sum = graph_builder.emit('Reshape', [data_sum], attrs={'shape': ori_reduced_shape})
data_sub = graph_builder.emit('Sub', [x, data_sum])
data_mul2 = graph_builder.emit('Mul', [data_sub, y])
result = graph_builder.emit('Mul', [data_mul2, z])
return result

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for SquareSumV1"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis
@VLD.add_format(DF.FRAC_NZ)
@VLD.add_format(DF.DEFAULT)
@VLD.check_attrs('axis')
class SquareSumV1(Expander):
"""Square expander"""
def _expand(self, graph_builder):
x = self.inputs[0]
axis = self.attrs['axis']
ori_shape = x.shape
if x.data_format == DF.FRAC_NZ:
ori_shape = infer_shape_from_fractalNz(ori_shape)
if not axis:
axis = []
for i, _ in enumerate(ori_shape):
axis.append(i)
else:
if isinstance(axis, int):
axis = [axis]
for i, _ in enumerate(list(axis)):
if axis[i] < 0:
axis[i] += len(ori_shape)
ori_reduced_shape = ori_shape
if x.data_format == DF.FRAC_NZ:
ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis)
axis = to_frac_z_axis(ori_shape, axis)
square_res = graph_builder.emit('Mul', [x, x])
result = graph_builder.emit('ReduceSum', [square_res], attrs={'reduce_axis': axis, 'keep_dims': True})
if x.data_format == DF.FRAC_NZ:
result = graph_builder.emit('Reshape', [result], attrs={'shape': ori_reduced_shape})
return result

View File

@ -59,6 +59,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimClipByNormNoDivSum,
prim::kLambApplyOptimizerAssign,
prim::kLambApplyWeightAssign,
prim::kSoftmaxGradExt,
prim::kSquareSumV1,
#elif ENABLE_GPU
prim::kPrimBiasAdd,
prim::kPrimFusedAdam,

View File

@ -314,6 +314,8 @@ inline const PrimitivePtr kPrimL2Normalize = std::make_shared<Primitive>("L2Norm
inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared<Primitive>("CustomExtractFeatures");
inline const PrimitivePtr kLambApplyOptimizerAssign = std::make_shared<Primitive>("LambApplyOptimizerAssign");
inline const PrimitivePtr kLambApplyWeightAssign = std::make_shared<Primitive>("LambApplyWeightAssign");
inline const PrimitivePtr kSoftmaxGradExt = std::make_shared<Primitive>("SoftmaxGradExt");
inline const PrimitivePtr kSquareSumV1 = std::make_shared<Primitive>("SquareSumV1");
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");