!9522 Expand Tile and SqrtGrad in graph kernel

From: @looop5
Reviewed-by: @ckey_dou,@gaoxiong1
Signed-off-by: @gaoxiong1
This commit is contained in:
mindspore-ci-bot 2020-12-07 10:55:26 +08:00 committed by Gitee
commit d616ce1a2a
8 changed files with 144 additions and 3 deletions

View File

@ -32,3 +32,5 @@ from .layernorm_grad import expand_layernormgrad
from .logsoftmax import expand_logsoftmax
from .logsoftmax_grad import expand_logsoftmaxgrad
from .gkdropout import expand_gkdropout
from .tile import expand_tile
from .sqrt_grad import expand_sqrtgrad

View File

@ -0,0 +1,45 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for sqrtgrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_sqrtgrad(expand_info):
"""SqrtGrad expander"""
# cal formula are:
# sqrt_grad(x, dout) = dout / (2 * x)
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_dout = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
graph_scope.set_input(input_x, input_dout)
# cal result
const_two = graph_builder.value(input_x.dtype, 2, input_x.data_format)
dividend = graph_builder.emit('Mul', [input_x, const_two])
result = graph_builder.emit('RealDiv', [input_dout, dividend])
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for Tile"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def _get_tile_output_shape(shape, multiples):
"""compute output shape of tile"""
if multiples is None:
return shape
if not isinstance(shape, (list, tuple)):
raise TypeError("Input shape of Tile must be of type list or tuple")
if not isinstance(multiples, (list, tuple)):
raise TypeError("multiples of Tile must be of type list or tuple")
shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)
shape_compatible = True
output_shape = []
input_reshape = []
output_reshape = []
for sh, mul in list(zip(shape, multiples)):
dim = sh * mul
output_shape.append(dim)
if sh == 1 or mul == 1:
input_reshape.append(sh)
output_reshape.append(dim)
else:
shape_compatible = False
input_reshape.append(1)
input_reshape.append(sh)
output_reshape.append(mul)
output_reshape.append(sh)
return output_shape, input_reshape, output_reshape, shape_compatible
def expand_tile(expand_info):
"""Tile expander"""
# get op info.
input_desc = expand_info['input_desc'][0]
attrs = expand_info['attr']
multiples = None
for item in attrs:
if 'multiples' in item:
multiples = item['multiples']
output_shape, input_reshape, output_reshape, shape_compatible = _get_tile_output_shape(input_desc['shape'],
multiples)
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
# create op.
if shape_compatible:
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
else:
input_x_reshape = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_reshape})
reshape_broadcast = graph_builder.emit('BroadcastTo', [input_x_reshape], attrs={'shape': output_reshape})
result = graph_builder.emit('Reshape', [reshape_broadcast], attrs={'shape': output_shape})
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -31,7 +31,8 @@ class GraphSplitByPattern:
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
self.mode = self.MODE_BASIC
if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE):
if self.pattern == PrimLib.TRANSFORM or self.pattern == PrimLib.BROADCAST or \
(use_poly_reduce and self.pattern == PrimLib.REDUCE):
self.mode = self.MODE_COMPOSITE
self.is_output = is_output
self.output_excluded = set()

View File

@ -170,6 +170,7 @@ class PrimLib:
'FlattenGrad': Prim(RESHAPE),
'Transpose': Prim(TRANSFORM),
'Tile': Prim(BROADCAST),
'BroadcastTo': Prim(BROADCAST),
}
default_primtive = Prim(UNKNOWN)

View File

@ -73,6 +73,7 @@ class OpInfer:
# add special infer func here
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
'Reshape': lambda inputs, attrs: attrs["shape"],
'BroadcastTo': lambda inputs, attrs: attrs["shape"],
}
infer_dtype_func = {
# add special infer func here

View File

@ -248,6 +248,7 @@ class CNodeDecoder {
{kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}},
{kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}},
{kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}},
{kBroadcastToOpName, std::vector<std::string>{kAttrShape}},
};
PrimitivePtr GetPrimitive(const std::string &op_name) {

View File

@ -701,11 +701,14 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
std::unordered_set<PrimitivePtr> GetExpandOps() {
std::unordered_set<PrimitivePtr> expand_ops = {
prim::kPrimSquare,
#if ENABLE_GPU
prim::kPrimGeluGrad,
#if ENABLE_D
prim::kPrimTile,
prim::kPrimSqrtGrad,
#elif ENABLE_GPU
prim::kPrimBiasAdd,
prim::kPrimBiasAddGrad,
prim::kPrimGelu,
prim::kPrimGeluGrad,
prim::kPrimFusedAdam,
prim::kPrimFusedAdamWeightDecay,
prim::kPrimTanhGrad,