forked from mindspore-Ecosystem/mindspore
!9522 Expand Tile and SqrtGrad in graph kernel
From: @looop5 Reviewed-by: @ckey_dou,@gaoxiong1 Signed-off-by: @gaoxiong1
This commit is contained in:
commit
d616ce1a2a
|
@ -32,3 +32,5 @@ from .layernorm_grad import expand_layernormgrad
|
|||
from .logsoftmax import expand_logsoftmax
|
||||
from .logsoftmax_grad import expand_logsoftmaxgrad
|
||||
from .gkdropout import expand_gkdropout
|
||||
from .tile import expand_tile
|
||||
from .sqrt_grad import expand_sqrtgrad
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for sqrtgrad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
|
||||
def expand_sqrtgrad(expand_info):
|
||||
"""SqrtGrad expander"""
|
||||
# cal formula are:
|
||||
# sqrt_grad(x, dout) = dout / (2 * x)
|
||||
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_dout = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
graph_scope.set_input(input_x, input_dout)
|
||||
|
||||
# cal result
|
||||
const_two = graph_builder.value(input_x.dtype, 2, input_x.data_format)
|
||||
dividend = graph_builder.emit('Mul', [input_x, const_two])
|
||||
result = graph_builder.emit('RealDiv', [input_dout, dividend])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for Tile"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
|
||||
def _get_tile_output_shape(shape, multiples):
|
||||
"""compute output shape of tile"""
|
||||
|
||||
if multiples is None:
|
||||
return shape
|
||||
if not isinstance(shape, (list, tuple)):
|
||||
raise TypeError("Input shape of Tile must be of type list or tuple")
|
||||
if not isinstance(multiples, (list, tuple)):
|
||||
raise TypeError("multiples of Tile must be of type list or tuple")
|
||||
|
||||
shape = list(shape)
|
||||
multiples = list(multiples)
|
||||
diff_len = len(multiples) - len(shape)
|
||||
if diff_len < 0:
|
||||
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
|
||||
if diff_len > 0:
|
||||
for _ in range(diff_len):
|
||||
shape.insert(0, 1)
|
||||
|
||||
shape_compatible = True
|
||||
output_shape = []
|
||||
input_reshape = []
|
||||
output_reshape = []
|
||||
for sh, mul in list(zip(shape, multiples)):
|
||||
dim = sh * mul
|
||||
output_shape.append(dim)
|
||||
if sh == 1 or mul == 1:
|
||||
input_reshape.append(sh)
|
||||
output_reshape.append(dim)
|
||||
else:
|
||||
shape_compatible = False
|
||||
input_reshape.append(1)
|
||||
input_reshape.append(sh)
|
||||
output_reshape.append(mul)
|
||||
output_reshape.append(sh)
|
||||
|
||||
return output_shape, input_reshape, output_reshape, shape_compatible
|
||||
|
||||
|
||||
def expand_tile(expand_info):
|
||||
"""Tile expander"""
|
||||
|
||||
# get op info.
|
||||
input_desc = expand_info['input_desc'][0]
|
||||
attrs = expand_info['attr']
|
||||
multiples = None
|
||||
for item in attrs:
|
||||
if 'multiples' in item:
|
||||
multiples = item['multiples']
|
||||
output_shape, input_reshape, output_reshape, shape_compatible = _get_tile_output_shape(input_desc['shape'],
|
||||
multiples)
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
# create op.
|
||||
if shape_compatible:
|
||||
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
|
||||
else:
|
||||
input_x_reshape = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_reshape})
|
||||
reshape_broadcast = graph_builder.emit('BroadcastTo', [input_x_reshape], attrs={'shape': output_reshape})
|
||||
result = graph_builder.emit('Reshape', [reshape_broadcast], attrs={'shape': output_shape})
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
|
@ -31,7 +31,8 @@ class GraphSplitByPattern:
|
|||
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||
self.mode = self.MODE_BASIC
|
||||
if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE):
|
||||
if self.pattern == PrimLib.TRANSFORM or self.pattern == PrimLib.BROADCAST or \
|
||||
(use_poly_reduce and self.pattern == PrimLib.REDUCE):
|
||||
self.mode = self.MODE_COMPOSITE
|
||||
self.is_output = is_output
|
||||
self.output_excluded = set()
|
||||
|
|
|
@ -170,6 +170,7 @@ class PrimLib:
|
|||
'FlattenGrad': Prim(RESHAPE),
|
||||
'Transpose': Prim(TRANSFORM),
|
||||
'Tile': Prim(BROADCAST),
|
||||
'BroadcastTo': Prim(BROADCAST),
|
||||
}
|
||||
|
||||
default_primtive = Prim(UNKNOWN)
|
||||
|
|
|
@ -73,6 +73,7 @@ class OpInfer:
|
|||
# add special infer func here
|
||||
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
|
||||
'Reshape': lambda inputs, attrs: attrs["shape"],
|
||||
'BroadcastTo': lambda inputs, attrs: attrs["shape"],
|
||||
}
|
||||
infer_dtype_func = {
|
||||
# add special infer func here
|
||||
|
|
|
@ -248,6 +248,7 @@ class CNodeDecoder {
|
|||
{kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}},
|
||||
{kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}},
|
||||
{kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}},
|
||||
{kBroadcastToOpName, std::vector<std::string>{kAttrShape}},
|
||||
};
|
||||
|
||||
PrimitivePtr GetPrimitive(const std::string &op_name) {
|
||||
|
|
|
@ -701,11 +701,14 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
|
|||
std::unordered_set<PrimitivePtr> GetExpandOps() {
|
||||
std::unordered_set<PrimitivePtr> expand_ops = {
|
||||
prim::kPrimSquare,
|
||||
#if ENABLE_GPU
|
||||
prim::kPrimGeluGrad,
|
||||
#if ENABLE_D
|
||||
prim::kPrimTile,
|
||||
prim::kPrimSqrtGrad,
|
||||
#elif ENABLE_GPU
|
||||
prim::kPrimBiasAdd,
|
||||
prim::kPrimBiasAddGrad,
|
||||
prim::kPrimGelu,
|
||||
prim::kPrimGeluGrad,
|
||||
prim::kPrimFusedAdam,
|
||||
prim::kPrimFusedAdamWeightDecay,
|
||||
prim::kPrimTanhGrad,
|
||||
|
|
Loading…
Reference in New Issue