!12387 【GraphKernel】Refactor GraphKernelExpander (2nd submission)

From: @dayschan
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-23 09:28:35 +08:00 committed by Gitee
commit ddea9d52eb
26 changed files with 477 additions and 583 deletions

View File

@ -18,6 +18,16 @@ import json.decoder as jd
import traceback import traceback
from mindspore import log as logger from mindspore import log as logger
import mindspore._extends.graph_kernel.expanders as expanders import mindspore._extends.graph_kernel.expanders as expanders
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException
def create_expander(expand_info):
"""Create an expander according to op name"""
op_name = str(expand_info['name'])
if not hasattr(expanders, op_name):
raise GraphKernelUnsupportedException("Generator do not support op: {}".format(op_name))
expander = getattr(expanders, op_name)
return expander(expand_info)
def extract_expand_info(kernel_info): def extract_expand_info(kernel_info):
@ -46,20 +56,8 @@ def get_op_expander(json_str: str):
kernel_info = json.loads(json_str) kernel_info = json.loads(json_str)
expand_info = extract_expand_info(kernel_info) expand_info = extract_expand_info(kernel_info)
processor = expand_info['process'] expander = create_expander(expand_info)
op_name = str(expand_info['name']).lower() graph = expander.run()
expand_op_func_name = 'expand_' + op_name
if not hasattr(expanders, expand_op_func_name):
logger.error("Generator do not support op: {}".format(op_name))
return None
expand_op_func = getattr(expanders, expand_op_func_name)
# generate graph desc.
graph = expand_op_func(expand_info)
if graph is None:
logger.error("Failed to generate graph of: {}".format(op_name))
return None
graph.set_processor(processor)
# dump graph to json desc. # dump graph to json desc.
desc = graph.dump() desc = graph.dump()
@ -69,3 +67,6 @@ def get_op_expander(json_str: str):
logger.error("Failed to generate graph kernel op") logger.error("Failed to generate graph kernel op")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return None return None
except GraphKernelUnsupportedException as e:
logger.info(e.message)
return ""

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,24 +14,24 @@
# ============================================================================ # ============================================================================
"""expanders init""" """expanders init"""
from .gelu import expand_gelu from .bias_add import BiasAdd
from .gelu_grad import expand_gelugrad from .bias_add_grad import BiasAddGrad
from .layernorm import expand_layernorm from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
from .softmax import expand_softmax from .dropout_grad import DropoutGrad
from .square import expand_square from .fused_adam import FusedAdam
from .bias_add import expand_biasadd from .fused_adam_weight_decay import FusedAdamWeightDecay
from .bias_add_grad import expand_biasaddgrad from .gelu import GeLU
from .fused_adam import expand_fusedadam from .gelu_grad import GeLUGrad
from .fused_adam_weight_decay import expand_fusedadamweightdecay from .gkdropout import GkDropout
from .reduce_mean import expand_reducemean from .layernorm import LayerNorm
from .tanh_grad import expand_tanhgrad from .layernorm_grad import LayerNormGrad
from .maximum_grad import expand_maximumgrad from .logsoftmax import LogSoftmax
from .minimum_grad import expand_minimumgrad from .logsoftmax_grad import LogSoftmaxGrad
from .dropout_grad import expand_dropoutgrad from .maximum_grad import MaximumGrad
from .layernorm_grad import expand_layernormgrad from .minimum_grad import MinimumGrad
from .logsoftmax import expand_logsoftmax from .reduce_mean import ReduceMean
from .logsoftmax_grad import expand_logsoftmaxgrad from .softmax import Softmax
from .gkdropout import expand_gkdropout from .sqrt_grad import SqrtGrad
from .tile import expand_tile from .square import Square
from .sqrt_grad import expand_sqrtgrad from .tanh_grad import TanhGrad
from .clip_by_norm_no_div_sum import expand_clipbynormnodivsum from .tile import Tile

View File

