forked from mindspore-Ecosystem/mindspore
!12387 【GraphKernel】Refactor GraphKernelExpander (2nd submission)
From: @dayschan Reviewed-by: Signed-off-by:
This commit is contained in:
commit
ddea9d52eb
|
@ -18,6 +18,16 @@ import json.decoder as jd
|
|||
import traceback
|
||||
from mindspore import log as logger
|
||||
import mindspore._extends.graph_kernel.expanders as expanders
|
||||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException
|
||||
|
||||
|
||||
def create_expander(expand_info):
|
||||
"""Create an expander according to op name"""
|
||||
op_name = str(expand_info['name'])
|
||||
if not hasattr(expanders, op_name):
|
||||
raise GraphKernelUnsupportedException("Generator do not support op: {}".format(op_name))
|
||||
expander = getattr(expanders, op_name)
|
||||
return expander(expand_info)
|
||||
|
||||
|
||||
def extract_expand_info(kernel_info):
|
||||
|
@ -46,20 +56,8 @@ def get_op_expander(json_str: str):
|
|||
kernel_info = json.loads(json_str)
|
||||
expand_info = extract_expand_info(kernel_info)
|
||||
|
||||
processor = expand_info['process']
|
||||
op_name = str(expand_info['name']).lower()
|
||||
expand_op_func_name = 'expand_' + op_name
|
||||
if not hasattr(expanders, expand_op_func_name):
|
||||
logger.error("Generator do not support op: {}".format(op_name))
|
||||
return None
|
||||
expand_op_func = getattr(expanders, expand_op_func_name)
|
||||
# generate graph desc.
|
||||
graph = expand_op_func(expand_info)
|
||||
if graph is None:
|
||||
logger.error("Failed to generate graph of: {}".format(op_name))
|
||||
return None
|
||||
|
||||
graph.set_processor(processor)
|
||||
expander = create_expander(expand_info)
|
||||
graph = expander.run()
|
||||
|
||||
# dump graph to json desc.
|
||||
desc = graph.dump()
|
||||
|
@ -69,3 +67,6 @@ def get_op_expander(json_str: str):
|
|||
logger.error("Failed to generate graph kernel op")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
except GraphKernelUnsupportedException as e:
|
||||
logger.info(e.message)
|
||||
return ""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,24 +14,24 @@
|
|||
# ============================================================================
|
||||
"""expanders init"""
|
||||
|
||||
from .gelu import expand_gelu
|
||||
from .gelu_grad import expand_gelugrad
|
||||
from .layernorm import expand_layernorm
|
||||
from .softmax import expand_softmax
|
||||
from .square import expand_square
|
||||
from .bias_add import expand_biasadd
|
||||
from .bias_add_grad import expand_biasaddgrad
|
||||
from .fused_adam import expand_fusedadam
|
||||
from .fused_adam_weight_decay import expand_fusedadamweightdecay
|
||||
from .reduce_mean import expand_reducemean
|
||||
from .tanh_grad import expand_tanhgrad
|
||||
from .maximum_grad import expand_maximumgrad
|
||||
from .minimum_grad import expand_minimumgrad
|
||||
from .dropout_grad import expand_dropoutgrad
|
||||
from .layernorm_grad import expand_layernormgrad
|
||||
from .logsoftmax import expand_logsoftmax
|
||||
from .logsoftmax_grad import expand_logsoftmaxgrad
|
||||
from .gkdropout import expand_gkdropout
|
||||
from .tile import expand_tile
|
||||
from .sqrt_grad import expand_sqrtgrad
|
||||
from .clip_by_norm_no_div_sum import expand_clipbynormnodivsum
|
||||
from .bias_add import BiasAdd
|
||||
from .bias_add_grad import BiasAddGrad
|
||||
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
|
||||
from .dropout_grad import DropoutGrad
|
||||
from .fused_adam import FusedAdam
|
||||
from .fused_adam_weight_decay import FusedAdamWeightDecay
|
||||
from .gelu import GeLU
|
||||
from .gelu_grad import GeLUGrad
|
||||
from .gkdropout import GkDropout
|
||||
from .layernorm import LayerNorm
|
||||
from .layernorm_grad import LayerNormGrad
|
||||
from .logsoftmax import LogSoftmax
|
||||
from .logsoftmax_grad import LogSoftmaxGrad
|
||||
from .maximum_grad import MaximumGrad
|
||||
from .minimum_grad import MinimumGrad
|
||||
from .reduce_mean import ReduceMean
|
||||
from .softmax import Softmax
|
||||
from .sqrt_grad import SqrtGrad
|
||||
from .square import Square
|
||||
from .tanh_grad import TanhGrad
|
||||
from .tile import Tile
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""GraphKernel expander utils"""
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
||||
|
||||
|
||||
class Expander:
|
||||
"""
|
||||
Expander is the base class of expanders.
|
||||
|
||||
The method `_expand` should be overridden to implement the operator detail.
|
||||
"""
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, expand_info):
|
||||
self.name = expand_info["name"]
|
||||
self.inputs = expand_info["input_desc"]
|
||||
self.outputs = expand_info["output_desc"]
|
||||
self.attrs = expand_info["attr"]
|
||||
self.processor = expand_info["process"]
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Expand the operator to a graph.
|
||||
|
||||
`GraphKernelUnsupportedException` would be raised if check failed.
|
||||
"""
|
||||
self._check()
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope(self.name) as graph_scope:
|
||||
# transform input_desc to Tensor
|
||||
self.inputs = [graph_builder.tensor(inp['shape'], inp['data_type'], inp['format']) for inp in self.inputs]
|
||||
graph_scope.set_input(*self.inputs)
|
||||
outputs = self._expand(graph_builder)
|
||||
if isinstance(outputs, (list, tuple)):
|
||||
graph_scope.set_output(*outputs)
|
||||
else:
|
||||
graph_scope.set_output(outputs)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
graph.set_processor(self.processor)
|
||||
return graph
|
||||
|
||||
def _check(self):
|
||||
"""Check inputs"""
|
||||
|
||||
@abstractmethod
|
||||
def _expand(self, graph_builder):
|
||||
"""Expand operator, this function should be overridden in subclass"""
|
||||
raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
|
||||
|
||||
|
||||
class ExpanderInfoValidator:
|
||||
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
|
||||
# pylint: disable=W0211
|
||||
@staticmethod
|
||||
def _add_check_function(cls, func):
|
||||
"""
|
||||
Rewrite the function `_check` in class Expander
|
||||
to append the new `func` after the original checks.
|
||||
"""
|
||||
old_check = getattr(cls, "_check")
|
||||
|
||||
def new_check(obj):
|
||||
old_check(obj)
|
||||
func(obj)
|
||||
setattr(cls, "_check", new_check)
|
||||
|
||||
@staticmethod
|
||||
def add_format(*input_format):
|
||||
"""
|
||||
Add new supported format for the operator
|
||||
|
||||
this function will add a list `__supported_formats` into the expander,
|
||||
saving the whitelist of formats that this op supports.
|
||||
it also rewrites the `_check` function to check the formats.
|
||||
"""
|
||||
format_list_name = "__supported_formats"
|
||||
|
||||
def _check_format(obj):
|
||||
inp_formats = [inp['format'] for inp in obj.inputs]
|
||||
for formats in getattr(obj, format_list_name):
|
||||
if len(formats) != len(inp_formats):
|
||||
raise GKException("length of registered format doesn't match with the input of {}".format(obj.name))
|
||||
if all([fmt == inp for fmt, inp in zip(formats, inp_formats)]):
|
||||
return
|
||||
raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
|
||||
|
||||
def wrapper(cls):
|
||||
if not issubclass(cls, Expander):
|
||||
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
||||
if not hasattr(cls, format_list_name):
|
||||
setattr(cls, format_list_name, list())
|
||||
ExpanderInfoValidator._add_check_function(cls, _check_format)
|
||||
getattr(cls, format_list_name).append(input_format)
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def check_all_formats_same(cls):
|
||||
"""Check that all formats are the same"""
|
||||
def _check_format(obj):
|
||||
inp_formats = [inp['format'] for inp in obj.inputs]
|
||||
if all([fmt == inp_formats[0] for fmt in inp_formats[1:]]):
|
||||
return
|
||||
raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
|
||||
','.join(inp_formats), obj.name))
|
||||
|
||||
def wrapper(*args, **kargs):
|
||||
if not issubclass(cls, Expander):
|
||||
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
||||
ExpanderInfoValidator._add_check_function(cls, _check_format)
|
||||
return cls(*args, **kargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def check_attrs(*args):
|
||||
"""Check the attrs exist"""
|
||||
def _check_attr(obj):
|
||||
for a in args:
|
||||
if a not in obj.attrs:
|
||||
raise GKException("attr '{}' does not exist.".format(a))
|
||||
|
||||
def wrapper(cls):
|
||||
if not issubclass(cls, Expander):
|
||||
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
||||
ExpanderInfoValidator._add_check_function(cls, _check_attr)
|
||||
return cls
|
||||
|
||||
return wrapper
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,50 +13,34 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for bias_add"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_biasadd(expand_info):
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.add_format(DF.NCHW, DF.DEFAULT)
|
||||
@VLD.add_format(DF.NHWC, DF.DEFAULT)
|
||||
class BiasAdd(Expander):
|
||||
"""BiasAdd expander"""
|
||||
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(
|
||||
input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_y = graph_builder.tensor(
|
||||
input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
graph_scope.set_input(input_x, input_y)
|
||||
if input_x.data_format == "NCHW":
|
||||
input_y_expand = graph_builder.emit(
|
||||
'ExpandDims', [input_y], attrs={'axis': 1})
|
||||
input_y_expand = graph_builder.emit(
|
||||
'ExpandDims', [input_y_expand], attrs={'axis': 2})
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_y = self.inputs
|
||||
|
||||
if input_x.data_format == DF.NCHW:
|
||||
input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1})
|
||||
input_y_expand = graph_builder.emit('ExpandDims', [input_y_expand], attrs={'axis': 2})
|
||||
result = graph_builder.emit('Add', [input_x, input_y_expand])
|
||||
elif input_x.data_format == "DefaultFormat":
|
||||
elif input_x.data_format == DF.DEFAULT:
|
||||
if len(input_x.shape) == 2:
|
||||
result = graph_builder.emit('Add', [input_x, input_y])
|
||||
elif len(input_x.shape) == 3:
|
||||
input_y_expand = graph_builder.emit(
|
||||
'ExpandDims', [input_y], attrs={'axis': 1})
|
||||
result = graph_builder.emit(
|
||||
'Add', [input_x, input_y_expand])
|
||||
else:
|
||||
input_y_expand = graph_builder.emit(
|
||||
'ExpandDims', [input_y], attrs={'axis': 1})
|
||||
input_y_expand = graph_builder.emit(
|
||||
'ExpandDims', [input_y_expand], attrs={'axis': 2})
|
||||
result = graph_builder.emit(
|
||||
'Add', [input_x, input_y_expand])
|
||||
else:
|
||||
input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1})
|
||||
result = graph_builder.emit('Add', [input_x, input_y_expand])
|
||||
else: # len == 4
|
||||
input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1})
|
||||
input_y_expand = graph_builder.emit('ExpandDims', [input_y_expand], attrs={'axis': 2})
|
||||
result = graph_builder.emit('Add', [input_x, input_y_expand])
|
||||
else: # NHWC
|
||||
result = graph_builder.emit('Add', [input_x, input_y])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,26 +13,25 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for bias_add"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_biasaddgrad(expand_info):
|
||||
@VLD.add_format(DF.DEFAULT)
|
||||
@VLD.add_format(DF.NHWC)
|
||||
@VLD.add_format(DF.NCHW)
|
||||
class BiasAddGrad(Expander):
|
||||
"""BiasAddGrad expander"""
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(
|
||||
input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
graph_scope.set_input(input_x)
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
|
||||
reduce_axis = ()
|
||||
if input_x.data_format == 'NHWC':
|
||||
reduce_axis = (0, 1, 2)
|
||||
elif input_x.data_format == 'NCHW':
|
||||
reduce_axis = (0, 2, 3)
|
||||
# Default format shape's length maybe equal 2 to 4, so different shape's length reduce axis are differnet
|
||||
# DefaultFormat shape's length should be from 2 to 4
|
||||
else:
|
||||
if len(input_x.shape) == 2:
|
||||
reduce_axis = (0,)
|
||||
|
@ -41,8 +40,4 @@ def expand_biasaddgrad(expand_info):
|
|||
else:
|
||||
reduce_axis = (0, 2, 3)
|
||||
result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,27 +13,15 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for ClipByNormNoDivSum"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_clipbynormnodivsum(expand_info):
|
||||
@VLD.check_all_formats_same
|
||||
class ClipByNormNoDivSum(Expander):
|
||||
"""ClipByNormNoDivSum expander"""
|
||||
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
input_desc_2 = expand_info['input_desc'][2]
|
||||
input_desc_3 = expand_info['input_desc'][3]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x0 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_x1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
input_x2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
input_x3 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format'])
|
||||
graph_scope.set_input(input_x0, input_x1, input_x2, input_x3)
|
||||
def _expand(self, graph_builder):
|
||||
input_x0, input_x1, input_x2, input_x3 = self.inputs
|
||||
|
||||
# cal result
|
||||
greater_res = graph_builder.emit('Greater', [input_x0, input_x1], attrs={'fusion': 'SelectGT_000'})
|
||||
|
@ -44,8 +32,4 @@ def expand_clipbynormnodivsum(expand_info):
|
|||
attrs={'fusion': 'SelectGT_000_end'})
|
||||
result = graph_builder.emit('Maximum', [select_res1, input_x3])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -13,27 +13,18 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for DropoutGrad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_dropoutgrad(expand_info):
|
||||
@VLD.check_all_formats_same
|
||||
@VLD.check_attrs('keep_prob')
|
||||
class DropoutGrad(Expander):
|
||||
"""DropoutGrad expander"""
|
||||
# get op info.
|
||||
dy_desc = expand_info['input_desc'][0]
|
||||
mask_desc = expand_info['input_desc'][1]
|
||||
keep_prob = expand_info['attr']['keep_prob']
|
||||
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format'])
|
||||
input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format'])
|
||||
graph_scope.set_input(input_dy, input_mask)
|
||||
def _expand(self, graph_builder):
|
||||
input_dy, input_mask = self.inputs
|
||||
keep_prob = self.attrs['keep_prob']
|
||||
r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob)
|
||||
# create op.
|
||||
result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
|
||||
result = graph_builder.emit('Mul', [result, input_mask])
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,40 +13,16 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for fused_adam"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_fusedadam(expand_info):
|
||||
"""FusedAdma expander"""
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
input_desc_2 = expand_info['input_desc'][2]
|
||||
input_desc_3 = expand_info['input_desc'][3]
|
||||
input_desc_4 = expand_info['input_desc'][4]
|
||||
input_desc_5 = expand_info['input_desc'][5]
|
||||
input_desc_6 = expand_info['input_desc'][6]
|
||||
input_desc_7 = expand_info['input_desc'][7]
|
||||
input_desc_8 = expand_info['input_desc'][8]
|
||||
input_desc_9 = expand_info['input_desc'][9]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
@VLD.check_all_formats_same
|
||||
class FusedAdam(Expander):
|
||||
"""FusedAdam expander"""
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format'])
|
||||
eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format'])
|
||||
lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format'])
|
||||
param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format'])
|
||||
m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format'])
|
||||
v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format'])
|
||||
gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format'])
|
||||
graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient)
|
||||
def _expand(self, graph_builder):
|
||||
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs
|
||||
|
||||
# compute result
|
||||
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
|
||||
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
|
||||
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
|
||||
|
@ -60,12 +36,9 @@ def expand_fusedadam(expand_info):
|
|||
update_with_lr = graph_builder.emit('Mul', [lr, update])
|
||||
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
||||
|
||||
param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
|
||||
param_result = graph_builder.emit(
|
||||
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
|
||||
param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
|
||||
param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(param_result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return param_result
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,41 +13,15 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for fused_adam_weight_decay"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_fusedadamweightdecay(expand_info):
|
||||
"""FusedAdmaWeightDecay expander"""
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
input_desc_2 = expand_info['input_desc'][2]
|
||||
input_desc_3 = expand_info['input_desc'][3]
|
||||
input_desc_4 = expand_info['input_desc'][4]
|
||||
input_desc_5 = expand_info['input_desc'][5]
|
||||
input_desc_6 = expand_info['input_desc'][6]
|
||||
input_desc_7 = expand_info['input_desc'][7]
|
||||
input_desc_8 = expand_info['input_desc'][8]
|
||||
input_desc_9 = expand_info['input_desc'][9]
|
||||
input_desc_10 = expand_info['input_desc'][10]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
@VLD.check_all_formats_same
|
||||
class FusedAdamWeightDecay(Expander):
|
||||
"""FusedAdamWeightDecay expander"""
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format'])
|
||||
eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format'])
|
||||
lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format'])
|
||||
param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format'])
|
||||
m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format'])
|
||||
v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format'])
|
||||
gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format'])
|
||||
weight_decay = graph_builder.tensor(input_desc_10['shape'], input_desc_10['data_type'], input_desc_10['format'])
|
||||
graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2,
|
||||
eps, lr, param, m, v, gradient, weight_decay)
|
||||
def _expand(self, graph_builder):
|
||||
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs
|
||||
|
||||
# compute result
|
||||
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
|
||||
|
@ -65,12 +39,9 @@ def expand_fusedadamweightdecay(expand_info):
|
|||
update_with_lr = graph_builder.emit('Mul', [lr, update])
|
||||
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
||||
|
||||
para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
|
||||
para_result = graph_builder.emit(
|
||||
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
|
||||
para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True})
|
||||
para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True})
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(para_result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return para_result
|
||||
|
|
|
@ -13,49 +13,36 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for gelu"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
CSVALUE = 0.044715
|
||||
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
|
||||
ONE = 1.0
|
||||
HALF = 0.5
|
||||
from ._utils import Expander
|
||||
|
||||
|
||||
def expand_gelu(expand_info):
|
||||
class GeLU(Expander):
|
||||
"""GeLU expander"""
|
||||
# cal formula are:
|
||||
# gelu(x) is 0.5 * x * (1.0 + tanh(y))
|
||||
# y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
|
||||
CSVALUE = 0.044715
|
||||
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
|
||||
|
||||
# get op info.
|
||||
input_desc = expand_info['input_desc'][0]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
def _expand(self, graph_builder):
|
||||
# cal formula are:
|
||||
# gelu(x) is 0.5 * x * (1.0 + tanh(y))
|
||||
# y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
graph_scope.set_input(input_x)
|
||||
input_x = self.inputs[0]
|
||||
|
||||
# cal y
|
||||
mul_0 = graph_builder.emit('Mul', [input_x, input_x])
|
||||
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
|
||||
const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE)
|
||||
const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE)
|
||||
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
|
||||
tanh_res = graph_builder.emit('Add', [input_x, mul_1])
|
||||
const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI)
|
||||
const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
|
||||
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
|
||||
|
||||
# cal gelu(x)
|
||||
tanh_y = graph_builder.emit('Tanh', [y])
|
||||
const_one = graph_builder.value(tanh_y.dtype, ONE)
|
||||
const_half = graph_builder.value(tanh_y.dtype, HALF)
|
||||
const_one = graph_builder.value(tanh_y.dtype, 1)
|
||||
const_half = graph_builder.value(tanh_y.dtype, 0.5)
|
||||
tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one])
|
||||
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
|
||||
result = graph_builder.emit('Mul', [const_half, mul_x])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -13,43 +13,31 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for gelugrad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
CSVALUE = 0.044715
|
||||
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
|
||||
CSVALUE_TRI = 0.134141 # CSVALUE * 3
|
||||
ONE = 1.0
|
||||
HALF = 0.5
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_gelugrad(expand_info):
|
||||
@VLD.check_all_formats_same
|
||||
class GeLUGrad(Expander):
|
||||
"""GeLUGrad expander"""
|
||||
# cal formula are:
|
||||
# gelu_grad(dy, x) is dy * y'
|
||||
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
|
||||
# tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
|
||||
# mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
|
||||
CSVALUE = 0.044715
|
||||
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
|
||||
CSVALUE_TRI = 0.134141 # CSVALUE * 3
|
||||
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
input_desc_2 = expand_info['input_desc'][2]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
def _expand(self, graph_builder):
|
||||
# cal formula are:
|
||||
# gelu_grad(dy, x) is dy * y'
|
||||
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
|
||||
# tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
|
||||
# mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_dy = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
graph_scope.set_input(input_dy, input_x, input_y)
|
||||
input_dy, input_x, _ = self.inputs
|
||||
|
||||
# create some const var
|
||||
const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE)
|
||||
const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI)
|
||||
const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI)
|
||||
const_one = graph_builder.value(input_dy.dtype, ONE)
|
||||
const_half = graph_builder.value(input_dy.dtype, HALF)
|
||||
const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE)
|
||||
const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
|
||||
const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI)
|
||||
const_one = graph_builder.value(input_dy.dtype, 1)
|
||||
const_half = graph_builder.value(input_dy.dtype, 0.5)
|
||||
|
||||
# cal mul_right
|
||||
mul_double = graph_builder.emit('Mul', [input_x, input_x])
|
||||
|
@ -79,8 +67,4 @@ def expand_gelugrad(expand_info):
|
|||
result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final])
|
||||
result = graph_builder.emit('Mul', [input_dy, result_tmp])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -12,35 +12,29 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for GkDropOut"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
"""generate json desc for GkDropout"""
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_gkdropout(expand_info):
|
||||
"""GkDropOut expander"""
|
||||
# get op info.
|
||||
input_desc = expand_info['input_desc'][0]
|
||||
maks_desc = expand_info['input_desc'][1]
|
||||
keep_prob = expand_info['attr']['keep_prob']
|
||||
@VLD.check_all_formats_same
|
||||
@VLD.check_attrs('keep_prob')
|
||||
class GkDropout(Expander):
|
||||
"""GkDropout expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_mask = self.inputs
|
||||
keep_prob = self.attrs['keep_prob']
|
||||
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format'])
|
||||
graph_scope.set_input(input_x, input_mask)
|
||||
keep_prob_v = graph_builder.value(input_x.dtype, keep_prob)
|
||||
r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob)
|
||||
keep_prob = graph_builder.value(input_x.dtype, keep_prob)
|
||||
|
||||
if input_mask.dtype != input_x.dtype:
|
||||
input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype})
|
||||
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type
|
||||
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type
|
||||
mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype})
|
||||
|
||||
# compute result
|
||||
result = graph_builder.emit('Mul', [r_keep_prob, input_x])
|
||||
result = graph_builder.emit('Mul', [result, mask])
|
||||
# set graph output.
|
||||
graph_scope.set_output(result, mask)
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
|
||||
return result, mask
|
||||
|
|
|
@ -13,38 +13,31 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for LayerNorm"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_layernorm(expand_info):
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
|
||||
class LayerNorm(Expander):
|
||||
"""LayerNorm expander"""
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
input_desc_2 = expand_info['input_desc'][2]
|
||||
attrs = expand_info['attr']
|
||||
begin_norm_axis = attrs['begin_norm_axis']
|
||||
epsilon = attrs['epsilon']
|
||||
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_gamma = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
input_beta = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_gamma, input_beta = self.inputs
|
||||
begin_norm_axis = self.attrs['begin_norm_axis']
|
||||
epsilon = self.attrs['epsilon']
|
||||
|
||||
# Calculate the scaling ratio of the average
|
||||
shape_x = input_desc_0['shape']
|
||||
if begin_norm_axis < 0:
|
||||
begin_norm_axis += len(shape_x)
|
||||
begin_norm_axis += len(input_x.shape)
|
||||
reduce_axis = ()
|
||||
for i, _ in enumerate(shape_x):
|
||||
for i, _ in enumerate(input_x.shape):
|
||||
if i > begin_norm_axis or i == begin_norm_axis:
|
||||
reduce_axis = reduce_axis + (i,)
|
||||
|
||||
reduce_elts = 1.0
|
||||
for i in reduce_axis:
|
||||
reduce_elts *= shape_x[i]
|
||||
reduce_elts *= input_x.shape[i]
|
||||
mean_cof = 1.0 / reduce_elts
|
||||
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof)
|
||||
|
||||
|
@ -70,8 +63,4 @@ def expand_layernorm(expand_info):
|
|||
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
|
||||
res = graph_builder.emit('Add', [scale_mul, input_beta])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(res, mean, variance)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return res, mean, variance
|
||||
|
|
|
@ -13,42 +13,30 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for LayerNormGrad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_layernormgrad(expand_info):
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis')
|
||||
class LayerNormGrad(Expander):
|
||||
"""LayerNormGrad expander"""
|
||||
# get op info.
|
||||
x_desc = expand_info['input_desc'][0]
|
||||
dy_desc = expand_info['input_desc'][1]
|
||||
var_desc = expand_info['input_desc'][2]
|
||||
mean_desc = expand_info['input_desc'][3]
|
||||
gamma_desc = expand_info['input_desc'][4]
|
||||
attrs = expand_info['attr']
|
||||
begin_norm_axis = attrs['begin_norm_axis']
|
||||
begin_params_axis = attrs['begin_params_axis']
|
||||
epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-11
|
||||
|
||||
shape_x = x_desc['shape']
|
||||
if begin_norm_axis < 0:
|
||||
begin_norm_axis += len(shape_x)
|
||||
if begin_params_axis < 0:
|
||||
begin_params_axis += len(shape_x)
|
||||
norm_axis = tuple(range(begin_norm_axis, len(shape_x)))
|
||||
param_axis = tuple(range(0, begin_params_axis))
|
||||
reduce_size = 1.0
|
||||
for i in norm_axis:
|
||||
reduce_size *= shape_x[i]
|
||||
def _expand(self, graph_builder):
|
||||
x, dy, variance, mean, gamma = self.inputs
|
||||
begin_norm_axis = self.attrs['begin_norm_axis']
|
||||
begin_params_axis = self.attrs['begin_params_axis']
|
||||
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11
|
||||
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create input tensors.
|
||||
x = graph_builder.tensor(x_desc['shape'], x_desc['data_type'], x_desc['format'])
|
||||
dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format'])
|
||||
variance = graph_builder.tensor(var_desc['shape'], var_desc['data_type'], var_desc['format'])
|
||||
mean = graph_builder.tensor(mean_desc['shape'], mean_desc['data_type'], mean_desc['format'])
|
||||
gamma = graph_builder.tensor(gamma_desc['shape'], gamma_desc['data_type'], gamma_desc['format'])
|
||||
graph_scope.set_input(x, dy, variance, mean, gamma)
|
||||
if begin_norm_axis < 0:
|
||||
begin_norm_axis += len(x.shape)
|
||||
if begin_params_axis < 0:
|
||||
begin_params_axis += len(x.shape)
|
||||
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
|
||||
param_axis = tuple(range(0, begin_params_axis))
|
||||
reduce_size = 1.0
|
||||
for i in norm_axis:
|
||||
reduce_size *= x.shape[i]
|
||||
|
||||
# set some constant val.
|
||||
eps = graph_builder.value(x.dtype, epsilon)
|
||||
|
@ -99,8 +87,4 @@ def expand_layernormgrad(expand_info):
|
|||
dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
|
||||
dx = graph_builder.emit('Add', [dx_tmp, dx_3])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(dx, dg, db)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return dx, dg, db
|
||||
|
|
|
@ -13,24 +13,21 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for LogSoftmax"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_logsoftmax(expand_info):
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.check_attrs('axis')
|
||||
class LogSoftmax(Expander):
|
||||
"""LogSoftmax expander"""
|
||||
# get op info.
|
||||
input_desc = expand_info['input_desc'][0]
|
||||
axis = expand_info['attr']['axis']
|
||||
graph_builder = builder.GraphBuilder()
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
graph_scope.set_input(input_x)
|
||||
|
||||
# cal logsoftmax.
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
axis = self.attrs['axis']
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
|
||||
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
||||
data_exp = graph_builder.emit('Exp', [data_sub])
|
||||
|
@ -38,8 +35,4 @@ def expand_logsoftmax(expand_info):
|
|||
log_expsum = graph_builder.emit('Log', [data_expsum])
|
||||
result = graph_builder.emit('Sub', [data_sub, log_expsum])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -13,34 +13,24 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for LogSoftmaxGrad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_logsoftmaxgrad(expand_info):
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.check_attrs('axis')
|
||||
class LogSoftmaxGrad(Expander):
|
||||
"""LogSoftmaxGrad expander"""
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
axis = expand_info['attr']['axis']
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_logits = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
graph_scope.set_input(input_logits, input_dy)
|
||||
def _expand(self, graph_builder):
|
||||
input_logits, input_dy = self.inputs
|
||||
axis = self.attrs['axis']
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
|
||||
# cal logsoftmaxgrad.
|
||||
softmax = graph_builder.emit('Exp', [input_logits])
|
||||
dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||
mul_result = graph_builder.emit('Mul', [softmax, dy_sum])
|
||||
result = graph_builder.emit('Sub', [input_dy, mul_result])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -13,40 +13,26 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for maximum_grad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_maximumgrad(expand_info):
|
||||
@VLD.check_all_formats_same
|
||||
class MaximumGrad(Expander):
|
||||
"""MaximumGrad expander"""
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
input_desc_2 = expand_info['input_desc'][2]
|
||||
attrs = expand_info['attr']
|
||||
grad_x = attrs['grad_x'] if 'grad_x' in attrs else True
|
||||
grad_y = attrs['grad_y'] if 'grad_y' in attrs else True
|
||||
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
graph_scope.set_input(input_x, input_y, input_dout)
|
||||
x_dtype = input_x.dtype
|
||||
# cal result
|
||||
def _check(self):
|
||||
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
|
||||
raise GKException("both grad_x and grad_y are False.")
|
||||
return super()._check()
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_y, input_dout = self.inputs
|
||||
|
||||
ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
|
||||
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': x_dtype})
|
||||
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
|
||||
dx = graph_builder.emit('Mul', [ge_result, input_dout])
|
||||
dy = graph_builder.emit('Sub', [input_dout, dx])
|
||||
|
||||
# set graph output according to grad_x and grad_y
|
||||
if grad_x and grad_y:
|
||||
graph_scope.set_output(dx, dy)
|
||||
if grad_x and not grad_y:
|
||||
graph_scope.set_output(dx)
|
||||
if grad_y and not grad_x:
|
||||
graph_scope.set_output(dy)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
# output two results, regardless of grad_x and grad_y
|
||||
return dx, dy
|
||||
|
|
|
@ -13,41 +13,26 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for minimum_grad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_minimumgrad(expand_info):
|
||||
@VLD.check_all_formats_same
|
||||
class MinimumGrad(Expander):
|
||||
"""MinimumGrad expander"""
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
input_desc_2 = expand_info['input_desc'][2]
|
||||
attrs = expand_info['attr']
|
||||
grad_x = attrs['grad_x'] if 'grad_x' in attrs else True
|
||||
grad_y = attrs['grad_y'] if 'grad_y' in attrs else True
|
||||
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
graph_scope.set_input(input_x, input_y, input_dout)
|
||||
x_dtype = input_x.dtype
|
||||
def _check(self):
|
||||
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
|
||||
raise GKException("both grad_x and grad_y are False.")
|
||||
return super()._check()
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x, input_y, input_dout = self.inputs
|
||||
|
||||
# cal result
|
||||
le_result = graph_builder.emit('LessEqual', [input_x, input_y])
|
||||
le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': x_dtype})
|
||||
le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype})
|
||||
dx = graph_builder.emit('Mul', [le_result, input_dout])
|
||||
dy = graph_builder.emit('Sub', [input_dout, dx])
|
||||
|
||||
# set graph output according to grad_x and grad_y
|
||||
if grad_x and grad_y:
|
||||
graph_scope.set_output(dx, dy)
|
||||
if grad_x and not grad_y:
|
||||
graph_scope.set_output(dx)
|
||||
if grad_y and not grad_x:
|
||||
graph_scope.set_output(dy)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
# output two results, regardless of grad_x and grad_y
|
||||
return dx, dy
|
||||
|
|
|
@ -13,45 +13,30 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for reduce_mean"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_reducemean(expand_info):
|
||||
@VLD.add_format(DF.DEFAULT)
|
||||
@VLD.check_attrs('axis', 'keep_dims')
|
||||
class ReduceMean(Expander):
|
||||
"""ReduceMean expander"""
|
||||
# get op info.
|
||||
input_desc = expand_info['input_desc'][0]
|
||||
attrs = expand_info['attr']
|
||||
axis = attrs['axis']
|
||||
keep_dims = attrs['keep_dims']
|
||||
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
x_shape = input_x.shape
|
||||
graph_scope.set_input(input_x)
|
||||
def _expand(self, graph_builder):
|
||||
x = self.inputs[0]
|
||||
axis = self.attrs['axis']
|
||||
keep_dims = self.attrs['keep_dims']
|
||||
|
||||
# cal reduce_mean, when axis = None, reduce axis are all
|
||||
all_shape = 1.0
|
||||
real_axis = []
|
||||
# cal reduce_mean, when axis is None, reduce all axes.
|
||||
if not axis:
|
||||
for i, shape in enumerate(x_shape):
|
||||
real_axis.append(i)
|
||||
all_shape *= shape
|
||||
else:
|
||||
for idx in axis:
|
||||
all_shape *= x_shape[idx]
|
||||
axis = list(range(len(x.shape)))
|
||||
reduce_size = 1.0
|
||||
for idx in axis:
|
||||
reduce_size *= x.shape[idx]
|
||||
|
||||
all_shape_value = graph_builder.value(input_x.dtype, all_shape)
|
||||
reduce_size_value = graph_builder.value(x.dtype, reduce_size)
|
||||
|
||||
if not axis:
|
||||
sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims})
|
||||
else:
|
||||
sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
|
||||
result = graph_builder.emit('RealDiv', [sum_x, all_shape_value])
|
||||
sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
|
||||
result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -13,26 +13,23 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for softmax"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_softmax(expand_info):
|
||||
@VLD.add_format(DF.DEFAULT)
|
||||
@VLD.check_attrs('axis')
|
||||
class Softmax(Expander):
|
||||
"""Softmax expander"""
|
||||
input_desc = expand_info['input_desc'][0]
|
||||
axis = expand_info['attr']['axis']
|
||||
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
# cal softmax.
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
axis = self.attrs['axis']
|
||||
|
||||
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
||||
data_exp = graph_builder.emit('Exp', [data_sub])
|
||||
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||
result = graph_builder.emit('RealDiv', [data_exp, data_expsum])
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -13,33 +13,17 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for sqrtgrad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_sqrtgrad(expand_info):
|
||||
@VLD.check_all_formats_same
|
||||
class SqrtGrad(Expander):
|
||||
"""SqrtGrad expander"""
|
||||
# cal formula are:
|
||||
# sqrt_grad(x, dout) is dout / (2 * x)
|
||||
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_dout = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
graph_scope.set_input(input_x, input_dout)
|
||||
|
||||
# cal result
|
||||
const_two = graph_builder.value(input_x.dtype, 2)
|
||||
dividend = graph_builder.emit('Mul', [input_x, const_two])
|
||||
result = graph_builder.emit('RealDiv', [input_dout, dividend])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
def _expand(self, graph_builder):
|
||||
# sqrt_grad(x, dout) = dout / (2 * x)
|
||||
x, dout = self.inputs
|
||||
const_two = graph_builder.value(x.dtype, 2)
|
||||
dividend = graph_builder.emit('Mul', [x, const_two])
|
||||
result = graph_builder.emit('RealDiv', [dout, dividend])
|
||||
return result
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,24 +13,13 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for square"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from ._utils import Expander
|
||||
|
||||
|
||||
def expand_square(expand_info):
|
||||
class Square(Expander):
|
||||
"""Square expander"""
|
||||
|
||||
# get op info.
|
||||
input_desc = expand_info['input_desc'][0]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
# create op.
|
||||
result = graph_builder.emit('Mul', [input_x, input_x])
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
def _expand(self, graph_builder):
|
||||
x = self.inputs[0]
|
||||
result = graph_builder.emit('Mul', [x, x])
|
||||
return result
|
||||
|
|
|
@ -13,34 +13,19 @@
|
|||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for tanh_grad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
ONE = 1.0
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_tanhgrad(expand_info):
|
||||
@VLD.check_all_formats_same
|
||||
class TanhGrad(Expander):
|
||||
"""TanhGrad expander"""
|
||||
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
def _expand(self, graph_builder):
|
||||
input_y, input_dy = self.inputs
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
const_one = graph_builder.value(input_y.dtype, ONE)
|
||||
graph_scope.set_input(input_y, input_dy)
|
||||
|
||||
# cal result
|
||||
const_one = graph_builder.value(input_y.dtype, 1)
|
||||
double_y = graph_builder.emit('Mul', [input_y, input_y])
|
||||
one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y])
|
||||
result = graph_builder.emit('Mul', [input_dy, one_sub_double_y])
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -14,25 +14,22 @@
|
|||
# ===========================================================================
|
||||
"""generate json desc for Tile"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
def expand_tile(expand_info):
|
||||
@VLD.add_format(DF.DEFAULT)
|
||||
@VLD.check_attrs('multiples')
|
||||
class Tile(Expander):
|
||||
"""Tile expander"""
|
||||
input_desc = expand_info['input_desc'][0]
|
||||
multiples = expand_info['attr']['multiples']
|
||||
output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples)
|
||||
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||
# create op.
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
multiples = self.attrs['multiples']
|
||||
|
||||
output_shape, _, _, shape_compatible = builder.get_tile_output_shape(self.inputs[0].shape, multiples)
|
||||
if shape_compatible:
|
||||
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
|
||||
else:
|
||||
result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples})
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
return result
|
||||
|
|
|
@ -55,6 +55,25 @@ class DataFormat:
|
|||
NDHWC = "NDHWC"
|
||||
|
||||
|
||||
class DataType:
|
||||
"""Data Type"""
|
||||
FLOAT = "float"
|
||||
FLOAT16 = "float16"
|
||||
FLOAT32 = "float32"
|
||||
FLOAT64 = "float64"
|
||||
INT = "int"
|
||||
INT8 = "int8"
|
||||
INT16 = "int16"
|
||||
INT32 = "int32"
|
||||
INT64 = "int64"
|
||||
UINT = "uint"
|
||||
UINT8 = "uint8"
|
||||
UINT16 = "uint16"
|
||||
UINT32 = "uint32"
|
||||
UINT64 = "uint64"
|
||||
BOOL = "bool"
|
||||
|
||||
|
||||
class Config:
|
||||
R0 = 8.0
|
||||
UB_SIZE = 256 * 1024
|
||||
|
@ -508,3 +527,9 @@ class AddControlBuddy(GraphVisitor):
|
|||
for owner in self.buddies:
|
||||
for op in self.buddies[owner]:
|
||||
owner.add_buddy(op.output)
|
||||
|
||||
|
||||
class GraphKernelUnsupportedException(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__()
|
||||
self.message = message
|
||||
|
|
|
@ -90,7 +90,7 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
|
|||
}
|
||||
std::string kernel_desc_str = py::cast<std::string>(ret);
|
||||
if (kernel_desc_str.empty()) {
|
||||
MS_LOG(ERROR) << "Jump expand node: " << node->fullname_with_scope();
|
||||
MS_LOG(INFO) << "Jump expand node: " << node->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
// decode json to func_graph.
|
||||
|
@ -131,11 +131,8 @@ AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func
|
|||
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
|
||||
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs);
|
||||
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node));
|
||||
std::string graph_kernel_flag;
|
||||
std::for_each(kernel_nodes.begin(), kernel_nodes.end(), [&graph_kernel_flag](const AnfNodePtr &node) {
|
||||
static_cast<void>(graph_kernel_flag.append(AnfAlgo::GetCNodeName(node)).append("_"));
|
||||
});
|
||||
MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() << " with: " << graph_kernel_flag;
|
||||
MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope()
|
||||
<< " with: " << graph_kernel_node->fullname_with_scope();
|
||||
return graph_kernel_node;
|
||||
}
|
||||
|
||||
|
@ -152,14 +149,12 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
|||
continue;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Expand process node: " << node->fullname_with_scope();
|
||||
MS_LOG(INFO) << "Expanding node: " << node->fullname_with_scope();
|
||||
auto new_func_graph = CreateExpandFuncGraph(node);
|
||||
if (new_func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Decode fused nodes failed, " << node->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
mng->AddFuncGraph(new_func_graph);
|
||||
MS_LOG(DEBUG) << "decode fused nodes success.";
|
||||
|
||||
auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node);
|
||||
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node)));
|
||||
|
|
Loading…
Reference in New Issue