@ -0,0 +1,146 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""GraphKernel expander utils"""
from abc import ABCMeta, abstractmethod
from mindspore._extends.graph_kernel.model import model_builder as builder
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
class Expander:
"""
Expander is the base class of expanders.
The method `_expand` should be overridden to implement the operator detail.
"""
__metaclass__ = ABCMeta
def __init__(self, expand_info):
self.name = expand_info["name"]
self.inputs = expand_info["input_desc"]
self.outputs = expand_info["output_desc"]
self.attrs = expand_info["attr"]
self.processor = expand_info["process"]
def run(self):
"""
Expand the operator to a graph.
`GraphKernelUnsupportedException` would be raised if check failed.
"""
self._check()
graph_builder = builder.GraphBuilder()
with graph_builder.graph_scope(self.name) as graph_scope:
# transform input_desc to Tensor
self.inputs = [graph_builder.tensor(inp['shape'], inp['data_type'], inp['format']) for inp in self.inputs]
graph_scope.set_input(*self.inputs)
outputs = self._expand(graph_builder)
if isinstance(outputs, (list, tuple)):
graph_scope.set_output(*outputs)
else:
graph_scope.set_output(outputs)
graph = graph_builder.get()[0]
graph.set_processor(self.processor)
return graph
def _check(self):
"""Check inputs"""
@abstractmethod
def _expand(self, graph_builder):
"""Expand operator, this function should be overridden in subclass"""
raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
class ExpanderInfoValidator:
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
# pylint: disable=W0211
@staticmethod
def _add_check_function(cls, func):
"""
Rewrite the function `_check` in class Expander
to append the new `func` after the original checks.
"""
old_check = getattr(cls, "_check")
def new_check(obj):
old_check(obj)
func(obj)
setattr(cls, "_check", new_check)
@staticmethod
def add_format(*input_format):
"""
Add new supported format for the operator
this function will add a list `__supported_formats` into the expander,
saving the whitelist of formats that this op supports.
it also rewrites the `_check` function to check the formats.
"""
format_list_name = "__supported_formats"
def _check_format(obj):
inp_formats = [inp['format'] for inp in obj.inputs]
for formats in getattr(obj, format_list_name):
if len(formats) != len(inp_formats):
raise GKException("length of registered format doesn't match with the input of {}".format(obj.name))
if all([fmt == inp for fmt, inp in zip(formats, inp_formats)]):
return
raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
def wrapper(cls):
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
if not hasattr(cls, format_list_name):
setattr(cls, format_list_name, list())
ExpanderInfoValidator._add_check_function(cls, _check_format)
getattr(cls, format_list_name).append(input_format)
return cls
return wrapper
@staticmethod
def check_all_formats_same(cls):
"""Check that all formats are the same"""
def _check_format(obj):
inp_formats = [inp['format'] for inp in obj.inputs]
if all([fmt == inp_formats[0] for fmt in inp_formats[1:]]):
return
raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
','.join(inp_formats), obj.name))
def wrapper(*args, **kargs):
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
ExpanderInfoValidator._add_check_function(cls, _check_format)
return cls(*args, **kargs)
return wrapper
@staticmethod
def check_attrs(*args):
"""Check the attrs exist"""
def _check_attr(obj):
for a in args:
if a not in obj.attrs:
raise GKException("attr '{}' does not exist.".format(a))
def wrapper(cls):
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
ExpanderInfoValidator._add_check_function(cls, _check_attr)
return cls
return wrapper

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,50 +13,34 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for bias_add""" """generate json desc for bias_add"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_biasadd(expand_info): @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NCHW, DF.DEFAULT)
@VLD.add_format(DF.NHWC, DF.DEFAULT)
class BiasAdd(Expander):
"""BiasAdd expander""" """BiasAdd expander"""
# get op info. def _expand(self, graph_builder):
input_desc_0 = expand_info['input_desc'][0] input_x, input_y = self.inputs
input_desc_1 = expand_info['input_desc'][1]
graph_builder = builder.GraphBuilder() if input_x.data_format == DF.NCHW:
# generate a graph. input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1})
with graph_builder.graph_scope('main') as graph_scope: input_y_expand = graph_builder.emit('ExpandDims', [input_y_expand], attrs={'axis': 2})
# create tensor input.
input_x = graph_builder.tensor(
input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_y = graph_builder.tensor(
input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
graph_scope.set_input(input_x, input_y)
if input_x.data_format == "NCHW":
input_y_expand = graph_builder.emit(
'ExpandDims', [input_y], attrs={'axis': 1})
input_y_expand = graph_builder.emit(
'ExpandDims', [input_y_expand], attrs={'axis': 2})
result = graph_builder.emit('Add', [input_x, input_y_expand]) result = graph_builder.emit('Add', [input_x, input_y_expand])
elif input_x.data_format == "DefaultFormat": elif input_x.data_format == DF.DEFAULT:
if len(input_x.shape) == 2: if len(input_x.shape) == 2:
result = graph_builder.emit('Add', [input_x, input_y]) result = graph_builder.emit('Add', [input_x, input_y])
elif len(input_x.shape) == 3: elif len(input_x.shape) == 3:
input_y_expand = graph_builder.emit( input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1})
'ExpandDims', [input_y], attrs={'axis': 1}) result = graph_builder.emit('Add', [input_x, input_y_expand])
result = graph_builder.emit( else: # len == 4
'Add', [input_x, input_y_expand]) input_y_expand = graph_builder.emit('ExpandDims', [input_y], attrs={'axis': 1})
else: input_y_expand = graph_builder.emit('ExpandDims', [input_y_expand], attrs={'axis': 2})
input_y_expand = graph_builder.emit( result = graph_builder.emit('Add', [input_x, input_y_expand])
'ExpandDims', [input_y], attrs={'axis': 1}) else: # NHWC
input_y_expand = graph_builder.emit(
'ExpandDims', [input_y_expand], attrs={'axis': 2})
result = graph_builder.emit(
'Add', [input_x, input_y_expand])
else:
result = graph_builder.emit('Add', [input_x, input_y]) result = graph_builder.emit('Add', [input_x, input_y])
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,26 +13,25 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for bias_add""" """generate json desc for bias_add"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_biasaddgrad(expand_info): @VLD.add_format(DF.DEFAULT)
@VLD.add_format(DF.NHWC)
@VLD.add_format(DF.NCHW)
class BiasAddGrad(Expander):
"""BiasAddGrad expander""" """BiasAddGrad expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0] def _expand(self, graph_builder):
graph_builder = builder.GraphBuilder() input_x = self.inputs[0]
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(
input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
graph_scope.set_input(input_x)
reduce_axis = () reduce_axis = ()
if input_x.data_format == 'NHWC': if input_x.data_format == 'NHWC':
reduce_axis = (0, 1, 2) reduce_axis = (0, 1, 2)
elif input_x.data_format == 'NCHW': elif input_x.data_format == 'NCHW':
reduce_axis = (0, 2, 3) reduce_axis = (0, 2, 3)
# Default format shape's length maybe equal 2 to 4, so different shape's length reduce axis are differnet # DefaultFormat shape's length should be from 2 to 4
else: else:
if len(input_x.shape) == 2: if len(input_x.shape) == 2:
reduce_axis = (0,) reduce_axis = (0,)
@ -41,8 +40,4 @@ def expand_biasaddgrad(expand_info):
else: else:
reduce_axis = (0, 2, 3) reduce_axis = (0, 2, 3)
result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,27 +13,15 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for ClipByNormNoDivSum""" """generate json desc for ClipByNormNoDivSum"""
from mindspore._extends.graph_kernel.model import model_builder as builder from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_clipbynormnodivsum(expand_info): @VLD.check_all_formats_same
class ClipByNormNoDivSum(Expander):
"""ClipByNormNoDivSum expander""" """ClipByNormNoDivSum expander"""
# get op info. def _expand(self, graph_builder):
input_desc_0 = expand_info['input_desc'][0] input_x0, input_x1, input_x2, input_x3 = self.inputs
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
input_desc_3 = expand_info['input_desc'][3]
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x0 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_x1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
input_x2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
input_x3 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format'])
graph_scope.set_input(input_x0, input_x1, input_x2, input_x3)
# cal result # cal result
greater_res = graph_builder.emit('Greater', [input_x0, input_x1], attrs={'fusion': 'SelectGT_000'}) greater_res = graph_builder.emit('Greater', [input_x0, input_x1], attrs={'fusion': 'SelectGT_000'})
@ -44,8 +32,4 @@ def expand_clipbynormnodivsum(expand_info):
attrs={'fusion': 'SelectGT_000_end'}) attrs={'fusion': 'SelectGT_000_end'})
result = graph_builder.emit('Maximum', [select_res1, input_x3]) result = graph_builder.emit('Maximum', [select_res1, input_x3])
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -13,27 +13,18 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for DropoutGrad""" """generate json desc for DropoutGrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_dropoutgrad(expand_info): @VLD.check_all_formats_same
@VLD.check_attrs('keep_prob')
class DropoutGrad(Expander):
"""DropoutGrad expander""" """DropoutGrad expander"""
# get op info.
dy_desc = expand_info['input_desc'][0]
mask_desc = expand_info['input_desc'][1]
keep_prob = expand_info['attr']['keep_prob']
graph_builder = builder.GraphBuilder() def _expand(self, graph_builder):
with graph_builder.graph_scope('main') as graph_scope: input_dy, input_mask = self.inputs
# create tensor input. keep_prob = self.attrs['keep_prob']
input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format'])
input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format'])
graph_scope.set_input(input_dy, input_mask)
r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob) r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob)
# create op.
result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
result = graph_builder.emit('Mul', [result, input_mask]) result = graph_builder.emit('Mul', [result, input_mask])
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,40 +13,16 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for fused_adam""" """generate json desc for fused_adam"""
from mindspore._extends.graph_kernel.model import model_builder as builder from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_fusedadam(expand_info): @VLD.check_all_formats_same
"""FusedAdma expander""" class FusedAdam(Expander):
# get op info. """FusedAdam expander"""
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
input_desc_3 = expand_info['input_desc'][3]
input_desc_4 = expand_info['input_desc'][4]
input_desc_5 = expand_info['input_desc'][5]
input_desc_6 = expand_info['input_desc'][6]
input_desc_7 = expand_info['input_desc'][7]
input_desc_8 = expand_info['input_desc'][8]
input_desc_9 = expand_info['input_desc'][9]
graph_builder = builder.GraphBuilder()
# generate a graph. def _expand(self, graph_builder):
with graph_builder.graph_scope('main') as graph_scope: beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs
# create tensor input.
beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format'])
eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format'])
lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format'])
param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format'])
m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format'])
v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format'])
gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format'])
graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient)
# compute result
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
@ -60,12 +36,9 @@ def expand_fusedadam(expand_info):
update_with_lr = graph_builder.emit('Mul', [lr, update]) update_with_lr = graph_builder.emit('Mul', [lr, update])
next_para = graph_builder.emit('Sub', [param, update_with_lr]) next_para = graph_builder.emit('Sub', [param, update_with_lr])
param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) param_result = graph_builder.emit(
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True}) param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True}) param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
# set graph output. return param_result
graph_scope.set_output(param_result)
graph = graph_builder.get()[0]
return graph

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,41 +13,15 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for fused_adam_weight_decay""" """generate json desc for fused_adam_weight_decay"""
from mindspore._extends.graph_kernel.model import model_builder as builder from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_fusedadamweightdecay(expand_info): @VLD.check_all_formats_same
"""FusedAdmaWeightDecay expander""" class FusedAdamWeightDecay(Expander):
# get op info. """FusedAdamWeightDecay expander"""
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
input_desc_3 = expand_info['input_desc'][3]
input_desc_4 = expand_info['input_desc'][4]
input_desc_5 = expand_info['input_desc'][5]
input_desc_6 = expand_info['input_desc'][6]
input_desc_7 = expand_info['input_desc'][7]
input_desc_8 = expand_info['input_desc'][8]
input_desc_9 = expand_info['input_desc'][9]
input_desc_10 = expand_info['input_desc'][10]
graph_builder = builder.GraphBuilder()
# generate a graph. def _expand(self, graph_builder):
with graph_builder.graph_scope('main') as graph_scope: beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs
# create tensor input.
beta_1 = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
one_sub_beta_1 = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
beta_2 = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
one_sub_beta_2 = graph_builder.tensor(input_desc_3['shape'], input_desc_3['data_type'], input_desc_3['format'])
eps = graph_builder.tensor(input_desc_4['shape'], input_desc_4['data_type'], input_desc_4['format'])
lr = graph_builder.tensor(input_desc_5['shape'], input_desc_5['data_type'], input_desc_5['format'])
param = graph_builder.tensor(input_desc_6['shape'], input_desc_6['data_type'], input_desc_6['format'])
m = graph_builder.tensor(input_desc_7['shape'], input_desc_7['data_type'], input_desc_7['format'])
v = graph_builder.tensor(input_desc_8['shape'], input_desc_8['data_type'], input_desc_8['format'])
gradient = graph_builder.tensor(input_desc_9['shape'], input_desc_9['data_type'], input_desc_9['format'])
weight_decay = graph_builder.tensor(input_desc_10['shape'], input_desc_10['data_type'], input_desc_10['format'])
graph_scope.set_input(beta_1, one_sub_beta_1, beta_2, one_sub_beta_2,
eps, lr, param, m, v, gradient, weight_decay)
# compute result # compute result
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
@ -65,12 +39,9 @@ def expand_fusedadamweightdecay(expand_info):
update_with_lr = graph_builder.emit('Mul', [lr, update]) update_with_lr = graph_builder.emit('Mul', [lr, update])
next_para = graph_builder.emit('Sub', [param, update_with_lr]) next_para = graph_builder.emit('Sub', [param, update_with_lr])
para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) para_result = graph_builder.emit(
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True}) para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True})
para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True}) para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True})
# set graph output. return para_result
graph_scope.set_output(para_result)
graph = graph_builder.get()[0]
return graph

View File

@ -13,49 +13,36 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for gelu""" """generate json desc for gelu"""
from mindspore._extends.graph_kernel.model import model_builder as builder from ._utils import Expander
CSVALUE = 0.044715
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
ONE = 1.0
HALF = 0.5
def expand_gelu(expand_info): class GeLU(Expander):
"""GeLU expander""" """GeLU expander"""
# cal formula are: CSVALUE = 0.044715
# gelu(x) is 0.5 * x * (1.0 + tanh(y)) CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
# y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
# get op info. def _expand(self, graph_builder):
input_desc = expand_info['input_desc'][0] # cal formula are:
graph_builder = builder.GraphBuilder() # gelu(x) is 0.5 * x * (1.0 + tanh(y))
# y is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
# generate a graph. input_x = self.inputs[0]
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
graph_scope.set_input(input_x)
# cal y # cal y
mul_0 = graph_builder.emit('Mul', [input_x, input_x]) mul_0 = graph_builder.emit('Mul', [input_x, input_x])
pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE) const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE)
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
tanh_res = graph_builder.emit('Add', [input_x, mul_1]) tanh_res = graph_builder.emit('Add', [input_x, mul_1])
const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI) const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
# cal gelu(x) # cal gelu(x)
tanh_y = graph_builder.emit('Tanh', [y]) tanh_y = graph_builder.emit('Tanh', [y])
const_one = graph_builder.value(tanh_y.dtype, ONE) const_one = graph_builder.value(tanh_y.dtype, 1)
const_half = graph_builder.value(tanh_y.dtype, HALF) const_half = graph_builder.value(tanh_y.dtype, 0.5)
tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one])
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
result = graph_builder.emit('Mul', [const_half, mul_x]) result = graph_builder.emit('Mul', [const_half, mul_x])
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -13,43 +13,31 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for gelugrad""" """generate json desc for gelugrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder from ._utils import Expander, ExpanderInfoValidator as VLD
CSVALUE = 0.044715
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
CSVALUE_TRI = 0.134141 # CSVALUE * 3
ONE = 1.0
HALF = 0.5
def expand_gelugrad(expand_info): @VLD.check_all_formats_same
class GeLUGrad(Expander):
"""GeLUGrad expander""" """GeLUGrad expander"""
# cal formula are: CSVALUE = 0.044715
# gelu_grad(dy, x) is dy * y' CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right CSVALUE_TRI = 0.134141 # CSVALUE * 3
# tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
# mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
# get op info. def _expand(self, graph_builder):
input_desc_0 = expand_info['input_desc'][0] # cal formula are:
input_desc_1 = expand_info['input_desc'][1] # gelu_grad(dy, x) is dy * y'
input_desc_2 = expand_info['input_desc'][2] # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
graph_builder = builder.GraphBuilder() # tanh_para is sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
# mul_right is sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
# generate a graph. input_dy, input_x, _ = self.inputs
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_dy = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
graph_scope.set_input(input_dy, input_x, input_y)
# create some const var # create some const var
const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE) const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE)
const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI) const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI) const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI)
const_one = graph_builder.value(input_dy.dtype, ONE) const_one = graph_builder.value(input_dy.dtype, 1)
const_half = graph_builder.value(input_dy.dtype, HALF) const_half = graph_builder.value(input_dy.dtype, 0.5)
# cal mul_right # cal mul_right
mul_double = graph_builder.emit('Mul', [input_x, input_x]) mul_double = graph_builder.emit('Mul', [input_x, input_x])
@ -79,8 +67,4 @@ def expand_gelugrad(expand_info):
result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final]) result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final])
result = graph_builder.emit('Mul', [input_dy, result_tmp]) result = graph_builder.emit('Mul', [input_dy, result_tmp])
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -12,35 +12,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for GkDropOut""" """generate json desc for GkDropout"""
from mindspore._extends.graph_kernel.model import model_builder as builder from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_gkdropout(expand_info): @VLD.check_all_formats_same
"""GkDropOut expander""" @VLD.check_attrs('keep_prob')
# get op info. class GkDropout(Expander):
input_desc = expand_info['input_desc'][0] """GkDropout expander"""
maks_desc = expand_info['input_desc'][1]
keep_prob = expand_info['attr']['keep_prob'] def _expand(self, graph_builder):
input_x, input_mask = self.inputs
keep_prob = self.attrs['keep_prob']
graph_builder = builder.GraphBuilder()
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
input_mask = graph_builder.tensor(maks_desc['shape'], maks_desc['data_type'], maks_desc['format'])
graph_scope.set_input(input_x, input_mask)
keep_prob_v = graph_builder.value(input_x.dtype, keep_prob)
r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob) r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob)
keep_prob = graph_builder.value(input_x.dtype, keep_prob)
if input_mask.dtype != input_x.dtype: if input_mask.dtype != input_x.dtype:
input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype}) input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype})
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob_v]) # output is bool type mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type
mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype}) mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype})
# compute result # compute result
result = graph_builder.emit('Mul', [r_keep_prob, input_x]) result = graph_builder.emit('Mul', [r_keep_prob, input_x])
result = graph_builder.emit('Mul', [result, mask]) result = graph_builder.emit('Mul', [result, mask])
# set graph output.
graph_scope.set_output(result, mask) return result, mask
graph = graph_builder.get()[0]
return graph

View File

@ -13,38 +13,31 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for LayerNorm""" """generate json desc for LayerNorm"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_layernorm(expand_info): @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
class LayerNorm(Expander):
"""LayerNorm expander""" """LayerNorm expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
attrs = expand_info['attr']
begin_norm_axis = attrs['begin_norm_axis']
epsilon = attrs['epsilon']
graph_builder = builder.GraphBuilder() def _expand(self, graph_builder):
with graph_builder.graph_scope('main') as graph_scope: input_x, input_gamma, input_beta = self.inputs
# create tensor input. begin_norm_axis = self.attrs['begin_norm_axis']
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) epsilon = self.attrs['epsilon']
input_gamma = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
input_beta = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
# Calculate the scaling ratio of the average # Calculate the scaling ratio of the average
shape_x = input_desc_0['shape']
if begin_norm_axis < 0: if begin_norm_axis < 0:
begin_norm_axis += len(shape_x) begin_norm_axis += len(input_x.shape)
reduce_axis = () reduce_axis = ()
for i, _ in enumerate(shape_x): for i, _ in enumerate(input_x.shape):
if i > begin_norm_axis or i == begin_norm_axis: if i > begin_norm_axis or i == begin_norm_axis:
reduce_axis = reduce_axis + (i,) reduce_axis = reduce_axis + (i,)
reduce_elts = 1.0 reduce_elts = 1.0
for i in reduce_axis: for i in reduce_axis:
reduce_elts *= shape_x[i] reduce_elts *= input_x.shape[i]
mean_cof = 1.0 / reduce_elts mean_cof = 1.0 / reduce_elts
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) mean_cof_v = graph_builder.value(input_x.dtype, mean_cof)
@ -70,8 +63,4 @@ def expand_layernorm(expand_info):
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul]) scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
res = graph_builder.emit('Add', [scale_mul, input_beta]) res = graph_builder.emit('Add', [scale_mul, input_beta])
# set graph output. return res, mean, variance
graph_scope.set_output(res, mean, variance)
graph = graph_builder.get()[0]
return graph

View File

@ -13,42 +13,30 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for LayerNormGrad""" """generate json desc for LayerNormGrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_layernormgrad(expand_info): @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis')
class LayerNormGrad(Expander):
"""LayerNormGrad expander""" """LayerNormGrad expander"""
# get op info.
x_desc = expand_info['input_desc'][0]
dy_desc = expand_info['input_desc'][1]
var_desc = expand_info['input_desc'][2]
mean_desc = expand_info['input_desc'][3]
gamma_desc = expand_info['input_desc'][4]
attrs = expand_info['attr']
begin_norm_axis = attrs['begin_norm_axis']
begin_params_axis = attrs['begin_params_axis']
epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-11
shape_x = x_desc['shape'] def _expand(self, graph_builder):
if begin_norm_axis < 0: x, dy, variance, mean, gamma = self.inputs
begin_norm_axis += len(shape_x) begin_norm_axis = self.attrs['begin_norm_axis']
if begin_params_axis < 0: begin_params_axis = self.attrs['begin_params_axis']
begin_params_axis += len(shape_x) epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11
norm_axis = tuple(range(begin_norm_axis, len(shape_x)))
param_axis = tuple(range(0, begin_params_axis))
reduce_size = 1.0
for i in norm_axis:
reduce_size *= shape_x[i]
graph_builder = builder.GraphBuilder() if begin_norm_axis < 0:
with graph_builder.graph_scope('main') as graph_scope: begin_norm_axis += len(x.shape)
# create input tensors. if begin_params_axis < 0:
x = graph_builder.tensor(x_desc['shape'], x_desc['data_type'], x_desc['format']) begin_params_axis += len(x.shape)
dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format']) norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
variance = graph_builder.tensor(var_desc['shape'], var_desc['data_type'], var_desc['format']) param_axis = tuple(range(0, begin_params_axis))
mean = graph_builder.tensor(mean_desc['shape'], mean_desc['data_type'], mean_desc['format']) reduce_size = 1.0
gamma = graph_builder.tensor(gamma_desc['shape'], gamma_desc['data_type'], gamma_desc['format']) for i in norm_axis:
graph_scope.set_input(x, dy, variance, mean, gamma) reduce_size *= x.shape[i]
# set some constant val. # set some constant val.
eps = graph_builder.value(x.dtype, epsilon) eps = graph_builder.value(x.dtype, epsilon)
@ -99,8 +87,4 @@ def expand_layernormgrad(expand_info):
dx_tmp = graph_builder.emit('Add', [dx_1, dx_2]) dx_tmp = graph_builder.emit('Add', [dx_1, dx_2])
dx = graph_builder.emit('Add', [dx_tmp, dx_3]) dx = graph_builder.emit('Add', [dx_tmp, dx_3])
# set graph output. return dx, dg, db
graph_scope.set_output(dx, dg, db)
graph = graph_builder.get()[0]
return graph

View File

@ -13,24 +13,21 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for LogSoftmax""" """generate json desc for LogSoftmax"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_logsoftmax(expand_info): @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('axis')
class LogSoftmax(Expander):
"""LogSoftmax expander""" """LogSoftmax expander"""
# get op info.
input_desc = expand_info['input_desc'][0]
axis = expand_info['attr']['axis']
graph_builder = builder.GraphBuilder()
if isinstance(axis, int):
axis = (axis,)
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
graph_scope.set_input(input_x)
# cal logsoftmax. def _expand(self, graph_builder):
input_x = self.inputs[0]
axis = self.attrs['axis']
if isinstance(axis, int):
axis = (axis,)
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
data_sub = graph_builder.emit('Sub', [input_x, max_x]) data_sub = graph_builder.emit('Sub', [input_x, max_x])
data_exp = graph_builder.emit('Exp', [data_sub]) data_exp = graph_builder.emit('Exp', [data_sub])
@ -38,8 +35,4 @@ def expand_logsoftmax(expand_info):
log_expsum = graph_builder.emit('Log', [data_expsum]) log_expsum = graph_builder.emit('Log', [data_expsum])
result = graph_builder.emit('Sub', [data_sub, log_expsum]) result = graph_builder.emit('Sub', [data_sub, log_expsum])
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -13,34 +13,24 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for LogSoftmaxGrad""" """generate json desc for LogSoftmaxGrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_logsoftmaxgrad(expand_info): @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('axis')
class LogSoftmaxGrad(Expander):
"""LogSoftmaxGrad expander""" """LogSoftmaxGrad expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
axis = expand_info['attr']['axis']
graph_builder = builder.GraphBuilder()
if isinstance(axis, int): def _expand(self, graph_builder):
axis = (axis,) input_logits, input_dy = self.inputs
# generate a graph. axis = self.attrs['axis']
with graph_builder.graph_scope('main') as graph_scope: if isinstance(axis, int):
# create tensor input. axis = (axis,)
input_logits = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
graph_scope.set_input(input_logits, input_dy)
# cal logsoftmaxgrad.
softmax = graph_builder.emit('Exp', [input_logits]) softmax = graph_builder.emit('Exp', [input_logits])
dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True}) dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True})
mul_result = graph_builder.emit('Mul', [softmax, dy_sum]) mul_result = graph_builder.emit('Mul', [softmax, dy_sum])
result = graph_builder.emit('Sub', [input_dy, mul_result]) result = graph_builder.emit('Sub', [input_dy, mul_result])
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -13,40 +13,26 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for maximum_grad""" """generate json desc for maximum_grad"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_maximumgrad(expand_info): @VLD.check_all_formats_same
class MaximumGrad(Expander):
"""MaximumGrad expander""" """MaximumGrad expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
attrs = expand_info['attr']
grad_x = attrs['grad_x'] if 'grad_x' in attrs else True
grad_y = attrs['grad_y'] if 'grad_y' in attrs else True
graph_builder = builder.GraphBuilder() def _check(self):
with graph_builder.graph_scope('main') as graph_scope: if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
# create tensor input. raise GKException("both grad_x and grad_y are False.")
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) return super()._check()
input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) def _expand(self, graph_builder):
graph_scope.set_input(input_x, input_y, input_dout) input_x, input_y, input_dout = self.inputs
x_dtype = input_x.dtype
# cal result
ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y]) ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': x_dtype}) ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
dx = graph_builder.emit('Mul', [ge_result, input_dout]) dx = graph_builder.emit('Mul', [ge_result, input_dout])
dy = graph_builder.emit('Sub', [input_dout, dx]) dy = graph_builder.emit('Sub', [input_dout, dx])
# set graph output according to grad_x and grad_y # output two results, regardless of grad_x and grad_y
if grad_x and grad_y: return dx, dy
graph_scope.set_output(dx, dy)
if grad_x and not grad_y:
graph_scope.set_output(dx)
if grad_y and not grad_x:
graph_scope.set_output(dy)
graph = graph_builder.get()[0]
return graph

View File

@ -13,41 +13,26 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for minimum_grad""" """generate json desc for minimum_grad"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_minimumgrad(expand_info): @VLD.check_all_formats_same
class MinimumGrad(Expander):
"""MinimumGrad expander""" """MinimumGrad expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
attrs = expand_info['attr']
grad_x = attrs['grad_x'] if 'grad_x' in attrs else True
grad_y = attrs['grad_y'] if 'grad_y' in attrs else True
graph_builder = builder.GraphBuilder() def _check(self):
with graph_builder.graph_scope('main') as graph_scope: if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
# create tensor input. raise GKException("both grad_x and grad_y are False.")
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) return super()._check()
input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) def _expand(self, graph_builder):
graph_scope.set_input(input_x, input_y, input_dout) input_x, input_y, input_dout = self.inputs
x_dtype = input_x.dtype
# cal result
le_result = graph_builder.emit('LessEqual', [input_x, input_y]) le_result = graph_builder.emit('LessEqual', [input_x, input_y])
le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': x_dtype}) le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype})
dx = graph_builder.emit('Mul', [le_result, input_dout]) dx = graph_builder.emit('Mul', [le_result, input_dout])
dy = graph_builder.emit('Sub', [input_dout, dx]) dy = graph_builder.emit('Sub', [input_dout, dx])
# set graph output according to grad_x and grad_y # output two results, regardless of grad_x and grad_y
if grad_x and grad_y: return dx, dy
graph_scope.set_output(dx, dy)
if grad_x and not grad_y:
graph_scope.set_output(dx)
if grad_y and not grad_x:
graph_scope.set_output(dy)
graph = graph_builder.get()[0]
return graph

View File

@ -13,45 +13,30 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for reduce_mean""" """generate json desc for reduce_mean"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_reducemean(expand_info): @VLD.add_format(DF.DEFAULT)
@VLD.check_attrs('axis', 'keep_dims')
class ReduceMean(Expander):
"""ReduceMean expander""" """ReduceMean expander"""
# get op info.
input_desc = expand_info['input_desc'][0]
attrs = expand_info['attr']
axis = attrs['axis']
keep_dims = attrs['keep_dims']
graph_builder = builder.GraphBuilder() def _expand(self, graph_builder):
with graph_builder.graph_scope('main') as graph_scope: x = self.inputs[0]
# create tensor input. axis = self.attrs['axis']
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) keep_dims = self.attrs['keep_dims']
x_shape = input_x.shape
graph_scope.set_input(input_x)
# cal reduce_mean, when axis = None, reduce axis are all # cal reduce_mean, when axis is None, reduce all axes.
all_shape = 1.0
real_axis = []
if not axis: if not axis:
for i, shape in enumerate(x_shape): axis = list(range(len(x.shape)))
real_axis.append(i) reduce_size = 1.0
all_shape *= shape for idx in axis:
else: reduce_size *= x.shape[idx]
for idx in axis:
all_shape *= x_shape[idx]
all_shape_value = graph_builder.value(input_x.dtype, all_shape) reduce_size_value = graph_builder.value(x.dtype, reduce_size)
if not axis: sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': real_axis, 'keep_dims': keep_dims}) result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value])
else:
sum_x = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
result = graph_builder.emit('RealDiv', [sum_x, all_shape_value])
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -13,26 +13,23 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for softmax""" """generate json desc for softmax"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_softmax(expand_info): @VLD.add_format(DF.DEFAULT)
@VLD.check_attrs('axis')
class Softmax(Expander):
"""Softmax expander""" """Softmax expander"""
input_desc = expand_info['input_desc'][0]
axis = expand_info['attr']['axis']
graph_builder = builder.GraphBuilder() def _expand(self, graph_builder):
with graph_builder.graph_scope('main') as graph_scope: input_x = self.inputs[0]
# create tensor input. axis = self.attrs['axis']
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
# cal softmax.
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
data_sub = graph_builder.emit('Sub', [input_x, max_x]) data_sub = graph_builder.emit('Sub', [input_x, max_x])
data_exp = graph_builder.emit('Exp', [data_sub]) data_exp = graph_builder.emit('Exp', [data_sub])
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
result = graph_builder.emit('RealDiv', [data_exp, data_expsum]) result = graph_builder.emit('RealDiv', [data_exp, data_expsum])
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0] return result
return graph

View File

@ -13,33 +13,17 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for sqrtgrad""" """generate json desc for sqrtgrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_sqrtgrad(expand_info): @VLD.check_all_formats_same
class SqrtGrad(Expander):
"""SqrtGrad expander""" """SqrtGrad expander"""
# cal formula are:
# sqrt_grad(x, dout) is dout / (2 * x)
# get op info. def _expand(self, graph_builder):
input_desc_0 = expand_info['input_desc'][0] # sqrt_grad(x, dout) = dout / (2 * x)
input_desc_1 = expand_info['input_desc'][1] x, dout = self.inputs
graph_builder = builder.GraphBuilder() const_two = graph_builder.value(x.dtype, 2)
dividend = graph_builder.emit('Mul', [x, const_two])
# generate a graph. result = graph_builder.emit('RealDiv', [dout, dividend])
with graph_builder.graph_scope('main') as graph_scope: return result
# create tensor input.
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_dout = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
graph_scope.set_input(input_x, input_dout)
# cal result
const_two = graph_builder.value(input_x.dtype, 2)
dividend = graph_builder.emit('Mul', [input_x, const_two])
result = graph_builder.emit('RealDiv', [input_dout, dividend])
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,24 +13,13 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for square""" """generate json desc for square"""
from mindspore._extends.graph_kernel.model import model_builder as builder from ._utils import Expander
def expand_square(expand_info): class Square(Expander):
"""Square expander""" """Square expander"""
# get op info. def _expand(self, graph_builder):
input_desc = expand_info['input_desc'][0] x = self.inputs[0]
graph_builder = builder.GraphBuilder() result = graph_builder.emit('Mul', [x, x])
return result
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
# create op.
result = graph_builder.emit('Mul', [input_x, input_x])
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -13,34 +13,19 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for tanh_grad""" """generate json desc for tanh_grad"""
from mindspore._extends.graph_kernel.model import model_builder as builder from ._utils import Expander, ExpanderInfoValidator as VLD
ONE = 1.0
def expand_tanhgrad(expand_info): @VLD.check_all_formats_same
class TanhGrad(Expander):
"""TanhGrad expander""" """TanhGrad expander"""
# get op info. def _expand(self, graph_builder):
input_desc_0 = expand_info['input_desc'][0] input_y, input_dy = self.inputs
input_desc_1 = expand_info['input_desc'][1]
graph_builder = builder.GraphBuilder()
# generate a graph. const_one = graph_builder.value(input_y.dtype, 1)
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_y = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_dy = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
const_one = graph_builder.value(input_y.dtype, ONE)
graph_scope.set_input(input_y, input_dy)
# cal result
double_y = graph_builder.emit('Mul', [input_y, input_y]) double_y = graph_builder.emit('Mul', [input_y, input_y])
one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y]) one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y])
result = graph_builder.emit('Mul', [input_dy, one_sub_double_y]) result = graph_builder.emit('Mul', [input_dy, one_sub_double_y])
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -14,25 +14,22 @@
# =========================================================================== # ===========================================================================
"""generate json desc for Tile""" """generate json desc for Tile"""
from mindspore._extends.graph_kernel.model import model_builder as builder from mindspore._extends.graph_kernel.model import model_builder as builder
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
def expand_tile(expand_info): @VLD.add_format(DF.DEFAULT)
@VLD.check_attrs('multiples')
class Tile(Expander):
"""Tile expander""" """Tile expander"""
input_desc = expand_info['input_desc'][0]
multiples = expand_info['attr']['multiples']
output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples)
graph_builder = builder.GraphBuilder() def _expand(self, graph_builder):
with graph_builder.graph_scope('main') as graph_scope: input_x = self.inputs[0]
# create tensor input. multiples = self.attrs['multiples']
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
# create op. output_shape, _, _, shape_compatible = builder.get_tile_output_shape(self.inputs[0].shape, multiples)
if shape_compatible: if shape_compatible:
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
else: else:
result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples})
# set graph output. return result
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -55,6 +55,25 @@ class DataFormat:
NDHWC = "NDHWC" NDHWC = "NDHWC"
class DataType:
"""Data Type"""
FLOAT = "float"
FLOAT16 = "float16"
FLOAT32 = "float32"
FLOAT64 = "float64"
INT = "int"
INT8 = "int8"
INT16 = "int16"
INT32 = "int32"
INT64 = "int64"
UINT = "uint"
UINT8 = "uint8"
UINT16 = "uint16"
UINT32 = "uint32"
UINT64 = "uint64"
BOOL = "bool"
class Config: class Config:
R0 = 8.0 R0 = 8.0
UB_SIZE = 256 * 1024 UB_SIZE = 256 * 1024
@ -508,3 +527,9 @@ class AddControlBuddy(GraphVisitor):
for owner in self.buddies: for owner in self.buddies:
for op in self.buddies[owner]: for op in self.buddies[owner]:
owner.add_buddy(op.output) owner.add_buddy(op.output)
class GraphKernelUnsupportedException(Exception):
def __init__(self, message):
super().__init__()
self.message = message

View File

@ -90,7 +90,7 @@ FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
} }
std::string kernel_desc_str = py::cast<std::string>(ret); std::string kernel_desc_str = py::cast<std::string>(ret);
if (kernel_desc_str.empty()) { if (kernel_desc_str.empty()) {
MS_LOG(ERROR) << "Jump expand node: " << node->fullname_with_scope(); MS_LOG(INFO) << "Jump expand node: " << node->fullname_with_scope();
return nullptr; return nullptr;
} }
// decode json to func_graph. // decode json to func_graph.
@ -131,11 +131,8 @@ AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs); kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs); auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs);
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node)); SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node));
std::string graph_kernel_flag; MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope()
std::for_each(kernel_nodes.begin(), kernel_nodes.end(), [&graph_kernel_flag](const AnfNodePtr &node) { << " with: " << graph_kernel_node->fullname_with_scope();
static_cast<void>(graph_kernel_flag.append(AnfAlgo::GetCNodeName(node)).append("_"));
});
MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() << " with: " << graph_kernel_flag;
return graph_kernel_node; return graph_kernel_node;
} }
@ -152,14 +149,12 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
continue; continue;
} }
MS_LOG(INFO) << "Expand process node: " << node->fullname_with_scope(); MS_LOG(INFO) << "Expanding node: " << node->fullname_with_scope();
auto new_func_graph = CreateExpandFuncGraph(node); auto new_func_graph = CreateExpandFuncGraph(node);
if (new_func_graph == nullptr) { if (new_func_graph == nullptr) {
MS_LOG(ERROR) << "Decode fused nodes failed, " << node->fullname_with_scope();
continue; continue;
} }
mng->AddFuncGraph(new_func_graph); mng->AddFuncGraph(new_func_graph);
MS_LOG(DEBUG) << "decode fused nodes success.";
auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node); auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node);
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node))); new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node)));