forked from mindspore-Ecosystem/mindspore
!5783 GraphKernel supports GPU
Merge pull request !5783 from DeshiChen/graph_kernel_1.0
This commit is contained in:
commit
7152fe04be
2
akg
2
akg
|
@ -1 +1 @@
|
||||||
Subproject commit 3bb6264188d0b1d6ff776a35a571bc7190df0800
|
Subproject commit d237aa7d8e9d3fb709bda9f30205b02129bc2b59
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""init"""
|
||||||
|
from .splitter import split_with_json
|
||||||
|
from .expander import get_op_expander
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""generate json desc for graph kernel ops"""
|
||||||
|
import json
|
||||||
|
import json.decoder as jd
|
||||||
|
import traceback
|
||||||
|
from mindspore import log as logger
|
||||||
|
import mindspore._extends.graph_kernel.expanders as expanders
|
||||||
|
|
||||||
|
|
||||||
|
def get_op_expander(json_str: str):
|
||||||
|
"""get op expander by json info"""
|
||||||
|
try:
|
||||||
|
kernel_info = json.loads(json_str)
|
||||||
|
expand_info = kernel_info['expand_info']
|
||||||
|
|
||||||
|
if 'name' not in expand_info:
|
||||||
|
logger.error("expand info have no op name")
|
||||||
|
return None
|
||||||
|
if 'process' not in expand_info:
|
||||||
|
logger.error("expand info have no processor info")
|
||||||
|
return None
|
||||||
|
|
||||||
|
processor = expand_info['process']
|
||||||
|
op_name = str(expand_info['name']).lower()
|
||||||
|
expand_op_func_name = 'expand_' + op_name
|
||||||
|
if not hasattr(expanders, expand_op_func_name):
|
||||||
|
logger.error("Generator do not support op: {}".format(op_name))
|
||||||
|
return None
|
||||||
|
expand_op_func = getattr(expanders, expand_op_func_name)
|
||||||
|
# generate graph desc.
|
||||||
|
graph = expand_op_func(expand_info)
|
||||||
|
if graph is None:
|
||||||
|
logger.error("Failed to generate graph of: {}".format(op_name))
|
||||||
|
return None
|
||||||
|
|
||||||
|
graph.set_processor(processor)
|
||||||
|
|
||||||
|
# dump graph to json desc.
|
||||||
|
desc = graph.dump()
|
||||||
|
return json.dumps(desc)
|
||||||
|
|
||||||
|
except jd.JSONDecodeError:
|
||||||
|
logger.error("Failed to generate graph kernel op")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return None
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""expanders init"""
|
||||||
|
|
||||||
|
from .gelu import expand_gelu
|
||||||
|
from .layernorm import expand_layernorm
|
||||||
|
from .softmax import expand_softmax
|
||||||
|
from .square import expand_square
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for gelu"""
|
||||||
|
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||||
|
|
||||||
|
CSVALUE = 0.044715
|
||||||
|
CSVALUE_A = 1.5957691 # 2*np.sqrt(2/np.pi)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_gelu(expand_info):
|
||||||
|
"""Gelu expander"""
|
||||||
|
|
||||||
|
# get op info.
|
||||||
|
input_desc = expand_info['input_desc'][0]
|
||||||
|
graph_builder = builder.GraphBuilder()
|
||||||
|
|
||||||
|
# generate a graph.
|
||||||
|
with graph_builder.graph_scope('main') as graph_scope:
|
||||||
|
# create tensor input.
|
||||||
|
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||||
|
dtype = input_x.dtype
|
||||||
|
if dtype == 'float16':
|
||||||
|
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
|
||||||
|
|
||||||
|
# cal tanh.
|
||||||
|
mul_0 = graph_builder.emit('Mul', [input_x, input_x])
|
||||||
|
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
|
||||||
|
const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format'])
|
||||||
|
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
|
||||||
|
tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1])
|
||||||
|
|
||||||
|
const_csvalue_a = graph_builder.value(tanh_res.dtype, CSVALUE_A, input_desc['format'])
|
||||||
|
mul_0 = graph_builder.emit('Mul', [tanh_res, const_csvalue_a])
|
||||||
|
|
||||||
|
const_zero = graph_builder.value(mul_0.dtype, 0.0, input_desc['format'])
|
||||||
|
mul_0_min = graph_builder.emit('Minimum', [mul_0, const_zero])
|
||||||
|
right_mul = graph_builder.emit('Exp', [mul_0_min])
|
||||||
|
|
||||||
|
mul_0_abs = graph_builder.emit('Abs', [mul_0])
|
||||||
|
const_neg_one = graph_builder.value(mul_0_abs.dtype, -1.0, input_desc['format'])
|
||||||
|
mul_0_abs_neg = graph_builder.emit('Mul', [mul_0_abs, const_neg_one])
|
||||||
|
|
||||||
|
mul_0_abs_neg_exp = graph_builder.emit('Exp', [mul_0_abs_neg])
|
||||||
|
|
||||||
|
const_one = graph_builder.value(mul_0_abs_neg_exp.dtype, 1.0, input_desc['format'])
|
||||||
|
mul_0_abs_neg_exp_add = graph_builder.emit('TensorAdd', [mul_0_abs_neg_exp, const_one])
|
||||||
|
left_mul = graph_builder.emit('RealDiv', [input_x, mul_0_abs_neg_exp_add])
|
||||||
|
|
||||||
|
result = graph_builder.emit('Mul', [left_mul, right_mul])
|
||||||
|
if dtype == 'float16':
|
||||||
|
result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'})
|
||||||
|
# set graph output.
|
||||||
|
graph_scope.set_output(result)
|
||||||
|
|
||||||
|
graph = graph_builder.get()[0]
|
||||||
|
return graph
|
|
@ -0,0 +1,87 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for LayerNorm"""
|
||||||
|
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||||
|
|
||||||
|
|
||||||
|
def expand_layernorm(expand_info):
|
||||||
|
"""LayerNorm expander"""
|
||||||
|
|
||||||
|
# get op info.
|
||||||
|
input_desc_0 = expand_info['input_desc'][0]
|
||||||
|
input_desc_1 = expand_info['input_desc'][1]
|
||||||
|
input_desc_2 = expand_info['input_desc'][2]
|
||||||
|
attrs = expand_info['attr']
|
||||||
|
begin_norm_axis = None
|
||||||
|
epsilon = None
|
||||||
|
for item in attrs:
|
||||||
|
if 'begin_norm_axis' in item:
|
||||||
|
begin_norm_axis = item['begin_norm_axis']
|
||||||
|
if 'epsilon' in item:
|
||||||
|
epsilon = item['epsilon']
|
||||||
|
graph_builder = builder.GraphBuilder()
|
||||||
|
|
||||||
|
# generate a graph.
|
||||||
|
with graph_builder.graph_scope('main') as graph_scope:
|
||||||
|
# create tensor input.
|
||||||
|
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||||
|
input_gamma = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||||
|
input_beta = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||||
|
|
||||||
|
# Calculate the scaling ratio of the average
|
||||||
|
shape_x = input_desc_0['shape']
|
||||||
|
if begin_norm_axis < 0:
|
||||||
|
begin_norm_axis += len(shape_x)
|
||||||
|
reduce_axis = ()
|
||||||
|
for i, _ in enumerate(shape_x):
|
||||||
|
if i > begin_norm_axis or i == begin_norm_axis:
|
||||||
|
reduce_axis = reduce_axis + (i,)
|
||||||
|
|
||||||
|
reduce_elts = 1.0
|
||||||
|
for i in reduce_axis:
|
||||||
|
reduce_elts *= shape_x[i]
|
||||||
|
mean_cof = 1.0 / reduce_elts
|
||||||
|
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof, input_x.data_format)
|
||||||
|
|
||||||
|
# Calculate mean
|
||||||
|
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
|
||||||
|
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
|
||||||
|
|
||||||
|
# Calculate variance
|
||||||
|
variance_sub = graph_builder.emit('Sub', [input_x, mean])
|
||||||
|
variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub])
|
||||||
|
variance_red = graph_builder.emit('ReduceSum', [variance_mul],
|
||||||
|
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
|
||||||
|
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
|
||||||
|
|
||||||
|
# Calculate normalize
|
||||||
|
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
|
||||||
|
epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format)
|
||||||
|
normalize_add = graph_builder.emit('TensorAdd', [variance, epsilon_v])
|
||||||
|
normalize_log = graph_builder.emit('Log', [normalize_add])
|
||||||
|
input_y = graph_builder.value(input_x.dtype, -0.5, input_x.data_format)
|
||||||
|
normalize_log_mul = graph_builder.emit('Mul', [normalize_log, input_y])
|
||||||
|
normalize_exp = graph_builder.emit('Exp', [normalize_log_mul])
|
||||||
|
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normalize_exp])
|
||||||
|
|
||||||
|
# Calculate scale and translate
|
||||||
|
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
|
||||||
|
res = graph_builder.emit('TensorAdd', [scale_mul, input_beta])
|
||||||
|
|
||||||
|
# set graph output.
|
||||||
|
graph_scope.set_output(res, mean, variance)
|
||||||
|
|
||||||
|
graph = graph_builder.get()[0]
|
||||||
|
return graph
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for softmax"""
|
||||||
|
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||||
|
|
||||||
|
|
||||||
|
def expand_softmax(expand_info):
|
||||||
|
"""Softmax expander"""
|
||||||
|
|
||||||
|
# get op info.
|
||||||
|
input_desc = expand_info['input_desc'][0]
|
||||||
|
attrs = expand_info['attr']
|
||||||
|
axis = None
|
||||||
|
for item in attrs:
|
||||||
|
if 'axis' in item:
|
||||||
|
axis = item['axis']
|
||||||
|
graph_builder = builder.GraphBuilder()
|
||||||
|
|
||||||
|
# generate a graph.
|
||||||
|
with graph_builder.graph_scope('main') as graph_scope:
|
||||||
|
# create tensor input.
|
||||||
|
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||||
|
# cal softmax.
|
||||||
|
|
||||||
|
if input_x.dtype == 'float32':
|
||||||
|
input_x_cast = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
|
||||||
|
max_x = graph_builder.emit('ReduceMax', [input_x_cast], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||||
|
max_x = graph_builder.emit('Cast', [max_x], attrs={'dst_type': 'float32'})
|
||||||
|
else:
|
||||||
|
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||||
|
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
||||||
|
data_exp = graph_builder.emit('Exp', [data_sub])
|
||||||
|
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
|
||||||
|
result = graph_builder.emit('RealDiv', [data_exp, data_expsum])
|
||||||
|
# set graph output.
|
||||||
|
graph_scope.set_output(result)
|
||||||
|
|
||||||
|
graph = graph_builder.get()[0]
|
||||||
|
return graph
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for square"""
|
||||||
|
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||||
|
|
||||||
|
|
||||||
|
def expand_square(expand_info):
|
||||||
|
"""Square expander"""
|
||||||
|
|
||||||
|
# get op info.
|
||||||
|
input_desc = expand_info['input_desc'][0]
|
||||||
|
graph_builder = builder.GraphBuilder()
|
||||||
|
|
||||||
|
# generate a graph.
|
||||||
|
with graph_builder.graph_scope('main') as graph_scope:
|
||||||
|
# create tensor input.
|
||||||
|
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
|
||||||
|
# create op.
|
||||||
|
result = graph_builder.emit('Mul', [input_x, input_x])
|
||||||
|
# set graph output.
|
||||||
|
graph_scope.set_output(result)
|
||||||
|
|
||||||
|
graph = graph_builder.get()[0]
|
||||||
|
return graph
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""GraphKernel cost model init"""
|
||||||
|
|
||||||
|
from .graph_split import split
|
||||||
|
from .model_builder import GraphBuilder, load_composite
|
|
@ -0,0 +1,153 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""Cost model splitter"""
|
||||||
|
|
||||||
|
from .model import PrimLib, Graph
|
||||||
|
|
||||||
|
|
||||||
|
class GraphSplitByPattern:
|
||||||
|
"""Graph split by pattern"""
|
||||||
|
|
||||||
|
def __init__(self, graph):
|
||||||
|
self.graph = graph
|
||||||
|
self.groups = []
|
||||||
|
self.op_group = {}
|
||||||
|
for op in self.graph.ops:
|
||||||
|
g = [op]
|
||||||
|
self.groups.append(g)
|
||||||
|
self.op_group[op] = g
|
||||||
|
self.ids = {}
|
||||||
|
for i, op in enumerate(graph.ops):
|
||||||
|
self.ids[op] = i
|
||||||
|
self.doms = self.post_dom(graph.ops)
|
||||||
|
_, outputs = graph.deduce_parameters()
|
||||||
|
self.outputs = set(outputs)
|
||||||
|
|
||||||
|
def post_dom(self, ops):
|
||||||
|
"""Post dom"""
|
||||||
|
doms, i_doms = {}, {}
|
||||||
|
for i in range(len(ops) - 1, -1, -1):
|
||||||
|
op = ops[i]
|
||||||
|
doms[op] = {op}
|
||||||
|
i_dom = None
|
||||||
|
if op.output.to_ops:
|
||||||
|
suc_dom = set(doms[op.output.to_ops[0]])
|
||||||
|
for to in op.output.to_ops[1:]:
|
||||||
|
suc_dom.intersection_update(doms[to])
|
||||||
|
doms[op].update(suc_dom)
|
||||||
|
for dom in suc_dom:
|
||||||
|
if i_dom is None or self.ids[dom] < self.ids[i_dom]:
|
||||||
|
i_dom = dom
|
||||||
|
i_doms[op] = i_dom
|
||||||
|
return i_doms
|
||||||
|
|
||||||
|
def get_pattern(self, op, i):
|
||||||
|
"""Get pattern"""
|
||||||
|
pattern = PrimLib.UNKNOWN
|
||||||
|
_, elem_relation = PrimLib.input_relation(op, i)
|
||||||
|
for pat in elem_relation:
|
||||||
|
if pat and pat > pattern:
|
||||||
|
pattern = pat
|
||||||
|
return pattern
|
||||||
|
|
||||||
|
def fuse(self, check_fun):
|
||||||
|
"""Fuse ops"""
|
||||||
|
def _get_path(op, dom):
|
||||||
|
path_ops, visited = [], set()
|
||||||
|
|
||||||
|
def _get_path_depth(p):
|
||||||
|
visited.add(p)
|
||||||
|
if self.op_group[p][0] == p:
|
||||||
|
path_ops.append(p)
|
||||||
|
for to in p.output.to_ops:
|
||||||
|
if to != dom and to not in visited:
|
||||||
|
_get_path_depth(to)
|
||||||
|
_get_path_depth(op)
|
||||||
|
return path_ops
|
||||||
|
changed = True
|
||||||
|
while changed:
|
||||||
|
for group in self.groups:
|
||||||
|
op = group[0]
|
||||||
|
dom = self.doms[op]
|
||||||
|
if dom is None or op.output in self.outputs:
|
||||||
|
continue
|
||||||
|
ops = _get_path(op, dom)
|
||||||
|
if check_fun(op, dom, ops):
|
||||||
|
dom_group = self.op_group[dom]
|
||||||
|
fused = []
|
||||||
|
for fop in ops:
|
||||||
|
f_group = self.op_group[fop]
|
||||||
|
for p in f_group:
|
||||||
|
self.op_group[p] = dom_group
|
||||||
|
fused.append(f_group)
|
||||||
|
dom_group += f_group
|
||||||
|
for g in fused:
|
||||||
|
self.groups.remove(g)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
def to_subgraphs(self):
|
||||||
|
"""Transform op groups to subgraphs"""
|
||||||
|
subgraphs = []
|
||||||
|
for i, group in enumerate(self.groups):
|
||||||
|
group.sort(key=lambda op: self.ids[op])
|
||||||
|
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), group))
|
||||||
|
return subgraphs
|
||||||
|
|
||||||
|
def split(self):
|
||||||
|
"""Split graph"""
|
||||||
|
def _buddy(op, dom, path_ops):
|
||||||
|
"""Fuse buddy together"""
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
group = self.op_group[op]
|
||||||
|
for p in group:
|
||||||
|
# p is buddy
|
||||||
|
if p.output.buddy is not None and p.output.buddy.members[0].op not in group:
|
||||||
|
return True
|
||||||
|
# p's output is buddy
|
||||||
|
for to in p.output.to_ops:
|
||||||
|
if to.output.buddy is not None and to not in group:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _injective(pattern, limit):
|
||||||
|
def _checker(op, dom, path_ops):
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
for p in op.output.to_ops:
|
||||||
|
if p not in self.op_group[dom]:
|
||||||
|
return False
|
||||||
|
if PrimLib.iter_type(op) in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
||||||
|
for i, t in enumerate(dom.inputs):
|
||||||
|
if t == op.output:
|
||||||
|
return self.get_pattern(dom, i) == pattern and len(self.op_group[op]) < limit
|
||||||
|
return False
|
||||||
|
return _checker
|
||||||
|
|
||||||
|
def _diamond(op, dom, path_ops):
|
||||||
|
if PrimLib.iter_type(op) not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
|
||||||
|
PrimLib.iter_type(dom) in (PrimLib.UNKNOWN, PrimLib.TRANSFORM):
|
||||||
|
return False
|
||||||
|
return len(path_ops) == 1 and op.output not in dom.inputs
|
||||||
|
self.fuse(_buddy)
|
||||||
|
self.fuse(_injective(PrimLib.ELEMWISE, 100))
|
||||||
|
self.fuse(_injective(PrimLib.BROADCAST, 6))
|
||||||
|
self.fuse(_injective(PrimLib.REDUCE, 6))
|
||||||
|
self.fuse(_diamond)
|
||||||
|
return self.to_subgraphs()
|
||||||
|
|
||||||
|
|
||||||
|
def split(graph):
|
||||||
|
return GraphSplitByPattern(graph).split()
|
|
@ -0,0 +1,473 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""GraphKernel cost model"""
|
||||||
|
|
||||||
|
|
||||||
|
class Utils:
|
||||||
|
"""Model utils"""
|
||||||
|
@staticmethod
|
||||||
|
def get_attr_type(attr):
|
||||||
|
"""Get attr type"""
|
||||||
|
if isinstance(attr, bool):
|
||||||
|
return 'bool'
|
||||||
|
if isinstance(attr, str):
|
||||||
|
return 'str'
|
||||||
|
if isinstance(attr, int):
|
||||||
|
return 'int'
|
||||||
|
if isinstance(attr, float):
|
||||||
|
return 'bool'
|
||||||
|
if isinstance(attr, (list, tuple)):
|
||||||
|
if not attr:
|
||||||
|
raise ValueError("Length of attr is 0")
|
||||||
|
if isinstance(attr[0], int):
|
||||||
|
return 'listInt'
|
||||||
|
if isinstance(attr[0], str):
|
||||||
|
return 'listStr'
|
||||||
|
raise ValueError("Unknown type of attr: {}".format(attr))
|
||||||
|
|
||||||
|
|
||||||
|
class DataFormat:
|
||||||
|
"""DataFormat"""
|
||||||
|
DEFAULT = "DefaultFormat"
|
||||||
|
NC1KHKWHWC0 = "NC1KHKWHWC0"
|
||||||
|
ND = "ND"
|
||||||
|
NCHW = "NCHW"
|
||||||
|
NHWC = "NHWC"
|
||||||
|
HWCN = "HWCN"
|
||||||
|
NC1HWC0 = "NC1HWC0"
|
||||||
|
FRAC_Z = "FracZ"
|
||||||
|
FRAC_NZ = "FRACTAL_NZ"
|
||||||
|
C1HWNCOC0 = "C1HWNCoC0"
|
||||||
|
NC1HWC0_C04 = "NC1HWC0_C04"
|
||||||
|
FRACTAL_Z_C04 = "FRACTAL_Z_C04"
|
||||||
|
NDHWC = "NDHWC"
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
R0 = 8.0
|
||||||
|
UB_SIZE = 256 * 1024
|
||||||
|
MAX_BLOCK = 32
|
||||||
|
|
||||||
|
|
||||||
|
class PrimLib:
|
||||||
|
"""Prim lib"""
|
||||||
|
|
||||||
|
UNKNOWN = 0
|
||||||
|
ELEMWISE = 1
|
||||||
|
BROADCAST = 2
|
||||||
|
REDUCE = 3
|
||||||
|
TRANSFORM = 4
|
||||||
|
CONTROL = 5
|
||||||
|
|
||||||
|
class Prim:
|
||||||
|
"""Prim"""
|
||||||
|
|
||||||
|
def __init__(self, iter_type, calibrate=1, relation_func=None):
|
||||||
|
self.iter_type = iter_type
|
||||||
|
self.calibrate = calibrate
|
||||||
|
self.relation_func = relation_func
|
||||||
|
if relation_func is None:
|
||||||
|
self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
|
||||||
|
|
||||||
|
def default_elemwise_broadcast_relation(self, op, input_idx):
|
||||||
|
"""Process elemwise and broadcast relation"""
|
||||||
|
out_shape = op.output.shape
|
||||||
|
in_shape = op.inputs[input_idx].shape
|
||||||
|
assert len(out_shape) >= len(in_shape)
|
||||||
|
axis_relation, elem_relation = [], []
|
||||||
|
delta = len(out_shape) - len(in_shape)
|
||||||
|
if delta > 0:
|
||||||
|
for i in range(0, delta):
|
||||||
|
axis_relation.append(None)
|
||||||
|
elem_relation.append(None)
|
||||||
|
for i, _ in enumerate(in_shape):
|
||||||
|
axis_relation.append(i)
|
||||||
|
elem_relation.append(
|
||||||
|
PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
|
||||||
|
return axis_relation, elem_relation
|
||||||
|
|
||||||
|
def default_reduce_relation(self, op, input_idx):
|
||||||
|
"""Process reduce relation"""
|
||||||
|
axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
|
||||||
|
for i in op.attrs['reduce_axis']:
|
||||||
|
elem_relation[i] = PrimLib.REDUCE
|
||||||
|
return axis_relation, elem_relation
|
||||||
|
|
||||||
|
def unknown_relation(self, op, input_idx):
|
||||||
|
"""Process unknown relation"""
|
||||||
|
out_shape = op.output.shape
|
||||||
|
in_shape = op.inputs[input_idx].shape
|
||||||
|
all_relation = list(range(len(in_shape)))
|
||||||
|
axis_relation = [all_relation for i in range(0, len(out_shape))]
|
||||||
|
elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
|
||||||
|
return axis_relation, elem_relation
|
||||||
|
|
||||||
|
default_relation_func = [
|
||||||
|
unknown_relation,
|
||||||
|
default_elemwise_broadcast_relation,
|
||||||
|
default_elemwise_broadcast_relation,
|
||||||
|
default_reduce_relation,
|
||||||
|
unknown_relation,
|
||||||
|
unknown_relation,
|
||||||
|
]
|
||||||
|
|
||||||
|
primtives = {
|
||||||
|
'TensorAdd': Prim(ELEMWISE),
|
||||||
|
'Abs': Prim(ELEMWISE),
|
||||||
|
'Neg': Prim(ELEMWISE),
|
||||||
|
'Mul': Prim(ELEMWISE),
|
||||||
|
'Sub': Prim(ELEMWISE),
|
||||||
|
'Log': Prim(ELEMWISE),
|
||||||
|
'Exp': Prim(ELEMWISE),
|
||||||
|
'Rsqrt': Prim(ELEMWISE),
|
||||||
|
'Sqrt': Prim(ELEMWISE),
|
||||||
|
'RealDiv': Prim(ELEMWISE),
|
||||||
|
'Cast': Prim(ELEMWISE),
|
||||||
|
'Pow': Prim(ELEMWISE),
|
||||||
|
'Minimum': Prim(ELEMWISE),
|
||||||
|
'Maximum': Prim(ELEMWISE),
|
||||||
|
'Reciprocal': Prim(ELEMWISE),
|
||||||
|
'Equal': Prim(ELEMWISE),
|
||||||
|
'Greater': Prim(ELEMWISE),
|
||||||
|
'GreaterEqual': Prim(ELEMWISE),
|
||||||
|
'Less': Prim(ELEMWISE),
|
||||||
|
'LessEqual': Prim(ELEMWISE),
|
||||||
|
'Square': Prim(ELEMWISE),
|
||||||
|
'AddN': Prim(ELEMWISE),
|
||||||
|
'Select': Prim(ELEMWISE, 8),
|
||||||
|
'ReduceSum': Prim(REDUCE),
|
||||||
|
'ReduceMax': Prim(REDUCE),
|
||||||
|
'ReduceMin': Prim(REDUCE),
|
||||||
|
'make_tuple': Prim(CONTROL),
|
||||||
|
'ControlDepend': Prim(CONTROL),
|
||||||
|
'@ReduceInit': Prim(ELEMWISE),
|
||||||
|
}
|
||||||
|
|
||||||
|
default_primtive = Prim(UNKNOWN)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_prim(cls, op):
|
||||||
|
prim = cls.primtives.get(op.prim, None)
|
||||||
|
if prim is None:
|
||||||
|
print('[WARN] primtive is not registered: ' + op.prim)
|
||||||
|
prim = cls.default_primtive
|
||||||
|
return prim
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def input_relation(cls, op, input_idx):
|
||||||
|
return cls.get_prim(op).relation_func(op, input_idx)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def iter_type(cls, op):
|
||||||
|
return cls.get_prim(op).iter_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_reduce(cls, op):
|
||||||
|
return cls.get_prim(op).iter_type == cls.REDUCE
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def calibrate_iter_size(cls, op, iter_size):
|
||||||
|
return cls.get_prim(op).calibrate * iter_size
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dtype_bytes(cls, dtype):
|
||||||
|
bits, unit = 1, 1
|
||||||
|
for i in range(len(dtype) - 1, 0, -1):
|
||||||
|
if dtype[i].isdecimal():
|
||||||
|
bits += int(dtype[i]) * unit
|
||||||
|
unit *= 10
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return bits // 8
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def inplace_reuse(cls, op, input_idx, start_axis=0):
|
||||||
|
if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
|
||||||
|
return False
|
||||||
|
_, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
|
||||||
|
for i in range(start_axis, len(elem_relation)):
|
||||||
|
if elem_relation[i] != cls.ELEMWISE:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class Tensor:
|
||||||
|
"""Tensor"""
|
||||||
|
|
||||||
|
PARA_NONE = 0
|
||||||
|
PARA_INPUT = 1
|
||||||
|
PARA_OUTPUT = 2
|
||||||
|
|
||||||
|
class Buddy:
|
||||||
|
def __init__(self, leader):
|
||||||
|
self.members = [leader]
|
||||||
|
|
||||||
|
def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
|
||||||
|
self.name = name
|
||||||
|
self.shape = shape
|
||||||
|
self.dtype = dtype
|
||||||
|
self.data_format = data_format
|
||||||
|
self.para_type = para_type
|
||||||
|
self.op = None
|
||||||
|
self.to_ops = []
|
||||||
|
self.buddy = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name + str(list(self.shape))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
|
||||||
|
|
||||||
|
def get_size(self):
|
||||||
|
"""Get size"""
|
||||||
|
size = PrimLib.dtype_bytes(self.dtype)
|
||||||
|
for i in self.shape:
|
||||||
|
size *= i
|
||||||
|
return size
|
||||||
|
|
||||||
|
def add_buddy(self, tensor):
|
||||||
|
"""Add buddy"""
|
||||||
|
if self.buddy is None:
|
||||||
|
self.buddy = self.Buddy(self)
|
||||||
|
self.buddy.members.append(tensor)
|
||||||
|
tensor.buddy = self.buddy
|
||||||
|
|
||||||
|
|
||||||
|
class Value:
|
||||||
|
"""Value"""
|
||||||
|
|
||||||
|
def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
|
||||||
|
self.name = name
|
||||||
|
self.shape = [1]
|
||||||
|
self.dtype = dtype
|
||||||
|
self.value = value
|
||||||
|
self.data_format = data_format
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name + str(list(self.shape)) + str(self.value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "%s.%s%s%s" % (self.name, self.dtype, str(list(self.shape)), str(self.value))
|
||||||
|
|
||||||
|
def get_size(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
class Operator:
|
||||||
|
"""Operator"""
|
||||||
|
|
||||||
|
def __init__(self, primtive, inputs, output, attrs):
|
||||||
|
self.prim = primtive
|
||||||
|
self.inputs = inputs
|
||||||
|
self.output = output
|
||||||
|
self.attrs = attrs
|
||||||
|
for t in inputs:
|
||||||
|
t.to_ops.append(self)
|
||||||
|
if output.op is None:
|
||||||
|
output.op = self
|
||||||
|
self.all_inputs = [] # include Tensor inputs and Value inputs.
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
args = ', '.join([str(t) for t in self.all_inputs])
|
||||||
|
expr = "%s = %s.%s(%s)" % (
|
||||||
|
str(self.output), self.prim, self.output.dtype, args)
|
||||||
|
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
|
class Graph:
|
||||||
|
"""Graph"""
|
||||||
|
|
||||||
|
def __init__(self, name, ops):
|
||||||
|
self.name = name
|
||||||
|
self.ops = ops # in topo order, can not use set
|
||||||
|
self.outputs = []
|
||||||
|
|
||||||
|
def set_processor(self, processor):
|
||||||
|
"""Set processor"""
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
def add(self, ops):
|
||||||
|
"""Add ops"""
|
||||||
|
if isinstance(ops, Operator):
|
||||||
|
self.ops.append(ops)
|
||||||
|
else:
|
||||||
|
self.ops.extend(ops)
|
||||||
|
|
||||||
|
def extract_subgraph(self, graph_name, tensor_names, difference=False):
|
||||||
|
"""Extract subgraph from this graph"""
|
||||||
|
graph = Graph(graph_name, [])
|
||||||
|
outputs = set(tensor_names)
|
||||||
|
if difference:
|
||||||
|
for op in self.ops:
|
||||||
|
if op.output.name not in outputs:
|
||||||
|
graph.add(op)
|
||||||
|
else:
|
||||||
|
for op in self.ops:
|
||||||
|
if op.output.name in outputs:
|
||||||
|
graph.add(op)
|
||||||
|
outputs.remove(op.output.name)
|
||||||
|
for name in outputs:
|
||||||
|
raise ValueError("invalid input tensor : " + name)
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def deduce_parameters(self):
|
||||||
|
"""Deduce parameters"""
|
||||||
|
inputs, outputs = [], []
|
||||||
|
for op in self.ops:
|
||||||
|
for t in op.inputs:
|
||||||
|
if t not in inputs and t.op not in self.ops:
|
||||||
|
inputs.append(t)
|
||||||
|
if op.output not in outputs:
|
||||||
|
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
|
||||||
|
outputs.append(op.output)
|
||||||
|
else:
|
||||||
|
for d in op.output.to_ops:
|
||||||
|
if d not in self.ops:
|
||||||
|
outputs.append(op.output)
|
||||||
|
break
|
||||||
|
if self.outputs:
|
||||||
|
outputs = self.outputs
|
||||||
|
return inputs, outputs
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
inputs, outputs = self.deduce_parameters()
|
||||||
|
para_str = ', '.join([repr(t) for t in inputs])
|
||||||
|
out_str = ', '.join([repr(t) for t in outputs])
|
||||||
|
lines = []
|
||||||
|
lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
|
||||||
|
for op in self.ops:
|
||||||
|
lines.append(' ' + str(op))
|
||||||
|
lines.append('}')
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
def dump(self):
|
||||||
|
"""Dump Graph to json"""
|
||||||
|
attr_name = {'reduce_axis': 'axis'}
|
||||||
|
inputs, outputs = self.deduce_parameters()
|
||||||
|
input_desc, output_desc, op_desc = [], [], []
|
||||||
|
for t in inputs:
|
||||||
|
input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
|
||||||
|
'tensor_name': t.name, 'format': t.data_format}])
|
||||||
|
for t in outputs:
|
||||||
|
output_desc.append({'data_type': t.dtype, 'shape': t.shape,
|
||||||
|
'tensor_name': t.name, 'format': t.data_format})
|
||||||
|
for op in self.ops:
|
||||||
|
attrs, in_desc = [], []
|
||||||
|
for a in op.attrs:
|
||||||
|
name = attr_name.get(a, a)
|
||||||
|
attrs.append(
|
||||||
|
{'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
|
||||||
|
for t in op.all_inputs:
|
||||||
|
if isinstance(t, Tensor):
|
||||||
|
in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
|
||||||
|
'tensor_name': t.name, 'format': t.data_format}])
|
||||||
|
else:
|
||||||
|
in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
|
||||||
|
'tensor_name': t.name, 'format': t.data_format}])
|
||||||
|
out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
|
||||||
|
'tensor_name': op.output.name, 'format': t.data_format}]
|
||||||
|
op_desc.append({'attr': attrs, 'impl_path': '',
|
||||||
|
'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
|
||||||
|
graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
|
||||||
|
'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
|
||||||
|
'platform': 'AKG', 'process': self.processor}
|
||||||
|
return graph_desc
|
||||||
|
|
||||||
|
|
||||||
|
class GraphVisitor:
|
||||||
|
"""Graph visitor"""
|
||||||
|
|
||||||
|
def __init__(self, forward=True, once_mode=True):
|
||||||
|
self.forward = forward
|
||||||
|
self.once_mode = once_mode
|
||||||
|
if self.once_mode:
|
||||||
|
self.visited = set()
|
||||||
|
|
||||||
|
def visit_graph(self, graph):
|
||||||
|
"""Visit graph"""
|
||||||
|
inputs, outputs = graph.deduce_parameters()
|
||||||
|
if self.forward:
|
||||||
|
for tensor in inputs:
|
||||||
|
for op in tensor.to_ops:
|
||||||
|
self.visit(op)
|
||||||
|
else:
|
||||||
|
for tensor in outputs:
|
||||||
|
if not tensor.to_ops:
|
||||||
|
self.visit(tensor.op)
|
||||||
|
|
||||||
|
def visit(self, op):
|
||||||
|
"""Visit op"""
|
||||||
|
next_ops = op.output.to_ops if self.forward else [
|
||||||
|
t.op for t in op.inputs if t.op is not None]
|
||||||
|
if self.once_mode:
|
||||||
|
self.visited.add(op)
|
||||||
|
for n in next_ops:
|
||||||
|
if n not in self.visited:
|
||||||
|
self.visit(n)
|
||||||
|
else:
|
||||||
|
for n in next_ops:
|
||||||
|
self.visit(n)
|
||||||
|
|
||||||
|
|
||||||
|
class AlignShape(GraphVisitor):
|
||||||
|
"""Align shape"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(once_mode=False)
|
||||||
|
|
||||||
|
def visit(self, op):
|
||||||
|
prim = PrimLib.get_prim(op)
|
||||||
|
if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
|
||||||
|
out_dim = len(op.output.shape)
|
||||||
|
align_dim = out_dim
|
||||||
|
for t in op.inputs:
|
||||||
|
if len(t.shape) > align_dim:
|
||||||
|
align_dim = len(t.shape)
|
||||||
|
if align_dim > out_dim:
|
||||||
|
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
|
||||||
|
super().visit(op)
|
||||||
|
|
||||||
|
|
||||||
|
class AddControlBuddy(GraphVisitor):
|
||||||
|
"""Add control buddy"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.buddies = {} # {op : [ctrl_op]}
|
||||||
|
|
||||||
|
def visit(self, op):
|
||||||
|
if PrimLib.iter_type(op) == PrimLib.CONTROL:
|
||||||
|
assert len(op.output.to_ops) == 1
|
||||||
|
owner = op.output.to_ops[0]
|
||||||
|
if owner in self.buddies:
|
||||||
|
self.buddies[owner].append(op)
|
||||||
|
else:
|
||||||
|
self.buddies[owner] = [op]
|
||||||
|
if op in self.buddies:
|
||||||
|
ops = self.buddies.pop(op)
|
||||||
|
self.buddies[owner].extend(ops)
|
||||||
|
super().visit(op)
|
||||||
|
|
||||||
|
def visit_graph(self, graph):
|
||||||
|
super().visit_graph(graph)
|
||||||
|
for owner in self.buddies:
|
||||||
|
for op in self.buddies[owner]:
|
||||||
|
owner.add_buddy(op.output)
|
|
@ -0,0 +1,292 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""GraphKernel model builder"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy
|
||||||
|
|
||||||
|
|
||||||
|
class OpInfer:
|
||||||
|
"""Op infer"""
|
||||||
|
@staticmethod
|
||||||
|
def default_reduce_infer(inputs, attrs):
|
||||||
|
shape = copy.deepcopy(inputs[0].shape)
|
||||||
|
for i in attrs['reduce_axis']:
|
||||||
|
shape[i] = 1
|
||||||
|
return shape
|
||||||
|
|
||||||
|
default_infer_shape_func = [
|
||||||
|
None,
|
||||||
|
lambda inputs, attrs: max([t.shape for t in inputs]),
|
||||||
|
lambda inputs, attrs: max([t.shape for t in inputs]),
|
||||||
|
default_reduce_infer.__func__,
|
||||||
|
None,
|
||||||
|
lambda inputs, attrs: [1], # control op
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_infer_dtype_func(inputs, attrs):
|
||||||
|
"""Infer dtype"""
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
return inputs[0].dtype
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_infer_format_func(inputs, attrs):
|
||||||
|
"""Infer format"""
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
return inputs[0].data_format
|
||||||
|
|
||||||
|
infer_shape_func = {
|
||||||
|
# add special infer func here
|
||||||
|
}
|
||||||
|
infer_dtype_func = {
|
||||||
|
# add special infer func here
|
||||||
|
'Cast': lambda inputs, attrs: attrs['dst_type'],
|
||||||
|
}
|
||||||
|
infer_format_func = {
|
||||||
|
# add special infer func here
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def infer(cls, prim_name, inputs, attrs):
|
||||||
|
prim = PrimLib.primtives[prim_name]
|
||||||
|
infer_shape = cls.infer_shape_func.get(
|
||||||
|
prim_name, cls.default_infer_shape_func[prim.iter_type])
|
||||||
|
infer_dtype = cls.infer_dtype_func.get(
|
||||||
|
prim_name, cls.default_infer_dtype_func)
|
||||||
|
infer_format = cls.infer_format_func.get(
|
||||||
|
prim_name, cls.default_infer_format_func)
|
||||||
|
return infer_shape(inputs, attrs), infer_dtype(inputs, attrs), infer_format(inputs, attrs)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphBuilder:
|
||||||
|
"""Graph builder"""
|
||||||
|
class GraphWrapper:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.graph = Graph(name, [])
|
||||||
|
|
||||||
|
def set_output(self, *para):
|
||||||
|
for t in para:
|
||||||
|
t.para_type = Tensor.PARA_OUTPUT
|
||||||
|
self.graph.outputs.append(t)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.graphs = []
|
||||||
|
self.current = None
|
||||||
|
self.name_id = 0
|
||||||
|
|
||||||
|
def _alloc_tensor_name(self):
|
||||||
|
tid = self.name_id
|
||||||
|
self.name_id += 1
|
||||||
|
return "t%d" % (tid)
|
||||||
|
|
||||||
|
def graph_scope(self, name):
|
||||||
|
"""The graph scope to be processed"""
|
||||||
|
class GraphScope:
|
||||||
|
def __init__(self, gb):
|
||||||
|
self.gb = gb
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self.gb.current
|
||||||
|
|
||||||
|
def __exit__(self, ptype, value, trace):
|
||||||
|
self.gb.graphs.append(self.gb.current.graph)
|
||||||
|
self.gb.current = None
|
||||||
|
|
||||||
|
assert self.current is None
|
||||||
|
self.current = self.GraphWrapper(name)
|
||||||
|
return GraphScope(self)
|
||||||
|
|
||||||
|
def tensor(self, shape, dtype, data_format="DefaultFormat", name=None, para_type=Tensor.PARA_NONE):
|
||||||
|
"""Create a new Tensor"""
|
||||||
|
if name in (None, ''):
|
||||||
|
name = self._alloc_tensor_name()
|
||||||
|
if not shape:
|
||||||
|
shape = [1]
|
||||||
|
return Tensor(name, shape, dtype, data_format, para_type=para_type)
|
||||||
|
|
||||||
|
def value(self, dtype, value, data_format, name=None):
|
||||||
|
"""Create a new Value"""
|
||||||
|
if name in (None, ''):
|
||||||
|
name = self._alloc_tensor_name()
|
||||||
|
return Value(name, dtype, value, data_format)
|
||||||
|
|
||||||
|
def op(self, prim, output, inputs, attrs=None):
|
||||||
|
"""Insert an operator into graph"""
|
||||||
|
if attrs is None:
|
||||||
|
attrs = {}
|
||||||
|
if isinstance(inputs, Tensor):
|
||||||
|
inputs = [inputs]
|
||||||
|
tensor_inputs = [t for t in inputs if isinstance(t, Tensor)]
|
||||||
|
node = Operator(prim, tensor_inputs, output, attrs)
|
||||||
|
node.all_inputs = inputs
|
||||||
|
self.current.graph.add(node)
|
||||||
|
|
||||||
|
def emit(self, prim, inputs, name=None, attrs=None):
|
||||||
|
"""Emit a new operation"""
|
||||||
|
if attrs is None:
|
||||||
|
attrs = {}
|
||||||
|
if isinstance(inputs, Tensor):
|
||||||
|
inputs = [inputs]
|
||||||
|
tensor_inputs = [t for t in inputs if isinstance(t, Tensor)]
|
||||||
|
out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs)
|
||||||
|
output = self.tensor(out_shape, out_dtype, out_format, name)
|
||||||
|
self.op(prim, output, inputs, attrs)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return self.graphs
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeGraph:
|
||||||
|
"""Composite Graph"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.graph = None
|
||||||
|
self.desc = None
|
||||||
|
self.tensors = {} # name : Tensor
|
||||||
|
|
||||||
|
def refine(self):
|
||||||
|
"""Refine Graph"""
|
||||||
|
AlignShape().visit_graph(self.graph)
|
||||||
|
AddControlBuddy().visit_graph(self.graph)
|
||||||
|
|
||||||
|
def load(self, desc):
|
||||||
|
"""Load Graph from json"""
|
||||||
|
def _attr_of(op, inputs, output):
|
||||||
|
attr = {}
|
||||||
|
if op['name'] not in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
|
||||||
|
return attr
|
||||||
|
for a in op['attr']:
|
||||||
|
if a['name'] == 'axis':
|
||||||
|
red_axis, dim_size = [], len(inputs[0].shape)
|
||||||
|
if not a['value']:
|
||||||
|
assert len(output.shape) == len(inputs[0].shape)
|
||||||
|
for i in range(len(output.shape)):
|
||||||
|
if output.shape[i] == 1 and inputs[0].shape[i] > 1:
|
||||||
|
red_axis.append(i)
|
||||||
|
else:
|
||||||
|
for i in a['value']:
|
||||||
|
red_axis.append(i if i >= 0 else dim_size + i)
|
||||||
|
attr['reduce_axis'] = red_axis
|
||||||
|
break
|
||||||
|
return attr
|
||||||
|
|
||||||
|
builder = GraphBuilder()
|
||||||
|
with builder.graph_scope(desc['op']):
|
||||||
|
for in_desc in desc['input_desc']:
|
||||||
|
name, shape, dtype, data_format = in_desc[0]['tensor_name'], in_desc[
|
||||||
|
0]['shape'], in_desc[0]['data_type'], in_desc[0]['format']
|
||||||
|
self.tensors[name] = builder.tensor(
|
||||||
|
shape, dtype, data_format, name=name, para_type=Tensor.PARA_INPUT)
|
||||||
|
for out_desc in desc['output_desc']:
|
||||||
|
name, shape, dtype, data_format = out_desc['tensor_name'], out_desc[
|
||||||
|
'shape'], out_desc['data_type'], out_desc['format']
|
||||||
|
self.tensors[name] = builder.tensor(
|
||||||
|
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
|
||||||
|
cur_fusion = None
|
||||||
|
for op in desc['op_desc']:
|
||||||
|
inputs = [self.tensors[d[0]['tensor_name']]
|
||||||
|
for d in op['input_desc'] if 'value' not in d[0]]
|
||||||
|
out_desc = op['output_desc']
|
||||||
|
name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[
|
||||||
|
0]['shape'], out_desc[0]['data_type'], out_desc[0]['format']
|
||||||
|
if op['name'] == 'InplaceAssign':
|
||||||
|
inputs[0].add_buddy(inputs[1])
|
||||||
|
inputs[1].para_type = Tensor.PARA_OUTPUT
|
||||||
|
output = inputs[2]
|
||||||
|
self.tensors[name] = output
|
||||||
|
else:
|
||||||
|
output = self.tensors.get(name, None)
|
||||||
|
if not output:
|
||||||
|
output = builder.tensor(
|
||||||
|
shape, dtype, data_format, name=name)
|
||||||
|
self.tensors[name] = output
|
||||||
|
builder.op(op['name'], output, inputs,
|
||||||
|
attrs=_attr_of(op, inputs, output))
|
||||||
|
if 'fusion' in op:
|
||||||
|
if cur_fusion is None:
|
||||||
|
cur_fusion = output
|
||||||
|
else:
|
||||||
|
cur_fusion.add_buddy(output)
|
||||||
|
if op['fusion'].endswith('_end'):
|
||||||
|
cur_fusion = None
|
||||||
|
self.graph = builder.get()[0]
|
||||||
|
self.desc = desc
|
||||||
|
|
||||||
|
def dump(self, subgraph):
|
||||||
|
"""Dump Graph to json"""
|
||||||
|
desc = {}
|
||||||
|
inputs, outputs = subgraph.deduce_parameters()
|
||||||
|
graph_ops = set(subgraph.ops)
|
||||||
|
inplace_assign = {} # y_name, output_name
|
||||||
|
inplace_assign_z = None
|
||||||
|
for op in self.desc['op_desc']:
|
||||||
|
if op['name'] == 'InplaceAssign':
|
||||||
|
inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name']
|
||||||
|
if inplace_assign:
|
||||||
|
for t in outputs:
|
||||||
|
if t.name not in inplace_assign:
|
||||||
|
inplace_assign_z = t
|
||||||
|
for key in self.desc:
|
||||||
|
if key == 'input_desc':
|
||||||
|
desc[key] = [
|
||||||
|
[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
|
||||||
|
elif key == 'output_desc':
|
||||||
|
out_desc = []
|
||||||
|
for t in outputs:
|
||||||
|
if t.name in inplace_assign:
|
||||||
|
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
|
||||||
|
out_desc.append(
|
||||||
|
{'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]})
|
||||||
|
else:
|
||||||
|
out_desc.append(
|
||||||
|
{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name})
|
||||||
|
desc[key] = out_desc
|
||||||
|
elif key == 'op_desc':
|
||||||
|
op_desc = []
|
||||||
|
for d in self.desc[key]:
|
||||||
|
if d['name'] == 'InplaceAssign':
|
||||||
|
y = d['input_desc'][1][0]['tensor_name']
|
||||||
|
if self.tensors[y].op in graph_ops:
|
||||||
|
z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (
|
||||||
|
self.tensors[y], True)
|
||||||
|
inplace_desc = copy.deepcopy(d)
|
||||||
|
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
|
||||||
|
z_desc, out_desc = inplace_desc['input_desc'][2][0].inplace_desc['output_desc'][0]
|
||||||
|
z_desc['shape'] = z.shape
|
||||||
|
z_desc['data_type'] = z.dtype
|
||||||
|
z_desc['tensor_name'] = z.name
|
||||||
|
out_desc['shape'] = z.shape
|
||||||
|
out_desc['data_type'] = z.dtype
|
||||||
|
op_desc.append(inplace_desc)
|
||||||
|
else:
|
||||||
|
op = self.tensors[d['output_desc'][0]['tensor_name']].op
|
||||||
|
if op in graph_ops:
|
||||||
|
op_desc.append(d)
|
||||||
|
desc[key] = op_desc
|
||||||
|
elif key == 'op':
|
||||||
|
desc[key] = subgraph.name
|
||||||
|
else:
|
||||||
|
desc[key] = self.desc[key]
|
||||||
|
return desc
|
||||||
|
|
||||||
|
|
||||||
|
def load_composite(desc):
|
||||||
|
"""Load composite kernel"""
|
||||||
|
composite = CompositeGraph()
|
||||||
|
composite.load(desc)
|
||||||
|
composite.refine()
|
||||||
|
return composite
|
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""GraphKernel splitter"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import json.decoder as jd
|
||||||
|
import traceback
|
||||||
|
from mindspore import log as logger
|
||||||
|
from . import model
|
||||||
|
|
||||||
|
|
||||||
|
def split_with_json(json_str: str):
|
||||||
|
"""Call costmodel to split GraphKernel"""
|
||||||
|
try:
|
||||||
|
graph_desc = json.loads(json_str)
|
||||||
|
comp = model.load_composite(graph_desc)
|
||||||
|
graph_split = model.split(comp.graph)
|
||||||
|
is_multi_graph = len(graph_split) > 1
|
||||||
|
graph_list = list(map(comp.dump, graph_split))
|
||||||
|
result = {"multi_graph": is_multi_graph, "graph_desc": graph_list}
|
||||||
|
return json.dumps(result)
|
||||||
|
except jd.JSONDecodeError:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return None
|
|
@ -0,0 +1,17 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
PYTHONPATH="$(pwd)/..:${PYTHONPATH}"
|
||||||
|
export PYTHONPATH
|
|
@ -0,0 +1,142 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""graph kernel split"""
|
||||||
|
import json
|
||||||
|
import getopt
|
||||||
|
import sys
|
||||||
|
import model
|
||||||
|
|
||||||
|
|
||||||
|
def print_usage():
|
||||||
|
print('Usage: graph_kernel_split.py [OPTION] <JSON_FILE>')
|
||||||
|
print('Options:')
|
||||||
|
print(' -s <config/auto>\tsplit graph with config')
|
||||||
|
print(' -e \t\testimate graph')
|
||||||
|
print(' -i \t\tnaive estimate')
|
||||||
|
print(' -o <prefix>\toutput split graphs')
|
||||||
|
print(' -v \t\tverbose mode')
|
||||||
|
print(' -h \t\tprint this help')
|
||||||
|
print('Report bugs to xiong.gao@huawei.com')
|
||||||
|
|
||||||
|
|
||||||
|
class Option:
|
||||||
|
"""Options"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.split = None
|
||||||
|
self.estimate = False
|
||||||
|
self.estimate_naive = False
|
||||||
|
self.output = None
|
||||||
|
self.verbose = False
|
||||||
|
self.help = False
|
||||||
|
|
||||||
|
def parse(self, options):
|
||||||
|
"""parse options"""
|
||||||
|
for name, val in options:
|
||||||
|
if name == '-h':
|
||||||
|
self.help = True
|
||||||
|
elif name == '-v':
|
||||||
|
self.verbose = True
|
||||||
|
elif name == '-o':
|
||||||
|
self.output = val
|
||||||
|
elif name == '-e':
|
||||||
|
self.estimate = True
|
||||||
|
elif name == '-s':
|
||||||
|
self.split = val
|
||||||
|
elif name == '-i':
|
||||||
|
self.estimate_naive = True
|
||||||
|
|
||||||
|
|
||||||
|
opt = Option()
|
||||||
|
|
||||||
|
|
||||||
|
def estimate(graph_in, parts_in, naive):
|
||||||
|
"""estimate graphs costs"""
|
||||||
|
def _print_cost(name, c):
|
||||||
|
print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" %
|
||||||
|
(name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type()))
|
||||||
|
main_cost, _ = model.estimate(graph_in, naive)
|
||||||
|
split_cost, sub_costs = model.estimate(parts_in, naive) if parts_in else (None, None)
|
||||||
|
_print_cost("MainGraph:", main_cost)
|
||||||
|
if parts_in:
|
||||||
|
_print_cost("Subgraphs:", split_cost)
|
||||||
|
if opt.verbose:
|
||||||
|
for i, sub_cost in enumerate(sub_costs):
|
||||||
|
_print_cost(" |_%d:\t" % (i), sub_cost)
|
||||||
|
|
||||||
|
|
||||||
|
def split_graph(graph_in, config):
|
||||||
|
"""split graph"""
|
||||||
|
if config == 'auto':
|
||||||
|
return model.split(graph_in)
|
||||||
|
subgraphs = []
|
||||||
|
all_tensors = []
|
||||||
|
subgraph_idx = 0
|
||||||
|
config_parts = config.split('|')
|
||||||
|
for part in config_parts:
|
||||||
|
tensor_names = part.split(',')
|
||||||
|
graph_name = "%s_%d" % (graph_in.name, subgraph_idx)
|
||||||
|
g = graph_in.extract_subgraph(graph_name, tensor_names)
|
||||||
|
assert len(g.ops) == len(tensor_names)
|
||||||
|
subgraphs.append(g)
|
||||||
|
all_tensors += tensor_names
|
||||||
|
subgraph_idx += 1
|
||||||
|
if len(all_tensors) < len(graph_in.ops):
|
||||||
|
graph_name = "%s_%d" % (graph_in.name, subgraph_idx)
|
||||||
|
g = graph_in.extract_subgraph(graph_name, all_tensors, True)
|
||||||
|
subgraphs.append(g)
|
||||||
|
return subgraphs
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
opts, args = getopt.getopt(sys.argv[1:], 'heivo:s:')
|
||||||
|
opt.parse(opts)
|
||||||
|
if len(args) != 1 or opt.help:
|
||||||
|
print_usage()
|
||||||
|
sys.exit(0)
|
||||||
|
in_file = args[0]
|
||||||
|
with open(in_file, 'r') as f:
|
||||||
|
desc = json.loads(f.read())
|
||||||
|
comp = model.load_composite(desc)
|
||||||
|
graph = comp.graph
|
||||||
|
parts = []
|
||||||
|
# 1. split sub-graphs
|
||||||
|
if opt.split is not None:
|
||||||
|
parts = split_graph(graph, opt.split)
|
||||||
|
if opt.verbose:
|
||||||
|
print('----------- main graph --------------')
|
||||||
|
print(graph)
|
||||||
|
for i, _ in enumerate(parts):
|
||||||
|
print('---------------- sub graph %d ---------------' % (i))
|
||||||
|
print(parts[i])
|
||||||
|
# 2. estimate cost
|
||||||
|
if opt.estimate:
|
||||||
|
print('------------- cost --------------')
|
||||||
|
estimate(graph, parts, False)
|
||||||
|
if opt.estimate_naive:
|
||||||
|
print('------------- naive cost --------------')
|
||||||
|
estimate(graph, parts, True)
|
||||||
|
# 3. output parts
|
||||||
|
if opt.output is not None:
|
||||||
|
for graph_part in parts:
|
||||||
|
desc = comp.dump(graph_part)
|
||||||
|
s_desc = json.dumps(desc)
|
||||||
|
fname = "%s_%s.json" % (opt.output, graph_part.name)
|
||||||
|
with open(fname, 'w', encoding='utf-8') as of:
|
||||||
|
of.write(s_desc)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""test split"""
|
||||||
|
import model
|
||||||
|
|
||||||
|
|
||||||
|
def graph_1():
|
||||||
|
gb = model.GraphBuilder()
|
||||||
|
with gb.graph_scope("main"):
|
||||||
|
a = gb.tensor([1024, 16], "float32", name="a")
|
||||||
|
b = gb.emit("Abs", a, 'b')
|
||||||
|
c = gb.emit("Abs", b, 'c')
|
||||||
|
d = gb.emit("Abs", c, 'd')
|
||||||
|
gb.emit("TensorAdd", [b, d], "e")
|
||||||
|
return gb.get()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def graph_2():
|
||||||
|
gb = model.GraphBuilder()
|
||||||
|
with gb.graph_scope("main"):
|
||||||
|
a = gb.tensor([1024, 16], "float32", name="a")
|
||||||
|
b = gb.emit("Abs", a, 'b')
|
||||||
|
c = gb.emit("Abs", b, 'c')
|
||||||
|
d = gb.emit("ReduceSum", c, 'd', attrs={'reduce_axis': (1,)})
|
||||||
|
gb.emit("Sqrt", d, 'e')
|
||||||
|
return gb.get()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_by_pattern():
|
||||||
|
def _test(graph):
|
||||||
|
print("***************** main graph ***************")
|
||||||
|
print(graph)
|
||||||
|
subgraphs = model.split(graph)
|
||||||
|
for i, g in enumerate(subgraphs):
|
||||||
|
print('------------- subgraph {} --------------'.format(i))
|
||||||
|
print(g)
|
||||||
|
_test(graph_2())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_split_by_pattern()
|
|
@ -71,7 +71,8 @@ if(ENABLE_GPU)
|
||||||
"runtime/device/gpu/*.cu"
|
"runtime/device/gpu/*.cu"
|
||||||
"backend/kernel_compiler/gpu/*.cu"
|
"backend/kernel_compiler/gpu/*.cu"
|
||||||
"backend/kernel_compiler/akg/gpu/*.cc"
|
"backend/kernel_compiler/akg/gpu/*.cc"
|
||||||
"backend/kernel_compiler/akg/akg_kernel_build.cc"
|
"backend/kernel_compiler/akg/akg_kernel_json_generator.cc"
|
||||||
|
"backend/kernel_compiler/akg/akg_kernel_json_decoder.cc"
|
||||||
"backend/kernel_compiler/akg/akg_kernel_attrs_process.cc"
|
"backend/kernel_compiler/akg/akg_kernel_attrs_process.cc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,8 @@ if (ENABLE_D)
|
||||||
"kernel_query.cc"
|
"kernel_query.cc"
|
||||||
"kernel_fusion.cc"
|
"kernel_fusion.cc"
|
||||||
"akg/ascend/*.cc"
|
"akg/ascend/*.cc"
|
||||||
"akg/akg_kernel_build.cc"
|
"akg/akg_kernel_json_generator.cc"
|
||||||
|
"akg/akg_kernel_json_decoder.cc"
|
||||||
"akg/akg_kernel_attrs_process.cc"
|
"akg/akg_kernel_attrs_process.cc"
|
||||||
"akg/akg_kernel_metadata.cc"
|
"akg/akg_kernel_metadata.cc"
|
||||||
"tbe/*.cc"
|
"tbe/*.cc"
|
||||||
|
@ -49,7 +50,8 @@ if (ENABLE_GPU)
|
||||||
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"gpu/*.cu"
|
"gpu/*.cu"
|
||||||
"akg/gpu/*.cc"
|
"akg/gpu/*.cc"
|
||||||
"akg/akg_kernel_build.cc"
|
"akg/akg_kernel_json_generator.cc"
|
||||||
|
"akg/akg_kernel_json_decoder.cc"
|
||||||
"akg/akg_kernel_attrs_process.cc"
|
"akg/akg_kernel_attrs_process.cc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,6 @@
|
||||||
#include <climits>
|
#include <climits>
|
||||||
#include "runtime/device/kernel_runtime.h"
|
#include "runtime/device/kernel_runtime.h"
|
||||||
#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h"
|
#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h"
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
|
|
||||||
#include "proto/tensor.pb.h"
|
#include "proto/tensor.pb.h"
|
||||||
#include "proto/tensor_shape.pb.h"
|
#include "proto/tensor_shape.pb.h"
|
||||||
#include "proto/attr.pb.h"
|
#include "proto/attr.pb.h"
|
||||||
|
@ -33,6 +32,7 @@
|
||||||
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
|
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
|
||||||
#include "backend/session/kernel_graph.h"
|
#include "backend/session/kernel_graph.h"
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
|
|
@ -15,13 +15,20 @@
|
||||||
*/
|
*/
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
#include "utils/utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node) {
|
void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node) {
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
// The x and output are akg op input and output param.
|
// The x and output are akg op input and output param.
|
||||||
|
@ -169,5 +176,29 @@ void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) {
|
||||||
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_input_names), anf_node);
|
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_input_names), anf_node);
|
||||||
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node);
|
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_node)>> kAkgKernelAttrsProcessMap = {
|
||||||
|
{kFour2FiveOpName, SetAkgAttrsForFour2Five},
|
||||||
|
{kFive2FourOpName, SetAkgAttrsForFive2Four},
|
||||||
|
{kCastOpName, SetAkgAttrsForCast},
|
||||||
|
{kBNGrad1OpName, SetAkgAttrsForBNGrad1},
|
||||||
|
{kBNGrad2OpName, SetAkgAttrsForBNGrad2},
|
||||||
|
{kBNGrad3OpName, SetAkgAttrsForBNGrad3},
|
||||||
|
{kFusedBN1OpName, SetAkgAttrsForFusedBN1},
|
||||||
|
{kFusedBN2OpName, SetAkgAttrsForFusedBN2},
|
||||||
|
{kFusedBN3OpName, SetAkgAttrsForFusedBN3},
|
||||||
|
{kConvBN1OpName, SetAkgAttrsForConvBN1},
|
||||||
|
{kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu},
|
||||||
|
{kBN2ReLUOpName, SetAkgAttrsForBN2Relu},
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void SetAkgKernelAttrs(const AnfNodePtr &anf_node) {
|
||||||
|
auto it = kAkgKernelAttrsProcessMap.find(AnfAlgo::GetCNodeName(anf_node));
|
||||||
|
if (it != kAkgKernelAttrsProcessMap.end()) {
|
||||||
|
it->second(anf_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -16,43 +16,13 @@
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "utils/utils.h"
|
|
||||||
#include "base/core_ops.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForCast(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node);
|
|
||||||
void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node);
|
|
||||||
|
|
||||||
const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_node)>> kAkgKernelAttrsProcessMap = {
|
void SetAkgKernelAttrs(const AnfNodePtr &anf_node);
|
||||||
{kFour2FiveOpName, SetAkgAttrsForFour2Five},
|
|
||||||
{kFive2FourOpName, SetAkgAttrsForFive2Four},
|
|
||||||
{"Cast", SetAkgAttrsForCast},
|
|
||||||
{kBNGrad1OpName, SetAkgAttrsForBNGrad1},
|
|
||||||
{kBNGrad2OpName, SetAkgAttrsForBNGrad2},
|
|
||||||
{kBNGrad3OpName, SetAkgAttrsForBNGrad3},
|
|
||||||
{kFusedBN1OpName, SetAkgAttrsForFusedBN1},
|
|
||||||
{kFusedBN2OpName, SetAkgAttrsForFusedBN2},
|
|
||||||
{kFusedBN3OpName, SetAkgAttrsForFusedBN3},
|
|
||||||
{kConvBN1OpName, SetAkgAttrsForConvBN1},
|
|
||||||
{kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu},
|
|
||||||
{kBN2ReLUOpName, SetAkgAttrsForBN2Relu},
|
|
||||||
};
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H
|
||||||
|
|
|
@ -1,573 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
|
|
||||||
#include <unistd.h>
|
|
||||||
#include <dirent.h>
|
|
||||||
#include <memory>
|
|
||||||
#include <map>
|
|
||||||
#include <utility>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <functional>
|
|
||||||
#include <iterator>
|
|
||||||
#include <numeric>
|
|
||||||
#include <unordered_set>
|
|
||||||
#include "utils/convert_utils.h"
|
|
||||||
#include "utils/any.h"
|
|
||||||
#include "utils/utils.h"
|
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
|
|
||||||
#include "backend/session/kernel_build_client.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace kernel {
|
|
||||||
// json key
|
|
||||||
constexpr auto kOpDesc = "op_desc";
|
|
||||||
constexpr auto kInputDesc = "input_desc";
|
|
||||||
constexpr auto kShape = "shape";
|
|
||||||
constexpr auto kDataType = "data_type";
|
|
||||||
constexpr auto kOutputDesc = "output_desc";
|
|
||||||
constexpr auto kName = "name";
|
|
||||||
constexpr auto kTensorName = "tensor_name";
|
|
||||||
constexpr auto kValue = "value";
|
|
||||||
constexpr auto KDynInputSizes = "dyn_input_sizes";
|
|
||||||
constexpr auto KInputNames = "input_names";
|
|
||||||
constexpr auto KInput = "input";
|
|
||||||
constexpr auto KDtype = "dtype";
|
|
||||||
namespace {
|
|
||||||
template <typename T>
|
|
||||||
std::string Vector2Str(const std::vector<T> &inputs) {
|
|
||||||
if (!inputs.empty()) {
|
|
||||||
std::ostringstream oss;
|
|
||||||
(void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator<T>(oss, ", "));
|
|
||||||
oss << inputs.back();
|
|
||||||
return oss.str();
|
|
||||||
}
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
|
|
||||||
const std::pair<size_t, size_t> &position) {
|
|
||||||
if (node_json.count(tag) == 0) {
|
|
||||||
MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "].";
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
auto const &tag_desc = node_json[tag];
|
|
||||||
nlohmann::json first_index;
|
|
||||||
if (tag == kOutputDesc) {
|
|
||||||
first_index = tag_desc;
|
|
||||||
} else if (!tag_desc.is_array() || tag_desc.size() <= position.first) {
|
|
||||||
MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "].";
|
|
||||||
return "";
|
|
||||||
} else {
|
|
||||||
first_index = tag_desc[position.first];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!first_index.is_array() || first_index.size() <= position.second) {
|
|
||||||
MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "].";
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
auto const &second_index = first_index[position.second];
|
|
||||||
if (second_index.count(kTensorName) == 0) {
|
|
||||||
MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "].";
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
return second_index[kTensorName];
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
|
|
||||||
nlohmann::json *const node_json) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node_json);
|
|
||||||
if (node_json->count(tag) == 0) {
|
|
||||||
MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "].";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
nlohmann::json *tag_desc = &((*node_json)[tag]);
|
|
||||||
nlohmann::json *first_index;
|
|
||||||
if (tag == kOutputDesc) {
|
|
||||||
first_index = tag_desc;
|
|
||||||
} else if (!tag_desc->is_array() || tag_desc->size() <= position.first) {
|
|
||||||
MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "].";
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
first_index = &((*tag_desc)[position.first]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!first_index->is_array() || first_index->size() <= position.second) {
|
|
||||||
MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "].";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
nlohmann::json *second_index = &((*first_index)[position.second]);
|
|
||||||
if (second_index->count(kTensorName) == 0) {
|
|
||||||
MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "].";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
(*second_index)[kTensorName] = new_name;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int AkgKernelBuild::op_cnt_ = 0;
|
|
||||||
std::mutex AkgKernelBuild::op_cnt_mtx_;
|
|
||||||
|
|
||||||
std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
std::string device;
|
|
||||||
switch (AnfAlgo::GetProcessor(anf_node)) {
|
|
||||||
case Processor::AICORE:
|
|
||||||
device = kProcessorAiCore;
|
|
||||||
break;
|
|
||||||
|
|
||||||
case Processor::AICPU:
|
|
||||||
device = kProcessorAiCpu;
|
|
||||||
break;
|
|
||||||
|
|
||||||
case Processor::CUDA:
|
|
||||||
device = kProcessorCuda;
|
|
||||||
break;
|
|
||||||
|
|
||||||
default:
|
|
||||||
MS_LOG(ERROR) << "Unknown processor type.";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return device;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
|
|
||||||
std::vector<size_t> *const output_size) {
|
|
||||||
if (input_size == nullptr || output_size == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "input size or output size is nullptr";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
input_size->clear();
|
|
||||||
output_size->clear();
|
|
||||||
|
|
||||||
for (size_t i = 0; i < node_json[kInputDesc].size(); i++) {
|
|
||||||
for (size_t m = 0; m < node_json[kInputDesc][i].size(); m++) {
|
|
||||||
std::string dtype = node_json[kInputDesc][i][m][kDataType];
|
|
||||||
size_t nbyte = GetDtypeNbyte(dtype);
|
|
||||||
size_t size_i = std::accumulate(node_json[kInputDesc][i][m][kShape].begin(),
|
|
||||||
node_json[kInputDesc][i][m][kShape].end(), nbyte, std::multiplies<size_t>());
|
|
||||||
input_size->push_back(size_i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < node_json[kOutputDesc].size(); i++) {
|
|
||||||
std::string dtype = node_json[kOutputDesc][i][kDataType];
|
|
||||||
size_t nbyte = GetDtypeNbyte(dtype);
|
|
||||||
size_t size_i = std::accumulate(node_json[kOutputDesc][i][kShape].begin(), node_json[kOutputDesc][i][kShape].end(),
|
|
||||||
nbyte, std::multiplies<size_t>());
|
|
||||||
output_size->push_back(size_i);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int AkgKernelBuild::GetOpCntInc() {
|
|
||||||
op_cnt_mtx_.lock();
|
|
||||||
int cnt = op_cnt_++;
|
|
||||||
op_cnt_mtx_.unlock();
|
|
||||||
return cnt;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(inputs_json);
|
|
||||||
|
|
||||||
// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
|
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
|
||||||
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
|
|
||||||
if (op_info == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op_info is nullptr";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr = op_info->inputs_ptr();
|
|
||||||
if (inputs_ptr.empty()) {
|
|
||||||
MS_LOG(INFO) << "Apply kernel [" << op_name << "] regist info has no input info";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
auto op_info_input_num = inputs_ptr.size();
|
|
||||||
|
|
||||||
// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
|
|
||||||
std::vector<int> dyn_input_sizes;
|
|
||||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
|
|
||||||
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
|
|
||||||
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t real_input_index = 0;
|
|
||||||
std::vector<nlohmann::json> input_list;
|
|
||||||
for (size_t i = 0; i < op_info_input_num; i++) {
|
|
||||||
size_t input_tensor_num;
|
|
||||||
std::shared_ptr<OpIOInfo> input_ptr = inputs_ptr[i];
|
|
||||||
std::string op_input_name;
|
|
||||||
if (input_ptr == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Apply kernel [" << op_name << "] regist input[" << i << "] is nullptr";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
op_input_name = input_ptr->name();
|
|
||||||
if (dyn_input_sizes.empty()) {
|
|
||||||
input_tensor_num = 1;
|
|
||||||
} else {
|
|
||||||
input_tensor_num = IntToSize(dyn_input_sizes[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
input_list.clear();
|
|
||||||
for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
|
|
||||||
// dtype : float16
|
|
||||||
auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index);
|
|
||||||
std::string dtype = TypeId2String(type_id);
|
|
||||||
if (dtype.empty()) {
|
|
||||||
MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. ";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
nlohmann::json input_desc_json;
|
|
||||||
input_desc_json[kDataType] = dtype;
|
|
||||||
input_desc_json[kName] = op_input_name;
|
|
||||||
input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
|
|
||||||
auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
|
|
||||||
if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) &&
|
|
||||||
GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
|
|
||||||
MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2)
|
|
||||||
<< "] as const tensor, shape: [" << Vector2Str(input_shape)
|
|
||||||
<< "], value: " << input_desc_json[kValue];
|
|
||||||
|
|
||||||
input_shape.clear();
|
|
||||||
}
|
|
||||||
if (input_shape.empty()) {
|
|
||||||
input_shape.push_back(1);
|
|
||||||
}
|
|
||||||
input_desc_json[kShape] = input_shape;
|
|
||||||
input_list.emplace_back(input_desc_json);
|
|
||||||
real_input_index++;
|
|
||||||
}
|
|
||||||
inputs_json->emplace_back(input_list);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(outputs_json);
|
|
||||||
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
|
||||||
|
|
||||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
|
|
||||||
auto outputs = op_info_ptr->outputs_ptr();
|
|
||||||
for (size_t i = 0; i < output_tensor_num; i++) {
|
|
||||||
nlohmann::json output_json;
|
|
||||||
auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i);
|
|
||||||
std::string dtype = TypeId2String(type_id);
|
|
||||||
if (dtype.empty()) {
|
|
||||||
MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. ";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string output_name = outputs[i]->name();
|
|
||||||
output_json[kDataType] = dtype;
|
|
||||||
output_json[kName] = output_name;
|
|
||||||
output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc());
|
|
||||||
output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i);
|
|
||||||
outputs_json->push_back(output_json);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void GetJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes,
|
|
||||||
const std::shared_ptr<OpAttr> &op_attr, nlohmann::json *const attr_json, const ValuePtr &attr_value) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(op_attr);
|
|
||||||
MS_EXCEPTION_IF_NULL(attr_json);
|
|
||||||
std::string type = op_attr->type();
|
|
||||||
if (type == "int") {
|
|
||||||
(*attr_json)[kValue] = GetValue<int>(attr_value);
|
|
||||||
} else if (type == "str") {
|
|
||||||
(*attr_json)[kValue] = GetValue<std::string>(attr_value);
|
|
||||||
} else if (type == "bool") {
|
|
||||||
(*attr_json)[kValue] = GetValue<bool>(attr_value);
|
|
||||||
} else if (type == "float") {
|
|
||||||
(*attr_json)[kValue] = GetValue<float>(attr_value);
|
|
||||||
} else if (type == "listInt") {
|
|
||||||
(*attr_json)[kValue] = GetValue<std::vector<int>>(attr_value);
|
|
||||||
} else if (type == "listStr") {
|
|
||||||
std::vector<std::string> data_format;
|
|
||||||
if (op_attr->name() == kArgDataformat) {
|
|
||||||
size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
|
|
||||||
for (size_t format_i = 0; format_i < tensor_args_num; format_i++) {
|
|
||||||
auto input_format = AnfAlgo::GetInputFormat(anf_node, format_i);
|
|
||||||
data_format.push_back(input_format);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
data_format = GetValue<std::vector<std::string>>(attr_value);
|
|
||||||
}
|
|
||||||
(*attr_json)[kValue] = data_format;
|
|
||||||
} else {
|
|
||||||
MS_LOG(WARNING) << "attr type:" << type;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AkgKernelBuild::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name,
|
|
||||||
const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(attrs_json);
|
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
|
||||||
std::vector<std::shared_ptr<OpAttr>> attrs = op_info->attrs_ptr();
|
|
||||||
if (attrs.empty()) {
|
|
||||||
MS_LOG(INFO) << "Apply kernel [" << op_name << "] op info attrs is empty";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info->inputs_ptr();
|
|
||||||
|
|
||||||
std::vector<int> dyn_input_sizes;
|
|
||||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
|
|
||||||
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (inputs.empty()) {
|
|
||||||
MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op info inputs is empty";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// create input name list for atch "x_shape" in att with "x" in primitive.
|
|
||||||
std::map<size_t, std::string> op_info_shape_name;
|
|
||||||
for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) {
|
|
||||||
std::string input_name = inputs[op_info_input_i]->name();
|
|
||||||
std::string x_shape_name = input_name + "_shape";
|
|
||||||
(void)op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto &op_attr : attrs) {
|
|
||||||
nlohmann::json attr_json;
|
|
||||||
ValuePtr attr_value = primitive->GetAttr(op_attr->name());
|
|
||||||
if (attr_value == nullptr && op_attr->name() != kArgDataformat) {
|
|
||||||
if (op_attr->param_type() == "required") {
|
|
||||||
// match "x_shape" in att with "x" in primitive.
|
|
||||||
std::string attr_name = op_attr->name();
|
|
||||||
auto find_item = std::find_if(
|
|
||||||
op_info_shape_name.begin(), op_info_shape_name.end(),
|
|
||||||
[attr_name](const std::map<size_t, std::string>::value_type item) { return item.second == attr_name; });
|
|
||||||
if (find_item != op_info_shape_name.end()) {
|
|
||||||
if (!dyn_input_sizes.empty()) {
|
|
||||||
if (find_item->first >= dyn_input_sizes.size() - 1) {
|
|
||||||
MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first
|
|
||||||
<< " is out of range:" << dyn_input_sizes.size() - 1 << ".";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0));
|
|
||||||
for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) {
|
|
||||||
attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx);
|
|
||||||
attr_json[kName] = op_attr->name();
|
|
||||||
attrs_json->push_back(attr_json);
|
|
||||||
tensor_idx++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first);
|
|
||||||
attr_json[kName] = op_attr->name();
|
|
||||||
attrs_json->push_back(attr_json);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "op [" << op_name << "] should have attr :" << op_attr->name();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value);
|
|
||||||
|
|
||||||
attr_json[kName] = op_attr->name();
|
|
||||||
attrs_json->push_back(attr_json);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name,
|
|
||||||
nlohmann::json *const node_json) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(node_json);
|
|
||||||
int op_cnt = GetOpCntInc();
|
|
||||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
|
|
||||||
MS_EXCEPTION_IF_NULL(op_info_ptr);
|
|
||||||
|
|
||||||
// get basic params from currentNodeOpDesc
|
|
||||||
(*node_json)[kName] = op_name;
|
|
||||||
(*node_json)["impl_path"] = op_info_ptr->impl_path();
|
|
||||||
(*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node);
|
|
||||||
(*node_json)["composite"] = false;
|
|
||||||
|
|
||||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
ValuePtr input_names_v = primitive->GetAttr(KInputNames);
|
|
||||||
if (input_names_v == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "].";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
std::vector<std::string> prim_input_names = GetValue<const std::vector<std::string>>(input_names_v);
|
|
||||||
std::string inputs_name;
|
|
||||||
for (const auto &prim_input_name : prim_input_names) {
|
|
||||||
(void)inputs_name.append("_input_").append(prim_input_name).append("_");
|
|
||||||
}
|
|
||||||
|
|
||||||
// input desc
|
|
||||||
nlohmann::json inputs_json;
|
|
||||||
if (!CreateInputDescJson(anf_node, &inputs_json)) {
|
|
||||||
MS_LOG(ERROR) << "Create input desc json failed, op[" << op_name << "].";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
(*node_json)[kInputDesc] = inputs_json;
|
|
||||||
MS_LOG(INFO) << "Akg create input desc json success.";
|
|
||||||
std::string inputs_shape = "inputs_shape_";
|
|
||||||
for (auto &i : inputs_json) {
|
|
||||||
for (auto &m : i) {
|
|
||||||
std::string data_type = m[kDataType];
|
|
||||||
(void)inputs_shape.append("_").append(data_type).append("_");
|
|
||||||
for (auto &j : m[kShape]) {
|
|
||||||
size_t n = j;
|
|
||||||
(void)inputs_shape.append(std::to_string(n)).append("_");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// output desc
|
|
||||||
nlohmann::json outputs_json;
|
|
||||||
if (!CreateOutputDescJson(anf_node, &outputs_json)) {
|
|
||||||
MS_LOG(ERROR) << "Create output desc json failed, op[" << op_name << "].";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
(*node_json)[kOutputDesc] = outputs_json;
|
|
||||||
MS_LOG(INFO) << "Akg create output desc json success.";
|
|
||||||
std::string outputs_shape = "outputs_shape_";
|
|
||||||
for (auto &i : outputs_json) {
|
|
||||||
std::string data_type = i[kDataType];
|
|
||||||
(void)outputs_shape.append("_").append(data_type).append("_");
|
|
||||||
for (auto &j : i[kShape]) {
|
|
||||||
size_t m = j;
|
|
||||||
(void)outputs_shape.append(std::to_string(m)).append("_");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// attribute desc
|
|
||||||
nlohmann::json attrs_json;
|
|
||||||
if (!CreateAttrDescJson(anf_node, op_name, op_info_ptr, &attrs_json)) {
|
|
||||||
MS_LOG(ERROR) << "Create attr desc json failed, op[" << op_name << "].";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
(*node_json)["attr"] = attrs_json;
|
|
||||||
std::string json_str = node_json->dump();
|
|
||||||
size_t hash_id = std::hash<std::string>()(json_str);
|
|
||||||
json_name_ = op_name + "_";
|
|
||||||
(void)json_name_.append(std::to_string(hash_id));
|
|
||||||
MS_LOG(INFO) << "full scope name is : " << anf_node->fullname_with_scope() << ", json info name is : " << json_name_;
|
|
||||||
json_info_ = json_str;
|
|
||||||
(*node_json)["id"] = op_cnt;
|
|
||||||
(*node_json)["op"] = json_name_;
|
|
||||||
MS_LOG(INFO) << "Akg create node desc json success.";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNodePtr &anf_node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
auto processor = AkgKernelBuild::GetProcessor(anf_node);
|
|
||||||
auto cached_kernel_pack = SearchCache(json_name_, processor);
|
|
||||||
if (cached_kernel_pack != nullptr) {
|
|
||||||
MS_LOG(INFO) << "Use cached kernel, json_name_[" << json_name_ << "], fullname_with_scope["
|
|
||||||
<< anf_node->fullname_with_scope() << "].";
|
|
||||||
return cached_kernel_pack;
|
|
||||||
}
|
|
||||||
|
|
||||||
(void)alarm(AUTODIFF_COMPILE_OVERTIME);
|
|
||||||
auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(node_json);
|
|
||||||
(void)alarm(0);
|
|
||||||
if (!res) {
|
|
||||||
MS_LOG(ERROR) << "Akg compile failed, json: " << node_json;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto new_kernel_pack = InsertCache(json_name_, processor);
|
|
||||||
kernel::SaveJsonInfo(json_name_, json_info_);
|
|
||||||
if (new_kernel_pack == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name_ << "], fullname_with_scope["
|
|
||||||
<< anf_node->fullname_with_scope() << "].";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return new_kernel_pack;
|
|
||||||
}
|
|
||||||
|
|
||||||
KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size,
|
|
||||||
std::vector<size_t> *const output_size) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
|
||||||
auto it = kAkgKernelAttrsProcessMap.find(op_name);
|
|
||||||
if (it != kAkgKernelAttrsProcessMap.end()) {
|
|
||||||
it->second(anf_node);
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
|
|
||||||
nlohmann::json node_json;
|
|
||||||
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
|
|
||||||
MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed.";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string json_str = node_json.dump();
|
|
||||||
auto kernel_pack = OpBuild(json_str, anf_node);
|
|
||||||
if (kernel_pack == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Akg build failed op[" << op_name << "], json:" << json_str;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!GetIOSize(node_json, input_size, output_size)) {
|
|
||||||
MS_LOG(ERROR) << "Cal mem size failed.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "Akg compile success, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node)
|
|
||||||
<< "]";
|
|
||||||
return kernel_pack;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
auto cnode = anf_node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
if (input_idx + 1 >= cnode->inputs().size()) {
|
|
||||||
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
|
|
||||||
<< cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]";
|
|
||||||
}
|
|
||||||
|
|
||||||
auto input_node = cnode->input(input_idx + 1);
|
|
||||||
if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) {
|
|
||||||
size_t index = input_tensor_idx_.size();
|
|
||||||
input_tensor_idx_[input_node] = index;
|
|
||||||
}
|
|
||||||
|
|
||||||
return input_tensor_idx_[input_node];
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t AkgKernelBuild::GetOutputTensorIdxInc() {
|
|
||||||
size_t idx = output_tensor_idx_++;
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
} // namespace kernel
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,76 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_
|
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include <map>
|
|
||||||
#include <utility>
|
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
|
||||||
#include "ir/dtype.h"
|
|
||||||
#include "ir/primitive.h"
|
|
||||||
#include <nlohmann/json.hpp>
|
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
|
||||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace kernel {
|
|
||||||
class AkgKernelBuild {
|
|
||||||
public:
|
|
||||||
AkgKernelBuild() {
|
|
||||||
input_tensor_idx_ = {};
|
|
||||||
output_tensor_idx_ = 0;
|
|
||||||
}
|
|
||||||
~AkgKernelBuild() = default;
|
|
||||||
|
|
||||||
KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size,
|
|
||||||
std::vector<size_t> *const output_size);
|
|
||||||
static std::string GetProcessor(const AnfNodePtr &anf_node);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json);
|
|
||||||
bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json);
|
|
||||||
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name,
|
|
||||||
const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json);
|
|
||||||
KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node);
|
|
||||||
int GetOpCntInc();
|
|
||||||
size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
|
|
||||||
size_t GetOutputTensorIdxInc();
|
|
||||||
bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name,
|
|
||||||
nlohmann::json *const node_json);
|
|
||||||
|
|
||||||
static int op_cnt_;
|
|
||||||
// lock for variable fusionOpCnt in singleton mode
|
|
||||||
static std::mutex op_cnt_mtx_;
|
|
||||||
std::string json_name_;
|
|
||||||
std::string json_info_;
|
|
||||||
std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_;
|
|
||||||
size_t output_tensor_idx_;
|
|
||||||
};
|
|
||||||
|
|
||||||
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
|
|
||||||
std::vector<size_t> *const output_size);
|
|
||||||
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
|
|
||||||
nlohmann::json *const node_json);
|
|
||||||
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
|
|
||||||
const std::pair<size_t, size_t> &position);
|
|
||||||
} // namespace kernel
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_
|
|
|
@ -0,0 +1,415 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "ir/meta_tensor.h"
|
||||||
|
#include "ir/manager.h"
|
||||||
|
#include "ir/dtype.h"
|
||||||
|
#include "frontend/operator/ops.h"
|
||||||
|
#include "utils/convert_utils.h"
|
||||||
|
#include "utils/convert_utils_py.h"
|
||||||
|
#include "utils/utils.h"
|
||||||
|
#include "ir/graph_utils.h"
|
||||||
|
#include "runtime/device/kernel_info.h"
|
||||||
|
#include "pipeline/jit/parse/data_converter.h"
|
||||||
|
#include "pipeline/jit/parse/python_adapter.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "debug/anf_ir_dump.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) {
|
||||||
|
if (type == "str") {
|
||||||
|
std::string value = attr_json[kJsonKeyValue];
|
||||||
|
return MakeValue(value);
|
||||||
|
} else if (type == "int") {
|
||||||
|
int value = attr_json[kJsonKeyValue];
|
||||||
|
return MakeValue(value);
|
||||||
|
} else if (type == "bool") {
|
||||||
|
bool value = attr_json[kJsonKeyValue];
|
||||||
|
return MakeValue(value);
|
||||||
|
} else if (type == "float") {
|
||||||
|
float value = attr_json[kJsonKeyValue];
|
||||||
|
return MakeValue(value);
|
||||||
|
} else if (type == "listInt") {
|
||||||
|
std::vector<int> value = attr_json[kJsonKeyValue];
|
||||||
|
return MakeValue(value);
|
||||||
|
} else if (type == "listStr") {
|
||||||
|
std::vector<std::string> value = attr_json[kJsonKeyValue];
|
||||||
|
return MakeValue(value);
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Unknown type of attr: " << type << ", json: \n" << attr_json;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DecodeAttrs(const nlohmann::json &attrs_json, std::map<std::string, ValuePtr> *attrs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(attrs);
|
||||||
|
MS_LOG(DEBUG) << "start decode attrs, " << attrs_json;
|
||||||
|
// decode attrs.
|
||||||
|
if (attrs_json.find(kJsonKeyAttr) == attrs_json.end() || attrs_json[kJsonKeyAttr].is_null()) {
|
||||||
|
// attrs maybe empty
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<nlohmann::json> attr_descs = attrs_json[kJsonKeyAttr];
|
||||||
|
for (const auto &attr_desc : attr_descs) {
|
||||||
|
std::string name = attr_desc[kJsonKeyName];
|
||||||
|
std::string type = attr_desc[kJsonKeyDataType];
|
||||||
|
auto value = ParseValue(attr_desc, type);
|
||||||
|
if (value == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(*attrs)[name] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// python utils.
|
||||||
|
constexpr auto kGetPythonOpFunc = "_get_python_op";
|
||||||
|
constexpr auto kParallelUtilsModule = "mindspore.parallel._utils";
|
||||||
|
// almost all ops are defined in this path.
|
||||||
|
constexpr auto kOperationsModule = "mindspore.ops.operations";
|
||||||
|
|
||||||
|
const std::map<std::string, std::vector<std::string>> op_attrs_map = {
|
||||||
|
{kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}},
|
||||||
|
{kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}},
|
||||||
|
{kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}},
|
||||||
|
};
|
||||||
|
|
||||||
|
ValuePtr CreatOpInstance(const std::string &op_name, const std::vector<ValuePtr> &attrs) {
|
||||||
|
py::module mod = py::module::import(kOperationsModule);
|
||||||
|
if (!py::hasattr(mod, op_name.c_str())) {
|
||||||
|
MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::vector<py::object> arg_list;
|
||||||
|
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
|
||||||
|
[](const ValuePtr &attr) { return ValuePtrToPyData(attr); });
|
||||||
|
py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule,
|
||||||
|
op_name, arg_list);
|
||||||
|
ValuePtr op_instance = nullptr;
|
||||||
|
bool succ = parse::ConvertData(obj, &op_instance);
|
||||||
|
if (!succ) {
|
||||||
|
MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return op_instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
PrimitivePtr GetPrimitive(const std::string &op_name, const std::map<std::string, ValuePtr> &attrs_val) {
|
||||||
|
PrimitivePtr primitive{nullptr};
|
||||||
|
if (op_attrs_map.count(op_name) == 0) {
|
||||||
|
// no attrs for op instance.
|
||||||
|
primitive = CreatOpInstance(op_name, std::vector<ValuePtr>{})->cast<PrimitivePtr>();
|
||||||
|
} else {
|
||||||
|
// make attrs for op instance.
|
||||||
|
std::vector<ValuePtr> op_attrs;
|
||||||
|
const auto &attr_names = op_attrs_map.at(op_name);
|
||||||
|
for (const auto &attr_name : attr_names) {
|
||||||
|
if (attrs_val.count(attr_name) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
op_attrs.push_back(attrs_val.at(attr_name));
|
||||||
|
}
|
||||||
|
primitive = CreatOpInstance(op_name, op_attrs)->cast<PrimitivePtr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (primitive != nullptr) {
|
||||||
|
for (const auto &attr : attrs_val) {
|
||||||
|
primitive->AddAttr(attr.first, attr.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return primitive;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
|
||||||
|
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
||||||
|
|
||||||
|
ScalarPtr AkgKernelJsonDecoder::DecodeScalar(const nlohmann::json &scalar_json) {
|
||||||
|
auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]);
|
||||||
|
switch (type_id) {
|
||||||
|
case kNumberTypeFloat16:
|
||||||
|
case kNumberTypeFloat32:
|
||||||
|
return std::make_shared<FP32Imm>(scalar_json[kJsonKeyValue]);
|
||||||
|
case kNumberTypeInt32:
|
||||||
|
return std::make_shared<Int32Imm>(scalar_json[kJsonKeyValue]);
|
||||||
|
default:
|
||||||
|
MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueNodePtr AkgKernelJsonDecoder::DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) {
|
||||||
|
MS_LOG(DEBUG) << "start decode value node, " << value_json;
|
||||||
|
auto scalar = DecodeScalar(value_json);
|
||||||
|
auto tensor = ScalarToTensor(scalar);
|
||||||
|
|
||||||
|
auto value_node = std::make_shared<ValueNode>(tensor);
|
||||||
|
value_node->set_abstract(tensor->ToAbstract());
|
||||||
|
// create kernel_info fo new value node.
|
||||||
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||||
|
value_node->set_kernel_info(kernel_info);
|
||||||
|
// create kernel_build_info for new value node.
|
||||||
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||||
|
// layout info.
|
||||||
|
builder->SetOutputsFormat(std::vector<std::string>{value_json[kJsonKeyFormat]});
|
||||||
|
builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(value_json[kJsonKeyDataType])});
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get());
|
||||||
|
func_graph->AddValueNode(value_node);
|
||||||
|
MS_LOG(DEBUG) << "decode value node success, " << value_node->DebugString(2);
|
||||||
|
return value_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶meter_json,
|
||||||
|
const FuncGraphPtr &func_graph) {
|
||||||
|
MS_LOG(DEBUG) << "start decode parameter, " << parameter_json;
|
||||||
|
ParameterPtr new_parameter = func_graph->add_parameter();
|
||||||
|
std::string name = parameter_json[kJsonKeyTensorName];
|
||||||
|
new_parameter->set_name(name);
|
||||||
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||||
|
new_parameter->set_kernel_info(kernel_info);
|
||||||
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||||
|
builder->SetOutputsFormat(std::vector<std::string>{parameter_json[kJsonKeyFormat]});
|
||||||
|
builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(parameter_json[kJsonKeyDataType])});
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), new_parameter.get());
|
||||||
|
nodes_map_[name] = new_parameter;
|
||||||
|
return new_parameter;
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr AkgKernelJsonDecoder::DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph,
|
||||||
|
const std::string &processor) {
|
||||||
|
Processor p = kernel::GetProcessor(processor);
|
||||||
|
MS_LOG(DEBUG) << "start decode cnode, " << cnode_json;
|
||||||
|
// decode attrs.
|
||||||
|
std::map<std::string, ValuePtr> cnode_attrs;
|
||||||
|
if (!DecodeAttrs(cnode_json, &cnode_attrs)) {
|
||||||
|
MS_LOG(ERROR) << "Error decode attrs.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::string op_name = cnode_json[kJsonKeyName];
|
||||||
|
// new primitive.
|
||||||
|
auto primitive = GetPrimitive(op_name, cnode_attrs);
|
||||||
|
if (primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create primitive failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// data layout info.
|
||||||
|
std::vector<std::string> input_formats;
|
||||||
|
std::vector<TypeId> input_types;
|
||||||
|
std::vector<std::string> output_formats;
|
||||||
|
std::vector<TypeId> output_types;
|
||||||
|
|
||||||
|
// collect inputs.
|
||||||
|
auto primitive_v = NewValueNode(primitive);
|
||||||
|
func_graph->AddValueNode(primitive_v);
|
||||||
|
std::vector<AnfNodePtr> inputs{primitive_v};
|
||||||
|
std::vector<nlohmann::json> input_descs = cnode_json[kJsonKeyInputDesc];
|
||||||
|
for (size_t i = 0; i < input_descs.size(); ++i) {
|
||||||
|
nlohmann::json input_desc = input_descs[i][0];
|
||||||
|
std::string name = input_desc[kJsonKeyTensorName];
|
||||||
|
if (input_desc.find(kJsonKeyValue) != input_desc.end()) {
|
||||||
|
inputs.push_back(DecodeValueNode(input_desc, func_graph));
|
||||||
|
} else if (nodes_map_.count(name) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Input: " << name << " of: " << op_name << " not found.";
|
||||||
|
return nullptr;
|
||||||
|
} else {
|
||||||
|
inputs.push_back(nodes_map_[name]);
|
||||||
|
}
|
||||||
|
input_formats.push_back(input_desc[kJsonKeyFormat]);
|
||||||
|
input_types.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType]));
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "decode inputs success.";
|
||||||
|
|
||||||
|
// new cnode.
|
||||||
|
auto cnode = func_graph->NewCNode(inputs);
|
||||||
|
func_graph->AddNode(cnode);
|
||||||
|
|
||||||
|
// decode outputs.
|
||||||
|
std::vector<nlohmann::json> output_descs = cnode_json[kJsonKeyOutputDesc];
|
||||||
|
AbstractBasePtr abstract(nullptr);
|
||||||
|
if (output_descs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "No outputs found.";
|
||||||
|
return nullptr;
|
||||||
|
} else if (output_descs.size() == 1) {
|
||||||
|
// single output.
|
||||||
|
nlohmann::json output_desc = output_descs[0];
|
||||||
|
output_formats.push_back(output_desc[kJsonKeyFormat]);
|
||||||
|
output_types.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType]));
|
||||||
|
nodes_map_[output_desc[kJsonKeyTensorName]] = cnode;
|
||||||
|
} else {
|
||||||
|
// multi outputs.
|
||||||
|
for (size_t j = 0; j < output_descs.size(); ++j) {
|
||||||
|
nlohmann::json output_desc = output_descs[j];
|
||||||
|
output_formats.push_back(output_desc[kJsonKeyFormat]);
|
||||||
|
output_types.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType]));
|
||||||
|
auto get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, NewValueNode(SizeToInt(j))});
|
||||||
|
func_graph->AddNode(get_item);
|
||||||
|
nodes_map_[output_desc[kJsonKeyTensorName]] = get_item;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "decode outputs success.";
|
||||||
|
|
||||||
|
// create kernel_info.
|
||||||
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||||
|
std::vector<size_t> feature_map_input_indexs;
|
||||||
|
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
|
||||||
|
// then the node's output is a feature map output
|
||||||
|
for (size_t index = 1; index < inputs.size(); ++index) {
|
||||||
|
auto node = AnfAlgo::VisitKernel(inputs[index], 0);
|
||||||
|
if (AnfAlgo::IsFeatureMapOutput(node.first)) {
|
||||||
|
feature_map_input_indexs.push_back(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
|
||||||
|
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
|
||||||
|
}
|
||||||
|
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
|
||||||
|
kernel_info->SetFeatureMapFlag(true);
|
||||||
|
}
|
||||||
|
if (AnfAlgo::IsRealCNodeKernel(cnode)) {
|
||||||
|
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
|
||||||
|
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
|
||||||
|
}
|
||||||
|
cnode->set_kernel_info(kernel_info);
|
||||||
|
// create kernel_build_info.
|
||||||
|
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||||
|
builder->SetInputsFormat(input_formats);
|
||||||
|
builder->SetInputsDeviceType(input_types);
|
||||||
|
builder->SetOutputsFormat(output_formats);
|
||||||
|
builder->SetOutputsDeviceType(output_types);
|
||||||
|
builder->SetProcessor(p);
|
||||||
|
builder->SetKernelType(KernelType::AKG_KERNEL);
|
||||||
|
builder->SetFusionType(kernel::FusionType::OPAQUE);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
|
||||||
|
return cnode;
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel_json) {
|
||||||
|
MS_LOG(DEBUG) << "start decode, " << kernel_json;
|
||||||
|
// clear cache.
|
||||||
|
nodes_map_.clear();
|
||||||
|
// create a graph.
|
||||||
|
auto graph = std::make_shared<FuncGraph>();
|
||||||
|
|
||||||
|
// decode parameters.
|
||||||
|
std::vector<nlohmann::json> input_descs = kernel_json[kJsonKeyInputDesc];
|
||||||
|
if (input_descs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Error decode parameter, no inputs for graph.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < input_descs.size(); ++i) {
|
||||||
|
std::vector<nlohmann::json> input_desc = input_descs[i];
|
||||||
|
auto parameter = DecodeParameter(input_desc[0], graph);
|
||||||
|
if (parameter == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Error decode parameter.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "decode parameters success.";
|
||||||
|
|
||||||
|
// decode cnodes in graph.
|
||||||
|
std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
|
||||||
|
if (op_node_descs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Error decode cnodes, no cnodes for graph.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
for (const auto &op_desc : op_node_descs) {
|
||||||
|
auto op_node = DecodeCNode(op_desc, graph, kernel_json[kJsonKeyProcess]);
|
||||||
|
if (op_node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Error decode cnode.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "decode cnodes success.";
|
||||||
|
|
||||||
|
// decode outputs of graph.
|
||||||
|
std::vector<nlohmann::json> output_descs = kernel_json[kJsonKeyOutputDesc];
|
||||||
|
if (output_descs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Error decode outputs, no outputs for graph.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||||
|
for (const auto &output_desc : output_descs) {
|
||||||
|
std::string name = output_desc[kJsonKeyTensorName];
|
||||||
|
if (nodes_map_.count(name) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Output: " << name << " of graph not found.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
outputs.push_back(nodes_map_[name]);
|
||||||
|
}
|
||||||
|
if (outputs.size() == 2) {
|
||||||
|
graph->set_output(outputs[1]);
|
||||||
|
} else {
|
||||||
|
auto output = graph->NewCNode(outputs);
|
||||||
|
graph->AddNode(output);
|
||||||
|
graph->set_output(output);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "decode success, " << kernel_json;
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_json_str) {
|
||||||
|
auto kernel_json = nlohmann::json::parse(kernel_json_str);
|
||||||
|
return DecodeFusedNodes(kernel_json);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
|
||||||
|
const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||||
|
AnfNodePtrList *res_graphs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(res_graphs);
|
||||||
|
MS_LOG(DEBUG) << "start decode, " << kernel_json;
|
||||||
|
// decode cnodes in graph.
|
||||||
|
std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
|
||||||
|
if (op_node_descs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (const auto &op_desc : op_node_descs) {
|
||||||
|
if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
|
||||||
|
MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ptr_address = op_desc[kJsonKeyPtrAddress];
|
||||||
|
if (address_node_map.count(ptr_address) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
res_graphs->push_back(address_node_map.at(ptr_address));
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_DECODER_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_DECODER_H_
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "ir/scalar.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class AkgKernelJsonDecoder {
|
||||||
|
public:
|
||||||
|
AkgKernelJsonDecoder() { nodes_map_.clear(); }
|
||||||
|
~AkgKernelJsonDecoder() = default;
|
||||||
|
|
||||||
|
FuncGraphPtr DecodeFusedNodes(const nlohmann::json &kernel_json);
|
||||||
|
FuncGraphPtr DecodeFusedNodes(const std::string &kernel_json_str);
|
||||||
|
bool DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||||
|
AnfNodePtrList *res_graphs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
ScalarPtr DecodeScalar(const nlohmann::json &scalar_json);
|
||||||
|
ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph);
|
||||||
|
ParameterPtr DecodeParameter(const nlohmann::json ¶meter_json, const FuncGraphPtr &func_graph);
|
||||||
|
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
|
||||||
|
std::map<std::string, AnfNodePtr> nodes_map_{};
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_DECODER_H_
|
|
@ -0,0 +1,630 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <sstream>
|
||||||
|
#include <tuple>
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
std::vector<int> GetDynInputSize(const AnfNodePtr &anf_node) {
|
||||||
|
std::vector<int> dyn_input_sizes;
|
||||||
|
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
if (primitive->HasAttr(kAttrDynInputSizes)) {
|
||||||
|
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
|
||||||
|
}
|
||||||
|
return dyn_input_sizes;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
int AkgKernelJsonGenerator::op_cnt_ = 0;
|
||||||
|
std::mutex AkgKernelJsonGenerator::op_cnt_mtx_;
|
||||||
|
|
||||||
|
int AkgKernelJsonGenerator::GetOpCntInc() {
|
||||||
|
op_cnt_mtx_.lock();
|
||||||
|
int cnt = op_cnt_++;
|
||||||
|
op_cnt_mtx_.unlock();
|
||||||
|
return cnt;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TypeId AkgKernelJsonGenerator::GetInputDataType(const AnfNodePtr &anf_node, size_t real_index) {
|
||||||
|
return dump_option_.is_before_select_kernel ? AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index)
|
||||||
|
: AnfAlgo::GetInputDeviceDataType(anf_node, real_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<size_t> AkgKernelJsonGenerator::GetInputShape(const AnfNodePtr &anf_node, size_t real_index) {
|
||||||
|
return dump_option_.is_before_select_kernel ? AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index)
|
||||||
|
: AnfAlgo::GetInputDeviceShape(anf_node, real_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string AkgKernelJsonGenerator::GetInputFormat(const AnfNodePtr &anf_node, size_t real_index) {
|
||||||
|
return dump_option_.is_before_select_kernel ? kOpFormat_DEFAULT : AnfAlgo::GetInputFormat(anf_node, real_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TypeId AkgKernelJsonGenerator::GetOutputDataType(const AnfNodePtr &anf_node, size_t index) {
|
||||||
|
return dump_option_.is_before_select_kernel ? AnfAlgo::GetOutputInferDataType(anf_node, index)
|
||||||
|
: AnfAlgo::GetOutputDeviceDataType(anf_node, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<size_t> AkgKernelJsonGenerator::GetOutputShape(const AnfNodePtr &anf_node, size_t index) {
|
||||||
|
return dump_option_.is_before_select_kernel ? AnfAlgo::GetOutputInferShape(anf_node, index)
|
||||||
|
: AnfAlgo::GetOutputDeviceShape(anf_node, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string AkgKernelJsonGenerator::GetOutputFormat(const AnfNodePtr &anf_node, size_t index) {
|
||||||
|
return dump_option_.is_before_select_kernel ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(anf_node, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
|
||||||
|
nlohmann::json *const inputs_json) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
|
MS_EXCEPTION_IF_NULL(inputs_json);
|
||||||
|
|
||||||
|
// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
|
||||||
|
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr = op_info->inputs_ptr();
|
||||||
|
if (inputs_ptr.empty()) {
|
||||||
|
MS_LOG(DEBUG) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
|
||||||
|
auto dyn_input_sizes = GetDynInputSize(anf_node);
|
||||||
|
|
||||||
|
size_t real_input_index = 0;
|
||||||
|
std::vector<nlohmann::json> input_list;
|
||||||
|
for (size_t i = 0; i < inputs_ptr.size(); i++) {
|
||||||
|
std::shared_ptr<OpIOInfo> input_ptr = inputs_ptr[i];
|
||||||
|
if (input_ptr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist input[" << i << "] is nullptr";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto op_input_name = input_ptr->name();
|
||||||
|
size_t input_tensor_num = dyn_input_sizes.empty() ? 1 : IntToSize(dyn_input_sizes[i]);
|
||||||
|
|
||||||
|
input_list.clear();
|
||||||
|
for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
|
||||||
|
auto type_id = this->GetInputDataType(anf_node, real_input_index);
|
||||||
|
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
|
||||||
|
if (dtype.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] input [" << real_input_index
|
||||||
|
<< "] data type is null. ";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
nlohmann::json input_desc_json;
|
||||||
|
input_desc_json[kJsonKeyDataType] = dtype;
|
||||||
|
input_desc_json[kJsonKeyFormat] = this->GetInputFormat(anf_node, real_input_index);
|
||||||
|
input_desc_json[kJsonKeyName] = op_input_name;
|
||||||
|
input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
|
||||||
|
auto input_shape = this->GetInputShape(anf_node, real_input_index);
|
||||||
|
if (GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
|
||||||
|
MS_LOG(DEBUG) << "Take input[" << real_input_index << "] of [" << anf_node->DebugString(2)
|
||||||
|
<< "] as const tensor, shape: [" << Vector2Str(input_shape)
|
||||||
|
<< "], value: " << input_desc_json[kJsonKeyValue];
|
||||||
|
input_shape.clear();
|
||||||
|
}
|
||||||
|
if (input_shape.empty()) {
|
||||||
|
input_shape.push_back(1);
|
||||||
|
}
|
||||||
|
input_desc_json[kJsonKeyShape] = input_shape;
|
||||||
|
input_list.emplace_back(input_desc_json);
|
||||||
|
real_input_index++;
|
||||||
|
}
|
||||||
|
inputs_json->emplace_back(input_list);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AkgKernelJsonGenerator::CreateOutputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
|
||||||
|
nlohmann::json *const outputs_json) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
|
MS_EXCEPTION_IF_NULL(outputs_json);
|
||||||
|
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
||||||
|
|
||||||
|
auto outputs = op_info->outputs_ptr();
|
||||||
|
for (size_t i = 0; i < output_tensor_num; i++) {
|
||||||
|
nlohmann::json output_json;
|
||||||
|
auto type_id = this->GetOutputDataType(anf_node, i);
|
||||||
|
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
|
||||||
|
if (dtype.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] output [" << i << "] data type is null. ";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string output_name = outputs[i]->name();
|
||||||
|
output_json[kJsonKeyDataType] = dtype;
|
||||||
|
output_json[kJsonKeyFormat] = this->GetOutputFormat(anf_node, i);
|
||||||
|
output_json[kJsonKeyName] = output_name;
|
||||||
|
output_json[kJsonKeyTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc());
|
||||||
|
output_json[kJsonKeyShape] = this->GetOutputShape(anf_node, i);
|
||||||
|
outputs_json->push_back(output_json);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AkgKernelJsonGenerator::GetJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes,
|
||||||
|
const std::shared_ptr<OpAttr> &op_attr, nlohmann::json *const attr_json,
|
||||||
|
const ValuePtr &attr_value) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(op_attr);
|
||||||
|
MS_EXCEPTION_IF_NULL(attr_json);
|
||||||
|
std::string type = op_attr->type();
|
||||||
|
(*attr_json)[kJsonKeyDataType] = type;
|
||||||
|
if (type == "int") {
|
||||||
|
(*attr_json)[kJsonKeyValue] = GetValue<int>(attr_value);
|
||||||
|
} else if (type == "str") {
|
||||||
|
(*attr_json)[kJsonKeyValue] = GetValue<std::string>(attr_value);
|
||||||
|
} else if (type == "bool") {
|
||||||
|
(*attr_json)[kJsonKeyValue] = GetValue<bool>(attr_value);
|
||||||
|
} else if (type == "float") {
|
||||||
|
(*attr_json)[kJsonKeyValue] = GetValue<float>(attr_value);
|
||||||
|
} else if (type == "listInt") {
|
||||||
|
(*attr_json)[kJsonKeyValue] = GetValue<std::vector<int>>(attr_value);
|
||||||
|
} else if (type == "listStr") {
|
||||||
|
std::vector<std::string> data_format;
|
||||||
|
if (op_attr->name() == kArgDataformat) {
|
||||||
|
size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
|
||||||
|
for (size_t format_i = 0; format_i < tensor_args_num; format_i++) {
|
||||||
|
auto input_format = this->GetInputFormat(anf_node, format_i);
|
||||||
|
data_format.push_back(input_format);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
data_format = GetValue<std::vector<std::string>>(attr_value);
|
||||||
|
}
|
||||||
|
(*attr_json)[kJsonKeyValue] = data_format;
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "No valid json value for attr type: " << type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AkgKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
|
||||||
|
nlohmann::json *const attrs_json) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
|
MS_EXCEPTION_IF_NULL(attrs_json);
|
||||||
|
std::vector<std::shared_ptr<OpAttr>> attrs = op_info->attrs_ptr();
|
||||||
|
if (attrs.empty()) {
|
||||||
|
MS_LOG(INFO) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info attrs is empty";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info->inputs_ptr();
|
||||||
|
|
||||||
|
std::vector<int> dyn_input_sizes;
|
||||||
|
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
|
||||||
|
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info inputs is empty";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create input name list for "x_shape" in attr with "x" in primitive.
|
||||||
|
std::map<size_t, std::string> op_info_shape_name;
|
||||||
|
for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) {
|
||||||
|
std::string input_name = inputs[op_info_input_i]->name();
|
||||||
|
std::string x_shape_name = input_name + "_shape";
|
||||||
|
static_cast<void>(op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name)));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &op_attr : attrs) {
|
||||||
|
nlohmann::json attr_json;
|
||||||
|
ValuePtr attr_value = primitive->GetAttr(op_attr->name());
|
||||||
|
if (attr_value == nullptr && op_attr->name() != kArgDataformat) {
|
||||||
|
if (op_attr->param_type() == "required") {
|
||||||
|
// match "x_shape" in att with "x" in primitive.
|
||||||
|
std::string attr_name = op_attr->name();
|
||||||
|
auto find_item = std::find_if(
|
||||||
|
op_info_shape_name.begin(), op_info_shape_name.end(),
|
||||||
|
[attr_name](const std::map<size_t, std::string>::value_type item) { return item.second == attr_name; });
|
||||||
|
if (find_item != op_info_shape_name.end()) {
|
||||||
|
if (!dyn_input_sizes.empty()) {
|
||||||
|
if (find_item->first >= dyn_input_sizes.size() - 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first
|
||||||
|
<< " is out of range:" << dyn_input_sizes.size() - 1 << ".";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0));
|
||||||
|
for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) {
|
||||||
|
attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx);
|
||||||
|
attr_json[kJsonKeyName] = op_attr->name();
|
||||||
|
attrs_json->push_back(attr_json);
|
||||||
|
tensor_idx++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first);
|
||||||
|
attr_json[kJsonKeyName] = op_attr->name();
|
||||||
|
attrs_json->push_back(attr_json);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "op [" << anf_node->fullname_with_scope() << "] should have attr :" << op_attr->name();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value);
|
||||||
|
|
||||||
|
attr_json[kJsonKeyName] = op_attr->name();
|
||||||
|
attrs_json->push_back(attr_json);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t AkgKernelJsonGenerator::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (input_idx + 1 >= cnode->inputs().size()) {
|
||||||
|
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
|
||||||
|
<< cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input_node = cnode->input(input_idx + 1);
|
||||||
|
if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) {
|
||||||
|
size_t index = input_tensor_idx_.size();
|
||||||
|
input_tensor_idx_[input_node] = index;
|
||||||
|
}
|
||||||
|
|
||||||
|
return input_tensor_idx_[input_node];
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t AkgKernelJsonGenerator::GetOutputTensorIdxInc() {
|
||||||
|
size_t idx = output_tensor_idx_++;
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AkgKernelJsonGenerator::GetTensorName(const nlohmann::json &node_json, const std::string &tag,
|
||||||
|
const std::pair<size_t, size_t> &position) {
|
||||||
|
if (node_json.count(tag) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "].";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const &tag_desc = node_json[tag];
|
||||||
|
nlohmann::json first_index;
|
||||||
|
if (tag == kJsonKeyOutputDesc) {
|
||||||
|
first_index = tag_desc;
|
||||||
|
} else if (!tag_desc.is_array() || tag_desc.size() <= position.first) {
|
||||||
|
MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "].";
|
||||||
|
return "";
|
||||||
|
} else {
|
||||||
|
first_index = tag_desc[position.first];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!first_index.is_array() || first_index.size() <= position.second) {
|
||||||
|
MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "].";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
auto const &second_index = first_index[position.second];
|
||||||
|
if (second_index.count(kJsonKeyTensorName) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kJsonKeyTensorName << "].";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
return second_index[kJsonKeyTensorName];
|
||||||
|
}
|
||||||
|
|
||||||
|
void AkgKernelJsonGenerator::SetTensorName(const std::string &tag, const std::string &new_name,
|
||||||
|
const std::pair<size_t, size_t> &position, nlohmann::json *const node_json) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node_json);
|
||||||
|
if (node_json->count(tag) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "].";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
nlohmann::json *tag_desc = &((*node_json)[tag]);
|
||||||
|
nlohmann::json *first_index;
|
||||||
|
if (tag == kJsonKeyOutputDesc) {
|
||||||
|
first_index = tag_desc;
|
||||||
|
} else if (!tag_desc->is_array() || tag_desc->size() <= position.first) {
|
||||||
|
MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "].";
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
first_index = &((*tag_desc)[position.first]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!first_index->is_array() || first_index->size() <= position.second) {
|
||||||
|
MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "].";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
nlohmann::json *second_index = &((*first_index)[position.second]);
|
||||||
|
if (second_index->count(kJsonKeyTensorName) == 0) {
|
||||||
|
MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kJsonKeyTensorName << "].";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
(*second_index)[kJsonKeyTensorName] = new_name;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AkgKernelJsonGenerator::GenerateSingleKernelJson(const AnfNodePtr &anf_node, nlohmann::json *const node_json) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(node_json);
|
||||||
|
auto op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||||
|
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
|
||||||
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
|
|
||||||
|
// get basic params from currentNodeOpDesc
|
||||||
|
(*node_json)[kJsonKeyName] = op_name;
|
||||||
|
(*node_json)[kJsonKeyImplPath] = op_info->impl_path();
|
||||||
|
if (dump_option_.save_ptr_address) {
|
||||||
|
std::ostringstream get_the_address;
|
||||||
|
get_the_address << anf_node.get();
|
||||||
|
auto address = get_the_address.str();
|
||||||
|
(*node_json)[kJsonKeyPtrAddress] = address;
|
||||||
|
address_node_map_[address] = anf_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// input desc
|
||||||
|
nlohmann::json inputs_json;
|
||||||
|
if (!CreateInputDescJson(anf_node, op_info, &inputs_json)) {
|
||||||
|
MS_LOG(ERROR) << "Create input desc json failed, op[" << anf_node->fullname_with_scope() << "].";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(*node_json)[kJsonKeyInputDesc] = inputs_json;
|
||||||
|
MS_LOG(DEBUG) << "Akg create input desc json success.";
|
||||||
|
|
||||||
|
// output desc
|
||||||
|
nlohmann::json outputs_json;
|
||||||
|
if (!CreateOutputDescJson(anf_node, op_info, &outputs_json)) {
|
||||||
|
MS_LOG(ERROR) << "Create output desc json failed, op[" << anf_node->fullname_with_scope() << "].";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(*node_json)[kJsonKeyOutputDesc] = outputs_json;
|
||||||
|
MS_LOG(DEBUG) << "Akg create output desc json success.";
|
||||||
|
|
||||||
|
// attribute desc
|
||||||
|
nlohmann::json attrs_json;
|
||||||
|
if (!CreateAttrDescJson(anf_node, op_info, &attrs_json)) {
|
||||||
|
MS_LOG(ERROR) << "Create attr desc json failed, op[" << anf_node->fullname_with_scope() << "].";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(*node_json)[kJsonKeyAttr] = attrs_json;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AkgKernelJsonGenerator::GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
|
||||||
|
std::vector<size_t> *const output_size) {
|
||||||
|
if (input_size == nullptr || output_size == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "input size or output size is nullptr";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
input_size->clear();
|
||||||
|
output_size->clear();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < node_json[kJsonKeyInputDesc].size(); i++) {
|
||||||
|
for (size_t m = 0; m < node_json[kJsonKeyInputDesc][i].size(); m++) {
|
||||||
|
std::string dtype = node_json[kJsonKeyInputDesc][i][m][kJsonKeyDataType];
|
||||||
|
size_t nbyte = GetDtypeNbyte(dtype);
|
||||||
|
size_t size_i =
|
||||||
|
std::accumulate(node_json[kJsonKeyInputDesc][i][m][kJsonKeyShape].begin(),
|
||||||
|
node_json[kJsonKeyInputDesc][i][m][kJsonKeyShape].end(), nbyte, std::multiplies<size_t>());
|
||||||
|
input_size->push_back(size_i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < node_json[kJsonKeyOutputDesc].size(); i++) {
|
||||||
|
std::string dtype = node_json[kJsonKeyOutputDesc][i][kJsonKeyDataType];
|
||||||
|
size_t nbyte = GetDtypeNbyte(dtype);
|
||||||
|
size_t size_i =
|
||||||
|
std::accumulate(node_json[kJsonKeyOutputDesc][i][kJsonKeyShape].begin(),
|
||||||
|
node_json[kJsonKeyOutputDesc][i][kJsonKeyShape].end(), nbyte, std::multiplies<size_t>());
|
||||||
|
output_size->push_back(size_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::json *const kernel_json) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_json);
|
||||||
|
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||||
|
MS_LOG(INFO) << "Akg start generate kernel json desc, full scope name is : " << anf_node->fullname_with_scope();
|
||||||
|
SetAkgKernelAttrs(anf_node);
|
||||||
|
if (!GenerateSingleKernelJson(anf_node, kernel_json)) {
|
||||||
|
MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t hash_id = std::hash<std::string>()(kernel_json->dump());
|
||||||
|
kernel_name_ = op_name + "_";
|
||||||
|
(void)kernel_name_.append(std::to_string(hash_id));
|
||||||
|
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
|
||||||
|
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
||||||
|
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
||||||
|
(*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_node);
|
||||||
|
(*kernel_json)[kJsonKeyComposite] = false;
|
||||||
|
|
||||||
|
if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
|
||||||
|
MS_LOG(ERROR) << "Cal mem size failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Akg create kernel json desc success, full scope name is : " << anf_node->fullname_with_scope()
|
||||||
|
<< ", json info name is : " << kernel_name_;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
|
||||||
|
const std::vector<AnfNodePtr> &input_list,
|
||||||
|
const std::vector<AnfNodePtr> &output_list,
|
||||||
|
nlohmann::json *const kernel_json) {
|
||||||
|
if (anf_nodes.empty() || input_list.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size()
|
||||||
|
<< "].";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size()
|
||||||
|
<< "], output_list: [" << input_list.size() << "].";
|
||||||
|
std::map<AnfNodePtr, nlohmann::json> node_json_map;
|
||||||
|
|
||||||
|
for (auto const &anf_node : anf_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
if (!AnfAlgo::IsRealKernel(anf_node)) {
|
||||||
|
MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "].";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
SetAkgKernelAttrs(anf_node);
|
||||||
|
|
||||||
|
nlohmann::json node_json;
|
||||||
|
if (!GenerateSingleKernelJson(anf_node, &node_json)) {
|
||||||
|
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
|
||||||
|
if (primitive->GetAttr("fusion") != nullptr) {
|
||||||
|
node_json["fusion"] = primitive->GetAttr("fusion")->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
node_json_map[anf_node] = node_json;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto const &anf_node : anf_nodes) {
|
||||||
|
auto dyn_input_sizes = GetDynInputSize(anf_node);
|
||||||
|
bool is_dynamic_input = !dyn_input_sizes.empty();
|
||||||
|
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
|
||||||
|
size_t real_input_index = 0;
|
||||||
|
for (size_t i = 0; i < input_num; ++i) {
|
||||||
|
size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1;
|
||||||
|
for (size_t j = 0; j < input_tensor_num; ++j) {
|
||||||
|
auto tmp_input = GetKernelInput(anf_node, real_input_index);
|
||||||
|
std::string tensor_name = GetTensorName(node_json_map[anf_node], kJsonKeyInputDesc, std::make_pair(i, j));
|
||||||
|
if (node_json_map.find(tmp_input.first) != node_json_map.end()) {
|
||||||
|
std::string new_tensor_name =
|
||||||
|
GetTensorName(node_json_map[tmp_input.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_input.second));
|
||||||
|
SetTensorName(kJsonKeyInputDesc, new_tensor_name, std::make_pair(i, j), &(node_json_map[anf_node]));
|
||||||
|
MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of ["
|
||||||
|
<< anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output ["
|
||||||
|
<< new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "].";
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of ["
|
||||||
|
<< anf_node->fullname_with_scope() << "] is out input.";
|
||||||
|
}
|
||||||
|
real_input_index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<nlohmann::json> node_json_desc;
|
||||||
|
std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc),
|
||||||
|
[&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; });
|
||||||
|
(*kernel_json)[kJsonKeyOpDesc] = node_json_desc;
|
||||||
|
|
||||||
|
nlohmann::json inputs_json;
|
||||||
|
auto input_index = GetInputIndex(anf_nodes, input_list);
|
||||||
|
for (size_t i = 0; i < input_index.size(); ++i) {
|
||||||
|
auto tmp_input = input_index[i];
|
||||||
|
auto type_id = this->GetInputDataType(tmp_input.first, tmp_input.second.first);
|
||||||
|
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
|
||||||
|
nlohmann::json input_desc_json;
|
||||||
|
input_desc_json[kJsonKeyTensorName] =
|
||||||
|
GetTensorName(node_json_map[tmp_input.first], kJsonKeyInputDesc, tmp_input.second);
|
||||||
|
input_desc_json[kJsonKeyDataType] = dtype;
|
||||||
|
input_desc_json[kJsonKeyFormat] = this->GetInputFormat(tmp_input.first, tmp_input.second.first);
|
||||||
|
input_desc_json[kJsonKeyShape] = this->GetInputShape(tmp_input.first, tmp_input.second.first);
|
||||||
|
inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json});
|
||||||
|
}
|
||||||
|
(*kernel_json)[kJsonKeyInputDesc] = inputs_json;
|
||||||
|
|
||||||
|
nlohmann::json outputs_json;
|
||||||
|
auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
|
||||||
|
std::map<size_t, std::vector<std::string>> sub_graphs;
|
||||||
|
std::map<size_t, size_t> dim_infos;
|
||||||
|
for (size_t i = 0; i < output_index.size(); ++i) {
|
||||||
|
auto tmp_output = output_index[i];
|
||||||
|
bool found = false;
|
||||||
|
nlohmann::json output_desc_json;
|
||||||
|
for (size_t input_i = 0; input_i < input_list.size(); ++input_i) {
|
||||||
|
if (tmp_output.first == input_list[input_i]) {
|
||||||
|
output_desc_json = inputs_json[input_i][0];
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
auto type_id = this->GetOutputDataType(tmp_output.first, tmp_output.second);
|
||||||
|
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
|
||||||
|
output_desc_json[kJsonKeyTensorName] =
|
||||||
|
GetTensorName(node_json_map[tmp_output.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second));
|
||||||
|
output_desc_json[kJsonKeyDataType] = dtype;
|
||||||
|
output_desc_json[kJsonKeyFormat] = this->GetOutputFormat(tmp_output.first, tmp_output.second);
|
||||||
|
auto output_shape = this->GetOutputShape(tmp_output.first, tmp_output.second);
|
||||||
|
if (output_shape.empty()) {
|
||||||
|
output_shape.push_back(1);
|
||||||
|
}
|
||||||
|
output_desc_json[kJsonKeyShape] = output_shape;
|
||||||
|
}
|
||||||
|
outputs_json.emplace_back(output_desc_json);
|
||||||
|
}
|
||||||
|
(*kernel_json)[kJsonKeyOutputDesc] = outputs_json;
|
||||||
|
|
||||||
|
auto processor = GetProcessorStr(anf_nodes[0]);
|
||||||
|
|
||||||
|
size_t hash_id = std::hash<std::string>()(kernel_json->dump());
|
||||||
|
kernel_name_ = "Fused_";
|
||||||
|
auto fg = anf_nodes[0]->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||||
|
if (attr_val != nullptr) {
|
||||||
|
auto fg_attr = GetValue<std::string>(attr_val);
|
||||||
|
(void)kernel_name_.append(fg_attr).append("_");
|
||||||
|
}
|
||||||
|
(void)kernel_name_.append(std::to_string(hash_id));
|
||||||
|
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
|
||||||
|
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
||||||
|
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
||||||
|
(*kernel_json)[kJsonKeyProcess] = processor;
|
||||||
|
(*kernel_json)[kJsonKeyComposite] = true;
|
||||||
|
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();
|
||||||
|
|
||||||
|
if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
|
||||||
|
MS_LOG(ERROR) << "Cal mem size failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node) {
|
||||||
|
kernel_json_ = nlohmann::json();
|
||||||
|
return CollectJson(anf_node, &kernel_json_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
|
||||||
|
const std::vector<AnfNodePtr> &input_list,
|
||||||
|
const std::vector<AnfNodePtr> &output_list) {
|
||||||
|
kernel_json_ = nlohmann::json();
|
||||||
|
return CollectFusedJson(anf_nodes, input_list, output_list, &kernel_json_);
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,125 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_GENERATOR_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_GENERATOR_H_
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
// json key
|
||||||
|
constexpr auto kJsonKeyOpDesc = "op_desc";
|
||||||
|
constexpr auto kJsonKeyAttr = "attr";
|
||||||
|
constexpr auto kJsonKeyInputDesc = "input_desc";
|
||||||
|
constexpr auto kJsonKeyFormat = "format";
|
||||||
|
constexpr auto kJsonKeyInferDataType = "infer_data_type";
|
||||||
|
constexpr auto kJsonKeyInferShape = "infer_shape";
|
||||||
|
constexpr auto kJsonKeyShape = "shape";
|
||||||
|
constexpr auto kJsonKeyDataType = "data_type";
|
||||||
|
constexpr auto kJsonKeyOutputDesc = "output_desc";
|
||||||
|
constexpr auto kJsonKeyName = "name";
|
||||||
|
constexpr auto kJsonKeyTensorName = "tensor_name";
|
||||||
|
constexpr auto kJsonKeyValue = "value";
|
||||||
|
constexpr auto kJsonKeyImplPath = "impl_path";
|
||||||
|
constexpr auto kJsonKeyProcess = "process";
|
||||||
|
constexpr auto kJsonKeyComposite = "composite";
|
||||||
|
constexpr auto kJsonKeyId = "id";
|
||||||
|
constexpr auto kJsonKeyOp = "op";
|
||||||
|
constexpr auto kJsonKeyPtrAddress = "ptr_address";
|
||||||
|
constexpr auto kJsonKeyCompositeGraph = "composite_graph";
|
||||||
|
constexpr auto kJsonKeyPlatform = "platform";
|
||||||
|
|
||||||
|
constexpr auto kAttrInputNames = "input_names";
|
||||||
|
|
||||||
|
// dump option
|
||||||
|
struct DumpOption {
|
||||||
|
bool is_before_select_kernel = false;
|
||||||
|
bool save_ptr_address = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AkgKernelJsonGenerator {
|
||||||
|
public:
|
||||||
|
AkgKernelJsonGenerator() { Clear(); }
|
||||||
|
explicit AkgKernelJsonGenerator(DumpOption dump_option) : dump_option_(dump_option) { Clear(); }
|
||||||
|
~AkgKernelJsonGenerator() = default;
|
||||||
|
|
||||||
|
bool CollectJson(const AnfNodePtr &anf_node, nlohmann::json *const kernel_json);
|
||||||
|
bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
|
||||||
|
const std::vector<AnfNodePtr> &output_list, nlohmann::json *const kernel_json);
|
||||||
|
bool CollectJson(const AnfNodePtr &anf_node);
|
||||||
|
bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
|
||||||
|
const std::vector<AnfNodePtr> &output_list);
|
||||||
|
bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, nlohmann::json *const node_json);
|
||||||
|
std::string kernel_name() const { return kernel_name_; }
|
||||||
|
nlohmann::json kernel_json() const { return kernel_json_; }
|
||||||
|
std::string kernel_json_str() const { return kernel_json_.dump(); }
|
||||||
|
const std::vector<size_t> &input_size_list() const { return input_size_list_; }
|
||||||
|
const std::vector<size_t> &output_size_list() const { return output_size_list_; }
|
||||||
|
void Clear() {
|
||||||
|
input_tensor_idx_.clear();
|
||||||
|
address_node_map_.clear();
|
||||||
|
output_tensor_idx_ = 0;
|
||||||
|
}
|
||||||
|
void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; }
|
||||||
|
std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool CreateInputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
|
||||||
|
nlohmann::json *const inputs_json);
|
||||||
|
bool CreateOutputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
|
||||||
|
nlohmann::json *const outputs_json);
|
||||||
|
void GetJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes,
|
||||||
|
const std::shared_ptr<OpAttr> &op_attr, nlohmann::json *const attr_json, const ValuePtr &attr_value);
|
||||||
|
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
|
||||||
|
nlohmann::json *const attrs_json);
|
||||||
|
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
|
||||||
|
std::vector<size_t> *const output_size);
|
||||||
|
int GetOpCntInc();
|
||||||
|
size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
|
||||||
|
size_t GetOutputTensorIdxInc();
|
||||||
|
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
|
||||||
|
nlohmann::json *const node_json);
|
||||||
|
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
|
||||||
|
const std::pair<size_t, size_t> &position);
|
||||||
|
TypeId GetInputDataType(const AnfNodePtr &anf_node, size_t real_index);
|
||||||
|
std::vector<size_t> GetInputShape(const AnfNodePtr &anf_node, size_t real_index);
|
||||||
|
std::string GetInputFormat(const AnfNodePtr &anf_node, size_t real_index);
|
||||||
|
TypeId GetOutputDataType(const AnfNodePtr &anf_node, size_t index);
|
||||||
|
std::vector<size_t> GetOutputShape(const AnfNodePtr &anf_node, size_t index);
|
||||||
|
std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index);
|
||||||
|
|
||||||
|
DumpOption dump_option_;
|
||||||
|
static int op_cnt_;
|
||||||
|
// lock for variable fusionOpCnt in singleton mode
|
||||||
|
static std::mutex op_cnt_mtx_;
|
||||||
|
std::string kernel_name_;
|
||||||
|
std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_;
|
||||||
|
size_t output_tensor_idx_;
|
||||||
|
nlohmann::json kernel_json_;
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::map<std::string, AnfNodePtr> address_node_map_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_GENERATOR_H_
|
|
@ -29,6 +29,7 @@
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
||||||
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h"
|
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h"
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "backend/session/kernel_build_client.h"
|
#include "backend/session/kernel_build_client.h"
|
||||||
|
@ -38,287 +39,37 @@ namespace kernel {
|
||||||
constexpr int32_t PROCESS_NUM = 16;
|
constexpr int32_t PROCESS_NUM = 16;
|
||||||
constexpr int32_t TIME_OUT = 300;
|
constexpr int32_t TIME_OUT = 300;
|
||||||
|
|
||||||
constexpr auto kOpDesc = "op_desc";
|
bool AkgAscendKernelBuilder::AkgOpParallelBuild(
|
||||||
constexpr auto kShape = "shape";
|
const std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> &build_args) {
|
||||||
constexpr auto kDataType = "data_type";
|
|
||||||
constexpr auto kInputDesc = "input_desc";
|
|
||||||
constexpr auto kOutputDesc = "output_desc";
|
|
||||||
constexpr auto kTensorName = "tensor_name";
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
void UpdateTensorNameInJson(const std::vector<AnfNodePtr> &anf_nodes,
|
|
||||||
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
|
|
||||||
for (auto const &anf_node : anf_nodes) {
|
|
||||||
std::vector<int> dyn_input_sizes;
|
|
||||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
|
|
||||||
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
|
|
||||||
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_dynamic_input = !dyn_input_sizes.empty();
|
|
||||||
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
|
|
||||||
size_t real_input_index = 0;
|
|
||||||
for (size_t i = 0; i < input_num; ++i) {
|
|
||||||
size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1;
|
|
||||||
for (size_t j = 0; j < input_tensor_num; ++j) {
|
|
||||||
auto tmp_input = GetKernelInput(anf_node, real_input_index);
|
|
||||||
std::string tensor_name = GetTensorName((*node_json_map)[anf_node], kInputDesc, std::make_pair(i, j));
|
|
||||||
if (node_json_map->find(tmp_input.first) != node_json_map->end()) {
|
|
||||||
std::string new_tensor_name =
|
|
||||||
GetTensorName((*node_json_map)[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second));
|
|
||||||
SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &((*node_json_map)[anf_node]));
|
|
||||||
MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of ["
|
|
||||||
<< anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output ["
|
|
||||||
<< new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "].";
|
|
||||||
} else {
|
|
||||||
MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of ["
|
|
||||||
<< anf_node->fullname_with_scope() << "] is out input.";
|
|
||||||
}
|
|
||||||
real_input_index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
nlohmann::json GetInputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
|
|
||||||
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
|
|
||||||
nlohmann::json inputs_json;
|
|
||||||
auto input_index = GetInputIndex(anf_nodes, input_list);
|
|
||||||
for (size_t i = 0; i < input_index.size(); ++i) {
|
|
||||||
auto tmp_input = input_index[i];
|
|
||||||
auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first);
|
|
||||||
std::string dtype = TypeId2String(type_id);
|
|
||||||
nlohmann::json input_desc_json;
|
|
||||||
input_desc_json[kTensorName] = GetTensorName((*node_json_map)[tmp_input.first], kInputDesc, tmp_input.second);
|
|
||||||
input_desc_json[kDataType] = dtype;
|
|
||||||
input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first);
|
|
||||||
inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json});
|
|
||||||
}
|
|
||||||
|
|
||||||
return inputs_json;
|
|
||||||
}
|
|
||||||
|
|
||||||
nlohmann::json GetOutputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
|
|
||||||
const std::vector<AnfNodePtr> &output_list, const nlohmann::json &inputs_json,
|
|
||||||
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
|
|
||||||
nlohmann::json outputs_json;
|
|
||||||
auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
|
|
||||||
for (size_t i = 0; i < output_index.size(); ++i) {
|
|
||||||
auto tmp_output = output_index[i];
|
|
||||||
bool found = false;
|
|
||||||
nlohmann::json output_desc_json;
|
|
||||||
for (size_t input_i = 0; input_i < input_list.size(); ++input_i) {
|
|
||||||
if (tmp_output.first == input_list[input_i]) {
|
|
||||||
output_desc_json = inputs_json[input_i][0];
|
|
||||||
found = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!found) {
|
|
||||||
auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second);
|
|
||||||
std::string dtype = TypeId2String(type_id);
|
|
||||||
output_desc_json[kTensorName] =
|
|
||||||
GetTensorName((*node_json_map)[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second));
|
|
||||||
output_desc_json[kDataType] = dtype;
|
|
||||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second);
|
|
||||||
if (output_shape.empty()) {
|
|
||||||
output_shape.push_back(1);
|
|
||||||
}
|
|
||||||
output_desc_json[kShape] = output_shape;
|
|
||||||
}
|
|
||||||
outputs_json.emplace_back(output_desc_json);
|
|
||||||
}
|
|
||||||
|
|
||||||
return outputs_json;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<std::vector<std::string>, std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>>> PreProcessJsonForBuild(
|
|
||||||
const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) {
|
|
||||||
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
|
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
|
||||||
std::vector<std::string> jsons;
|
std::vector<std::string> jsons;
|
||||||
std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> repeat_nodes;
|
std::unordered_set<std::string> kernel_name_set;
|
||||||
std::unordered_set<std::string> json_name_set;
|
std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> repeat_nodes;
|
||||||
for (const auto &[builder, anf_node] : build_args) {
|
for (const auto &[json_generator, anf_node] : build_args) {
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
auto json_name = builder.json_name();
|
auto kernel_name = json_generator.kernel_name();
|
||||||
MS_LOG(DEBUG) << "Akg start compile op: " << json_name;
|
MS_LOG(DEBUG) << "Akg start compile op: " << kernel_name;
|
||||||
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
|
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node));
|
||||||
if (cached_kernel_pack != nullptr) {
|
if (cached_kernel_pack != nullptr) {
|
||||||
MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope["
|
MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
|
||||||
<< anf_node->fullname_with_scope() << "].";
|
<< anf_node->fullname_with_scope() << "].";
|
||||||
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
|
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
|
||||||
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
||||||
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
||||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (json_name_set.count(json_name) != 0) {
|
if (kernel_name_set.count(kernel_name) != 0) {
|
||||||
repeat_nodes.push_back({builder, anf_node});
|
repeat_nodes.push_back({json_generator, anf_node});
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
json_name_set.insert(json_name);
|
kernel_name_set.insert(kernel_name);
|
||||||
auto node_json = builder.kernel_json();
|
auto kernel_json = json_generator.kernel_json_str();
|
||||||
kernel::SaveJsonInfo(json_name, node_json);
|
kernel::SaveJsonInfo(kernel_name, kernel_json);
|
||||||
jsons.push_back(node_json);
|
jsons.push_back(kernel_json);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_pair(jsons, repeat_nodes);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool PostProcessAfterCompile(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args,
|
|
||||||
const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &repeat_nodes) {
|
|
||||||
for (const auto &[builder, anf_node] : build_args) {
|
|
||||||
auto json_name = builder.json_name();
|
|
||||||
auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
|
|
||||||
if (new_kernel_pack == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope["
|
|
||||||
<< anf_node->fullname_with_scope() << "].";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack);
|
|
||||||
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
|
|
||||||
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
|
|
||||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
|
||||||
MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!";
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto &[builder, anf_node] : repeat_nodes) {
|
|
||||||
auto node_json = builder.kernel_json();
|
|
||||||
auto json_name = builder.json_name();
|
|
||||||
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
|
|
||||||
if (cached_kernel_pack == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope["
|
|
||||||
<< anf_node->fullname_with_scope() << "].";
|
|
||||||
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
|
|
||||||
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
|
|
||||||
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
|
|
||||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
|
||||||
MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
|
|
||||||
auto it = kAkgKernelAttrsProcessMap.find(op_name);
|
|
||||||
if (it != kAkgKernelAttrsProcessMap.end()) {
|
|
||||||
it->second(anf_node);
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
|
|
||||||
nlohmann::json node_json;
|
|
||||||
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
|
|
||||||
MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed.";
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel_json_ = node_json.dump();
|
|
||||||
|
|
||||||
if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) {
|
|
||||||
MS_LOG(ERROR) << "Cal mem size failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AkgAscendKernelBuilder::GenJsonAndPreprocess4Fused(const std::vector<AnfNodePtr> &anf_nodes,
|
|
||||||
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
|
|
||||||
for (auto const &anf_node : anf_nodes) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
|
||||||
if (!AnfAlgo::IsRealKernel(anf_node)) {
|
|
||||||
MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "].";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto it = kAkgKernelAttrsProcessMap.find(op_name);
|
|
||||||
if (it != kAkgKernelAttrsProcessMap.end()) {
|
|
||||||
it->second(anf_node);
|
|
||||||
}
|
|
||||||
|
|
||||||
nlohmann::json node_json;
|
|
||||||
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
|
|
||||||
MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// No need for composite op.
|
|
||||||
node_json.erase("id");
|
|
||||||
node_json.erase("op");
|
|
||||||
node_json.erase("composite");
|
|
||||||
|
|
||||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
|
|
||||||
if (primitive->GetAttr("fusion") != nullptr) {
|
|
||||||
node_json["fusion"] = primitive->GetAttr("fusion")->ToString();
|
|
||||||
}
|
|
||||||
|
|
||||||
(*node_json_map)[anf_node] = node_json;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
|
|
||||||
const std::vector<AnfNodePtr> &input_list,
|
|
||||||
const std::vector<AnfNodePtr> &output_list) {
|
|
||||||
if (anf_nodes.empty() || input_list.empty()) {
|
|
||||||
MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size()
|
|
||||||
<< "].";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list ["
|
|
||||||
<< input_list.size() << "].";
|
|
||||||
|
|
||||||
std::map<AnfNodePtr, nlohmann::json> node_json_map;
|
|
||||||
if (!GenJsonAndPreprocess4Fused(anf_nodes, &node_json_map)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
UpdateTensorNameInJson(anf_nodes, &node_json_map);
|
|
||||||
|
|
||||||
nlohmann::json fused_node_json;
|
|
||||||
std::vector<nlohmann::json> node_json_desc;
|
|
||||||
std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc),
|
|
||||||
[&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; });
|
|
||||||
fused_node_json[kOpDesc] = node_json_desc;
|
|
||||||
fused_node_json[kInputDesc] = GetInputsJson(anf_nodes, input_list, &node_json_map);
|
|
||||||
fused_node_json[kOutputDesc] =
|
|
||||||
GetOutputsJson(anf_nodes, input_list, output_list, fused_node_json[kInputDesc], &node_json_map);
|
|
||||||
|
|
||||||
size_t hash_id = std::hash<std::string>()(fused_node_json.dump());
|
|
||||||
json_name_ = "Fused_";
|
|
||||||
auto fg = anf_nodes[0]->func_graph();
|
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
|
||||||
auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
|
||||||
if (attr_val != nullptr) {
|
|
||||||
auto fg_attr = GetValue<std::string>(attr_val);
|
|
||||||
(void)json_name_.append(fg_attr).append("_");
|
|
||||||
}
|
|
||||||
(void)json_name_.append(std::to_string(hash_id));
|
|
||||||
fused_node_json["composite_graph"] = fg->ToString();
|
|
||||||
fused_node_json["op"] = json_name_;
|
|
||||||
fused_node_json["platform"] = "AKG";
|
|
||||||
fused_node_json["process"] = "aicore";
|
|
||||||
fused_node_json["composite"] = true;
|
|
||||||
|
|
||||||
kernel_json_ = fused_node_json.dump();
|
|
||||||
|
|
||||||
if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) {
|
|
||||||
MS_LOG(ERROR) << "Cal mem size failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) {
|
|
||||||
auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args);
|
|
||||||
if (jsons.empty()) {
|
if (jsons.empty()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -337,18 +88,43 @@ bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfN
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!PostProcessAfterCompile(build_args, repeat_nodes)) {
|
// All unique done here, cache them and set kernel.
|
||||||
return false;
|
for (const auto &[json_generator, anf_node] : build_args) {
|
||||||
|
auto kernel_name = json_generator.kernel_name();
|
||||||
|
auto new_kernel_pack = tbe::TbeUtils::InsertCache(kernel_name, GetProcessorStr(anf_node));
|
||||||
|
if (new_kernel_pack == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
|
||||||
|
<< anf_node->fullname_with_scope() << "].";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack);
|
||||||
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
||||||
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
||||||
|
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
||||||
|
MS_LOG(DEBUG) << "Akg compile " << kernel_name << " kernel and insert cache successfully!";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle repeated nodes.
|
||||||
|
for (const auto &[json_generator, anf_node] : repeat_nodes) {
|
||||||
|
auto kernel_name = json_generator.kernel_name();
|
||||||
|
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node));
|
||||||
|
if (cached_kernel_pack == nullptr) return false;
|
||||||
|
MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
|
||||||
|
<< anf_node->fullname_with_scope() << "].";
|
||||||
|
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
|
||||||
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
||||||
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
||||||
|
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
||||||
std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> json_and_node;
|
std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> json_and_node;
|
||||||
for (const auto &anf_node : anf_nodes) {
|
for (const auto &anf_node : anf_nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
AkgAscendKernelBuilder akg_cce_kernel_builder;
|
AkgKernelJsonGenerator akg_kernel_json_generator;
|
||||||
KernelPackPtr kernel_pack = nullptr;
|
KernelPackPtr kernel_pack = nullptr;
|
||||||
auto cnode = anf_node->cast<CNodePtr>();
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
@ -363,18 +139,17 @@ bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
||||||
std::vector<AnfNodePtr> node_list;
|
std::vector<AnfNodePtr> node_list;
|
||||||
std::vector<AnfNodePtr> input_list;
|
std::vector<AnfNodePtr> input_list;
|
||||||
std::vector<AnfNodePtr> output_list;
|
std::vector<AnfNodePtr> output_list;
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
MS_LOG(INFO) << "Akg start compile composite op[" << anf_node->fullname_with_scope() << "]";
|
||||||
MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]";
|
|
||||||
GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
|
GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
|
||||||
if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) {
|
if (!akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) {
|
||||||
MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "].";
|
MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << anf_node->fullname_with_scope() << "].";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!akg_cce_kernel_builder.CollectJson(anf_node)) {
|
if (!akg_kernel_json_generator.CollectJson(anf_node)) {
|
||||||
MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "].";
|
MS_EXCEPTION(UnknownError) << "Akg build failed op[" << anf_node->fullname_with_scope() << "].";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
json_and_node.push_back({akg_cce_kernel_builder, anf_node});
|
json_and_node.push_back({akg_kernel_json_generator, anf_node});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (json_and_node.empty()) {
|
if (json_and_node.empty()) {
|
||||||
|
@ -382,7 +157,8 @@ bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return AkgOpParallelBuild(json_and_node);
|
AkgAscendKernelBuilder akg_ascend_kernel_builder;
|
||||||
|
return akg_ascend_kernel_builder.AkgOpParallelBuild(json_and_node);
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,35 +18,21 @@
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class AkgAscendKernelBuilder : public AkgKernelBuild {
|
class AkgAscendKernelBuilder {
|
||||||
public:
|
public:
|
||||||
AkgAscendKernelBuilder() = default;
|
AkgAscendKernelBuilder() = default;
|
||||||
~AkgAscendKernelBuilder() = default;
|
~AkgAscendKernelBuilder() = default;
|
||||||
|
|
||||||
bool CollectJson(const AnfNodePtr &anf_node);
|
bool AkgOpParallelBuild(const std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> &build_args);
|
||||||
bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
|
|
||||||
const std::vector<AnfNodePtr> &output_list);
|
|
||||||
std::string json_name() const { return json_name_; }
|
|
||||||
std::string kernel_json() const { return kernel_json_; }
|
|
||||||
const std::vector<size_t> &input_size_list() const { return input_size_list_; }
|
|
||||||
const std::vector<size_t> &output_size_list() const { return output_size_list_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
bool GenJsonAndPreprocess4Fused(const std::vector<AnfNodePtr> &anf_nodes,
|
|
||||||
std::map<AnfNodePtr, nlohmann::json> *node_json_map);
|
|
||||||
|
|
||||||
std::string kernel_json_;
|
|
||||||
std::vector<size_t> input_size_list_;
|
|
||||||
std::vector<size_t> output_size_list_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
|
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
|
||||||
|
|
|
@ -15,29 +15,116 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h"
|
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h"
|
||||||
|
#include <Python.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h"
|
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h"
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/session/kernel_build_client.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node) {
|
constexpr int32_t ARGS_SIZE = 1;
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
constexpr auto kCompileWithJsonFunc = "compilewithjson";
|
||||||
AkgKernelBuild akg_kernel_build;
|
|
||||||
|
|
||||||
std::vector<size_t> input_size_list;
|
KernelPackPtr AkgGpuKernelBuilder::OpBuild(const AkgKernelJsonGenerator &json_generator, const AnfNodePtr &anf_node) {
|
||||||
std::vector<size_t> output_size_list;
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
KernelPackPtr kernel_pack = akg_kernel_build.BuildByJson(anf_node, &input_size_list, &output_size_list);
|
auto processor = GetProcessorStr(anf_node);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_pack);
|
auto kernel_name = json_generator.kernel_name();
|
||||||
|
auto cached_kernel_pack = SearchCache(kernel_name, processor);
|
||||||
|
if (cached_kernel_pack != nullptr) {
|
||||||
|
MS_LOG(INFO) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
|
||||||
|
<< anf_node->fullname_with_scope() << "].";
|
||||||
|
return cached_kernel_pack;
|
||||||
|
}
|
||||||
|
|
||||||
|
(void)alarm(AUTODIFF_COMPILE_OVERTIME);
|
||||||
|
auto kernel_json = json_generator.kernel_json_str();
|
||||||
|
auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(kernel_json);
|
||||||
|
(void)alarm(0);
|
||||||
|
if (!res) {
|
||||||
|
MS_LOG(ERROR) << "Akg compile failed, json: " << kernel_json;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto new_kernel_pack = InsertCache(kernel_name, processor);
|
||||||
|
kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path());
|
||||||
|
if (new_kernel_pack == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
|
||||||
|
<< anf_node->fullname_with_scope() << "].";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return new_kernel_pack;
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelModPtr AkgGpuKernelBuilder::BuildByJson(const AnfNodePtr &anf_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_LOG(INFO) << "Akg start compile, op[" << anf_node->fullname_with_scope() << "]";
|
||||||
|
AkgKernelJsonGenerator json_generator;
|
||||||
|
if (!json_generator.CollectJson(anf_node)) {
|
||||||
|
MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel_pack = OpBuild(json_generator, anf_node);
|
||||||
|
if (kernel_pack == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Akg build failed op[" << anf_node->fullname_with_scope() << "].";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
auto kernel_mod_ptr = std::make_shared<GpuKernelMod>(kernel_pack);
|
auto kernel_mod_ptr = std::make_shared<GpuKernelMod>(kernel_pack);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
||||||
kernel_mod_ptr->SetInputSizeList(input_size_list);
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
||||||
kernel_mod_ptr->SetOutputSizeList(output_size_list);
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
||||||
|
MS_LOG(INFO) << "Akg compile success, op[" << anf_node->fullname_with_scope() << "]";
|
||||||
return kernel_mod_ptr;
|
return kernel_mod_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KernelModPtr AkgGpuKernelBuilder::FuseByJson(const AnfNodePtr &anf_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_LOG(INFO) << "Akg start compile, graph_kernel[" << anf_node->fullname_with_scope() << "]";
|
||||||
|
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
auto mng = fg->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(fg, true);
|
||||||
|
fg->set_manager(mng);
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtrList node_list;
|
||||||
|
AnfNodePtrList input_list;
|
||||||
|
AnfNodePtrList output_list;
|
||||||
|
GetValidKernelNodes(fg, &node_list, &input_list, &output_list);
|
||||||
|
AkgKernelJsonGenerator json_generator;
|
||||||
|
if (!json_generator.CollectFusedJson(node_list, input_list, output_list)) {
|
||||||
|
MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel_pack = OpBuild(json_generator, anf_node);
|
||||||
|
if (kernel_pack == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Akg build failed, graph_kernel[" << anf_node->fullname_with_scope() << "].";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel_mod_ptr = std::make_shared<GpuKernelMod>(kernel_pack);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
||||||
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
||||||
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
||||||
|
MS_LOG(INFO) << "Akg compile success, graph_kernel[" << anf_node->fullname_with_scope() << "]";
|
||||||
|
return kernel_mod_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
AkgGpuKernelBuilder akg_gpu_kernel_builder;
|
||||||
|
if (AnfAlgo::IsGraphKernel(anf_node)) {
|
||||||
|
return akg_gpu_kernel_builder.FuseByJson(anf_node);
|
||||||
|
}
|
||||||
|
return akg_gpu_kernel_builder.BuildByJson(anf_node);
|
||||||
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -16,11 +16,25 @@
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_
|
||||||
|
#include <string>
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
#include "base/base.h"
|
#include "base/base.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
class AkgGpuKernelBuilder {
|
||||||
|
public:
|
||||||
|
AkgGpuKernelBuilder() = default;
|
||||||
|
~AkgGpuKernelBuilder() = default;
|
||||||
|
|
||||||
|
KernelModPtr BuildByJson(const AnfNodePtr &anf_node);
|
||||||
|
KernelModPtr FuseByJson(const AnfNodePtr &anf_node);
|
||||||
|
|
||||||
|
private:
|
||||||
|
KernelPackPtr OpBuild(const AkgKernelJsonGenerator &json_generator, const AnfNodePtr &anf_node);
|
||||||
|
};
|
||||||
|
|
||||||
KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node);
|
KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -205,10 +205,13 @@ TypeId DtypeToTypeId(const std::string &dtypes) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string TypeId2String(TypeId type_id) {
|
std::string TypeId2String(TypeId type_id, bool unknown_as_default) {
|
||||||
auto iter = type_id_str_map.find(type_id);
|
auto iter = type_id_str_map.find(type_id);
|
||||||
if (iter == type_id_str_map.end()) {
|
if (iter == type_id_str_map.end()) {
|
||||||
return std::string(TypeIdLabel(type_id));
|
if (!unknown_as_default) {
|
||||||
|
MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
|
||||||
|
}
|
||||||
|
return "float32";
|
||||||
}
|
}
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
@ -427,9 +430,9 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SaveJsonInfo(const std::string &json_name, const std::string &info) {
|
void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
|
||||||
char real_path[PATH_MAX] = {0};
|
char real_path[PATH_MAX] = {0};
|
||||||
std::string path = kCceKernelMeta + json_name + kInfoSuffix;
|
std::string path = base_path + json_name + kInfoSuffix;
|
||||||
if (path.size() > PATH_MAX) {
|
if (path.size() > PATH_MAX) {
|
||||||
MS_LOG(DEBUG) << "file path " << path << " is too long.";
|
MS_LOG(DEBUG) << "file path " << path << " is too long.";
|
||||||
return;
|
return;
|
||||||
|
@ -458,6 +461,14 @@ void SaveJsonInfo(const std::string &json_name, const std::string &info) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Processor GetProcessor(const string &processor) {
|
||||||
|
if (processor == kProcessorAiCore) return Processor::AICORE;
|
||||||
|
if (processor == kProcessorAiCpu) return Processor::AICPU;
|
||||||
|
if (processor == kProcessorCuda) return Processor::CUDA;
|
||||||
|
MS_LOG(DEBUG) << "Unknown processor type.";
|
||||||
|
return Processor::UNKNOWN;
|
||||||
|
}
|
||||||
|
|
||||||
std::string GetProcessor(const AnfNodePtr &anf_node) {
|
std::string GetProcessor(const AnfNodePtr &anf_node) {
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
std::string device;
|
std::string device;
|
||||||
|
@ -628,16 +639,21 @@ void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr>
|
||||||
|
|
||||||
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
|
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
|
||||||
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
|
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(node_list);
|
MS_EXCEPTION_IF_NULL(node_list);
|
||||||
MS_EXCEPTION_IF_NULL(input_list);
|
MS_EXCEPTION_IF_NULL(input_list);
|
||||||
MS_EXCEPTION_IF_NULL(output_list);
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
|
|
||||||
GetValidKernelNodes(func_graph, node_list);
|
GetValidKernelNodes(func_graph, node_list);
|
||||||
|
|
||||||
auto parameters = func_graph->parameters();
|
auto parameters = func_graph->parameters();
|
||||||
input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
|
input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
|
||||||
|
|
||||||
|
GetFuncGraphOutputNodes(func_graph, output_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(output_list);
|
||||||
auto func_output = func_graph->output();
|
auto func_output = func_graph->output();
|
||||||
MS_EXCEPTION_IF_NULL(func_output);
|
MS_EXCEPTION_IF_NULL(func_output);
|
||||||
if (func_output->isa<CNode>()) {
|
if (func_output->isa<CNode>()) {
|
||||||
|
@ -780,5 +796,36 @@ std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode) {
|
||||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
|
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
|
||||||
return axis;
|
return axis;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string GetProcessorStr(const AnfNodePtr &anf_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
std::string processor = kProcessorUnknown;
|
||||||
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
|
auto build_info = kernel_info->select_kernel_build_info();
|
||||||
|
// we may call this before kernel select.
|
||||||
|
if (build_info == nullptr) {
|
||||||
|
return processor;
|
||||||
|
}
|
||||||
|
switch (build_info->processor()) {
|
||||||
|
case Processor::AICORE:
|
||||||
|
processor = kProcessorAiCore;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case Processor::AICPU:
|
||||||
|
processor = kProcessorAiCpu;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case Processor::CUDA:
|
||||||
|
processor = kProcessorCuda;
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
MS_LOG(ERROR) << "Unknown processor type.";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return processor;
|
||||||
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
|
@ -37,6 +38,7 @@ constexpr auto kGpuKernelMeta = "./cuda_meta";
|
||||||
constexpr auto kProcessorAiCore = "aicore";
|
constexpr auto kProcessorAiCore = "aicore";
|
||||||
constexpr auto kProcessorAiCpu = "aicpu";
|
constexpr auto kProcessorAiCpu = "aicpu";
|
||||||
constexpr auto kProcessorCuda = "cuda";
|
constexpr auto kProcessorCuda = "cuda";
|
||||||
|
constexpr auto kProcessorUnknown = "unknown";
|
||||||
constexpr auto kJsonSuffix = ".json";
|
constexpr auto kJsonSuffix = ".json";
|
||||||
constexpr auto kInfoSuffix = ".info";
|
constexpr auto kInfoSuffix = ".info";
|
||||||
constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600;
|
constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600;
|
||||||
|
@ -76,12 +78,13 @@ KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &pro
|
||||||
KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor);
|
KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor);
|
||||||
TypeId DtypeToTypeId(const std::string &dtypes);
|
TypeId DtypeToTypeId(const std::string &dtypes);
|
||||||
std::string Dtype2ShortType(const std::string &dtypes);
|
std::string Dtype2ShortType(const std::string &dtypes);
|
||||||
std::string TypeId2String(TypeId type_id);
|
std::string TypeId2String(TypeId type_id, bool unknown_as_default = false);
|
||||||
size_t GetDtypeNbyte(const std::string &dtypes);
|
size_t GetDtypeNbyte(const std::string &dtypes);
|
||||||
bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
|
bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
|
||||||
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list);
|
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list);
|
||||||
void SaveJsonInfo(const std::string &json_name, const std::string &info);
|
void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path = kCceKernelMeta);
|
||||||
std::string GetProcessor(const AnfNodePtr &anf_node);
|
std::string GetProcessor(const AnfNodePtr &anf_node);
|
||||||
|
Processor GetProcessor(const string &processor);
|
||||||
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
|
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
|
||||||
int Sign(float x);
|
int Sign(float x);
|
||||||
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index);
|
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index);
|
||||||
|
@ -90,13 +93,26 @@ std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(cons
|
||||||
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
|
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
|
||||||
const std::vector<AnfNodePtr> &input_list,
|
const std::vector<AnfNodePtr> &input_list,
|
||||||
const std::vector<AnfNodePtr> &output_list);
|
const std::vector<AnfNodePtr> &output_list);
|
||||||
|
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list);
|
||||||
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
|
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
|
||||||
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list);
|
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list);
|
||||||
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list);
|
void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list);
|
||||||
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json);
|
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json);
|
||||||
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list);
|
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list);
|
||||||
bool IsWeightBoundary(const AnfNodePtr &node);
|
bool IsWeightBoundary(const AnfNodePtr &node);
|
||||||
std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode);
|
std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode);
|
||||||
|
std::string GetProcessorStr(const AnfNodePtr &anf_node);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline std::string Vector2Str(const std::vector<T> &inputs) {
|
||||||
|
if (!inputs.empty()) {
|
||||||
|
std::ostringstream oss;
|
||||||
|
(void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator<T>(oss, ", "));
|
||||||
|
oss << inputs.back();
|
||||||
|
return oss.str();
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -16,14 +16,12 @@
|
||||||
|
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
|
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
#include "securec/include/securec.h"
|
#include "securec/include/securec.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
#include "utils/system/sha256.h"
|
#include "utils/system/sha256.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -49,6 +49,7 @@ enum OpPattern {
|
||||||
|
|
||||||
// Backend processor
|
// Backend processor
|
||||||
enum Processor {
|
enum Processor {
|
||||||
|
UNKNOWN = -1,
|
||||||
AICORE = 0,
|
AICORE = 0,
|
||||||
AICPU,
|
AICPU,
|
||||||
CUDA,
|
CUDA,
|
||||||
|
|
|
@ -5,13 +5,19 @@ file(GLOB_RECURSE _PREACTIVATE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
)
|
)
|
||||||
|
|
||||||
if (ENABLE_D)
|
if (ENABLE_D)
|
||||||
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc")
|
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST})
|
"ascend/*.cc"
|
||||||
|
"graph_kernel/*.cc"
|
||||||
|
)
|
||||||
|
list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST})
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
if (ENABLE_GPU)
|
if (ENABLE_GPU)
|
||||||
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc")
|
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
list(APPEND _PREACTIVATE_SRC_LIST ${_GPU_SRC_LIST})
|
"gpu/*.cc"
|
||||||
|
"graph_kernel/*.cc"
|
||||||
|
)
|
||||||
|
list(APPEND _PREACTIVATE_SRC_LIST ${_GPU_SRC_LIST})
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT)
|
set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT)
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "backend/optimizer/ascend/ascend_backend_optimization.h"
|
#include "backend/optimizer/ascend/ascend_backend_optimization.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <list>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
@ -68,8 +70,6 @@
|
||||||
#include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
|
#include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
|
||||||
#include "backend/optimizer/pass/eliminate_redundant_op.h"
|
#include "backend/optimizer/pass/eliminate_redundant_op.h"
|
||||||
#include "backend/optimizer/pass/common_subexpression_elimination.h"
|
#include "backend/optimizer/pass/common_subexpression_elimination.h"
|
||||||
#include "backend/optimizer/pass/fuse_graph_kernel.h"
|
|
||||||
#include "backend/optimizer/pass/fuse_basic.h"
|
|
||||||
#include "backend/optimizer/pass/add_atomic_clean.h"
|
#include "backend/optimizer/pass/add_atomic_clean.h"
|
||||||
#include "backend/optimizer/ascend/format_type/merge_cast_to_op.h"
|
#include "backend/optimizer/ascend/format_type/merge_cast_to_op.h"
|
||||||
#include "backend/optimizer/ascend/format_type/check_consistency.h"
|
#include "backend/optimizer/ascend/format_type/check_consistency.h"
|
||||||
|
@ -106,6 +106,8 @@
|
||||||
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
|
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
|
||||||
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
|
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||||
#include "utils/config_manager.h"
|
#include "utils/config_manager.h"
|
||||||
#include "debug/anf_ir_dump.h"
|
#include "debug/anf_ir_dump.h"
|
||||||
#include "debug/dump_proto.h"
|
#include "debug/dump_proto.h"
|
||||||
|
@ -406,7 +408,7 @@ void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &ke
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fuse graph kernels with basic ops
|
// Fuse graph kernels with basic ops
|
||||||
FuseGraphKernel(kernel_graph, is_before_kernel_select);
|
static_cast<void>(FuseCompositeOps(kernel_graph, is_before_kernel_select));
|
||||||
|
|
||||||
if (save_graphs) {
|
if (save_graphs) {
|
||||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" +
|
std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" +
|
||||||
|
@ -429,17 +431,17 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
|
||||||
save_graphs_path = ".";
|
save_graphs_path = ".";
|
||||||
}
|
}
|
||||||
if (save_graphs) {
|
if (save_graphs) {
|
||||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" +
|
std::string file_path = save_graphs_path + "/" + "hwopt_fuse_basic_opt_before_graph_" +
|
||||||
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
|
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
|
||||||
".ir";
|
".ir";
|
||||||
DumpIR(file_path, kernel_graph, true);
|
DumpIR(file_path, kernel_graph, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fuse basic ops with basic ops
|
// Fuse basic ops with basic ops
|
||||||
FuseBasic(kernel_graph, is_before_kernel_select);
|
static_cast<void>(FuseBasicOps(kernel_graph, is_before_kernel_select));
|
||||||
|
|
||||||
if (save_graphs) {
|
if (save_graphs) {
|
||||||
std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" +
|
std::string file_path = save_graphs_path + "/" + "hwopt_fuse_basic_opt_end_graph_" +
|
||||||
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
|
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
|
||||||
".ir";
|
".ir";
|
||||||
DumpIR(file_path, kernel_graph, true);
|
DumpIR(file_path, kernel_graph, true);
|
||||||
|
|
|
@ -601,6 +601,7 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i
|
||||||
std::vector<std::string> new_input_names;
|
std::vector<std::string> new_input_names;
|
||||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
primitive = primitive->Clone();
|
||||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||||
if (input_names == nullptr) {
|
if (input_names == nullptr) {
|
||||||
MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
|
MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
|
||||||
|
@ -631,6 +632,7 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i
|
||||||
}
|
}
|
||||||
if (need_update) {
|
if (need_update) {
|
||||||
// Update cnode's inputs
|
// Update cnode's inputs
|
||||||
|
new_inputs[0] = NewValueNode(primitive);
|
||||||
cnode->set_inputs(new_inputs);
|
cnode->set_inputs(new_inputs);
|
||||||
// Update cnode's input_names attr
|
// Update cnode's input_names attr
|
||||||
primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
|
primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
|
||||||
|
|
|
@ -73,7 +73,7 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
|
||||||
if (save_graphs) {
|
if (save_graphs) {
|
||||||
auto dump_file_path =
|
auto dump_file_path =
|
||||||
save_graphs_path + "/" + "hwopt_" + name() + "_" + std::to_string(num) + "_" + pass->name() + ".ir";
|
save_graphs_path + "/" + "hwopt_" + name() + "_" + std::to_string(num) + "_" + pass->name() + ".ir";
|
||||||
DumpIR(dump_file_path, func_graph);
|
DumpIR(dump_file_path, func_graph, true);
|
||||||
}
|
}
|
||||||
num++;
|
num++;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
|
@ -14,8 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "backend/optimizer/pass/fuse_basic.h"
|
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||||
#include "backend/optimizer/pass/fuse_graph_kernel.h"
|
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -31,17 +29,30 @@
|
||||||
#include "vm/segment_runner.h"
|
#include "vm/segment_runner.h"
|
||||||
#include "debug/anf_ir_dump.h"
|
#include "debug/anf_ir_dump.h"
|
||||||
#include "ir/func_graph_cloner.h"
|
#include "ir/func_graph_cloner.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
std::vector<PrimitivePtr> get_fusable_basic_ops(bool is_before_kernel_select) {
|
bool IsBasicOp(const AnfNodePtr &node, bool is_before_kernel_select) {
|
||||||
|
#if ENABLE_D
|
||||||
std::vector<PrimitivePtr> fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub,
|
std::vector<PrimitivePtr> fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub,
|
||||||
prim::kPrimExpandDims};
|
prim::kPrimExpandDims};
|
||||||
if (!is_before_kernel_select) {
|
if (!is_before_kernel_select) {
|
||||||
fusable_basic_ops.push_back(prim::kPrimCast);
|
fusable_basic_ops.push_back(prim::kPrimCast);
|
||||||
}
|
}
|
||||||
return fusable_basic_ops;
|
#elif ENABLE_GPU
|
||||||
|
std::vector<PrimitivePtr> fusable_basic_ops = {
|
||||||
|
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
||||||
|
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||||
|
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
|
||||||
|
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData};
|
||||||
|
#else
|
||||||
|
std::vector<PrimitivePtr> fusable_basic_ops;
|
||||||
|
#endif
|
||||||
|
return std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
|
||||||
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||||
}
|
}
|
||||||
|
|
||||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info,
|
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info,
|
||||||
|
@ -53,16 +64,14 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKe
|
||||||
return EXCLUDE;
|
return EXCLUDE;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select);
|
bool is_fusable = IsBasicOp(node, info.is_before_kernel_select);
|
||||||
bool is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
|
|
||||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
|
||||||
|
|
||||||
return is_fusable ? FOLLOW : EXCLUDE;
|
return is_fusable ? FOLLOW : EXCLUDE;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
|
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
|
||||||
GraphKernelInfo info;
|
GraphKernelInfo info;
|
||||||
info.is_before_kernel_select = is_before_kernel_select;
|
info.is_before_kernel_select = is_before_kernel_select;
|
||||||
|
|
||||||
// Search fusable nodes according input direction.
|
// Search fusable nodes according input direction.
|
||||||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1);
|
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1);
|
||||||
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
||||||
|
@ -170,8 +179,9 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con
|
||||||
fg->set_output(fg_new_output, true);
|
fg->set_output(fg_new_output, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::vector<AnfNodePtr> &todos,
|
bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos,
|
||||||
std::unordered_set<AnfNodePtr> *fused_ops, bool is_before_kernel_select) {
|
std::unordered_set<AnfNodePtr> *fused_ops, bool is_before_kernel_select) {
|
||||||
|
bool changed = false;
|
||||||
auto mng = kernel_graph->manager();
|
auto mng = kernel_graph->manager();
|
||||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||||
auto node = (*iter)->cast<CNodePtr>();
|
auto node = (*iter)->cast<CNodePtr>();
|
||||||
|
@ -181,9 +191,7 @@ void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, const
|
||||||
if (fused_ops->count(node)) {
|
if (fused_ops->count(node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto fusable_basic_ops = get_fusable_basic_ops(is_before_kernel_select);
|
bool is_basic_op = IsBasicOp(node, is_before_kernel_select);
|
||||||
bool is_basic_op = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
|
|
||||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
|
||||||
if (!is_basic_op || !kernel_graph->nodes().contains(node)) {
|
if (!is_basic_op || !kernel_graph->nodes().contains(node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -193,12 +201,16 @@ void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, const
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
changed = true;
|
||||||
FuncGraphPtr fg;
|
FuncGraphPtr fg;
|
||||||
AnfNodePtrList inputs;
|
AnfNodePtrList inputs;
|
||||||
AnfNodePtrList outputs;
|
AnfNodePtrList outputs;
|
||||||
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
|
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
|
||||||
RemoveControlDependOut(fg, &outputs, mng);
|
RemoveControlDependOut(fg, &outputs, mng);
|
||||||
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select);
|
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select);
|
||||||
|
if (!is_before_kernel_select) {
|
||||||
|
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
|
||||||
|
}
|
||||||
|
|
||||||
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
|
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
|
||||||
|
|
||||||
|
@ -210,10 +222,12 @@ void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, const
|
||||||
fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
|
fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
|
||||||
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
||||||
}
|
}
|
||||||
|
std::dynamic_pointer_cast<session::KernelGraph>(kernel_graph)->SetExecOrderByDefault();
|
||||||
|
return changed;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) {
|
bool FuseBasicOps(const FuncGraphPtr &kernel_graph, bool is_before_kernel_select) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
auto mng = kernel_graph->manager();
|
auto mng = kernel_graph->manager();
|
||||||
if (mng == nullptr) {
|
if (mng == nullptr) {
|
||||||
|
@ -223,7 +237,9 @@ void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool i
|
||||||
std::unordered_set<AnfNodePtr> fused_ops;
|
std::unordered_set<AnfNodePtr> fused_ops;
|
||||||
auto todos = TopoSort(kernel_graph->get_return());
|
auto todos = TopoSort(kernel_graph->get_return());
|
||||||
std::reverse(todos.begin(), todos.end());
|
std::reverse(todos.begin(), todos.end());
|
||||||
FuseBasic(kernel_graph, todos, &fused_ops, is_before_kernel_select);
|
return FuseBasicOps(kernel_graph, todos, &fused_ops, is_before_kernel_select);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { return FuseBasicOps(func_graph, false); }
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
|
@ -14,8 +14,8 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_BASIC_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_BASIC_H_
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
@ -23,7 +23,16 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select);
|
bool FuseBasicOps(const FuncGraphPtr &kernel_graph, bool is_before_kernel_select);
|
||||||
|
|
||||||
|
class BasicOpsFusion : public Pass {
|
||||||
|
public:
|
||||||
|
BasicOpsFusion() : Pass("basic_ops_fusion") {}
|
||||||
|
~BasicOpsFusion() override = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
};
|
||||||
|
using FuseBasicPtr = std::shared_ptr<BasicOpsFusion>;
|
||||||
|
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_BASIC_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_
|
|
@ -0,0 +1,385 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <queue>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "frontend/operator/ops.h"
|
||||||
|
#include "utils/utils.h"
|
||||||
|
#include "utils/ordered_set.h"
|
||||||
|
#include "utils/ordered_map.h"
|
||||||
|
#include "ir/graph_utils.h"
|
||||||
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "vm/segment_runner.h"
|
||||||
|
#include "debug/draw.h"
|
||||||
|
#include "debug/anf_ir_dump.h"
|
||||||
|
#include "ir/func_graph_cloner.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) {
|
||||||
|
#if ENABLE_D
|
||||||
|
std::vector<PrimitivePtr> basic_ops = {
|
||||||
|
prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum,
|
||||||
|
prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt,
|
||||||
|
prim::kPrimExpandDims, prim::kPrimReciprocal, prim::kPrimLessEqual};
|
||||||
|
if (!is_before_kernel_select) {
|
||||||
|
basic_ops.push_back(prim::kPrimCast);
|
||||||
|
}
|
||||||
|
#elif ENABLE_GPU
|
||||||
|
std::vector<PrimitivePtr> basic_ops = {
|
||||||
|
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
||||||
|
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||||
|
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
|
||||||
|
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData};
|
||||||
|
#else
|
||||||
|
std::vector<PrimitivePtr> basic_ops;
|
||||||
|
#endif
|
||||||
|
return std::any_of(basic_ops.begin(), basic_ops.end(),
|
||||||
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsReduceOp(const AnfNodePtr &node) {
|
||||||
|
std::vector<PrimitivePtr> reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin,
|
||||||
|
prim::kPrimReduceMax, prim::kPrimReduceAll};
|
||||||
|
return std::any_of(reduce_ops.begin(), reduce_ops.end(),
|
||||||
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetGraphKernelInfo(const FuncGraphPtr &fg, GraphKernelInfo *info) {
|
||||||
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
auto mng = fg->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(fg, false);
|
||||||
|
fg->set_manager(mng);
|
||||||
|
}
|
||||||
|
const auto &nodes = fg->nodes();
|
||||||
|
info->op_type = ELEWISE;
|
||||||
|
info->cal_step = -1;
|
||||||
|
info->reduce_op_num = 0;
|
||||||
|
for (auto node : nodes) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
info->cal_step++;
|
||||||
|
if (IsReduceOp(node)) {
|
||||||
|
info->op_type = REDUCE;
|
||||||
|
info->reduce_op_num++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto fg_flag = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||||
|
if (fg_flag != nullptr) {
|
||||||
|
auto fg_name = GetValue<std::string>(fg_flag);
|
||||||
|
info->origin_composite_name = fg_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsCompositeFuseBasic(const GraphKernelInfo &info, const AnfNodePtr &node) {
|
||||||
|
#if ENABLE_D
|
||||||
|
std::vector<PrimitivePtr> fusable_with_reduce;
|
||||||
|
if (!info.is_before_kernel_select) {
|
||||||
|
fusable_with_reduce.push_back(prim::kPrimCast);
|
||||||
|
}
|
||||||
|
if (info.op_type == REDUCE &&
|
||||||
|
(info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) {
|
||||||
|
return std::any_of(fusable_with_reduce.begin(), fusable_with_reduce.end(),
|
||||||
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return IsBasicFuseOp(node, info.is_before_kernel_select);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) {
|
||||||
|
// composite fuse composite op
|
||||||
|
if (AnfAlgo::IsGraphKernel(node)) {
|
||||||
|
#if ENABLE_D
|
||||||
|
return false;
|
||||||
|
#else
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
return IsCompositeFuseBasic(info, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateGraphKernelInfo(GraphKernelInfo *info, const AnfNodePtr &node) {
|
||||||
|
if (IsPrimitiveCNode(node)) {
|
||||||
|
info->cal_step++;
|
||||||
|
if (IsReduceOp(node)) {
|
||||||
|
info->op_type = REDUCE;
|
||||||
|
}
|
||||||
|
info->origin_composite_name += AnfAlgo::GetCNodePrimitive(node)->name() + "_";
|
||||||
|
} else if (AnfAlgo::IsGraphKernel(node)) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
auto composite_g = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||||
|
GraphKernelInfo fuse_info;
|
||||||
|
GetGraphKernelInfo(composite_g, &fuse_info);
|
||||||
|
info->cal_step += fuse_info.cal_step;
|
||||||
|
info->origin_composite_name += fuse_info.origin_composite_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, GraphKernelInfo *info, const AnfNodePtr &node) {
|
||||||
|
if (cur_node == node) {
|
||||||
|
return FOLLOW;
|
||||||
|
}
|
||||||
|
#if ENABLE_D
|
||||||
|
if (!IsPrimitiveCNode(node)) {
|
||||||
|
return EXCLUDE;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
bool is_fuse_composite = AnfAlgo::IsGraphKernel(node);
|
||||||
|
if (!IsPrimitiveCNode(node) && !is_fuse_composite) {
|
||||||
|
return EXCLUDE;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
bool is_fusable = IsFuse(*info, node);
|
||||||
|
if (is_fusable) {
|
||||||
|
UpdateGraphKernelInfo(info, node);
|
||||||
|
}
|
||||||
|
return is_fusable ? FOLLOW : EXCLUDE;
|
||||||
|
}
|
||||||
|
|
||||||
|
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelInfo *info, const AnfNodePtr &node) {
|
||||||
|
if (cur_node == node) {
|
||||||
|
return FOLLOW;
|
||||||
|
}
|
||||||
|
if (AnfAlgo::IsGraphKernel(node)) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kAnfPrimitiveIndex));
|
||||||
|
auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||||
|
MS_EXCEPTION_IF_NULL(fg_attr_val);
|
||||||
|
auto fg_attr = GetValue<std::string>(fg_attr_val);
|
||||||
|
if (fg_attr == kApplyMomentumOpName) {
|
||||||
|
return FOLLOW;
|
||||||
|
}
|
||||||
|
return EXCLUDE;
|
||||||
|
}
|
||||||
|
if (!IsPrimitiveCNode(node)) {
|
||||||
|
return EXCLUDE;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_fusable = IsFuse(*info, node);
|
||||||
|
if (is_fusable) {
|
||||||
|
UpdateGraphKernelInfo(info, node);
|
||||||
|
}
|
||||||
|
return is_fusable ? FOLLOW : EXCLUDE;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
|
||||||
|
std::set<AnfNodePtr> *cached_unconnected_set) {
|
||||||
|
if (!check_node->isa<CNode>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cnode = check_node->cast<CNodePtr>();
|
||||||
|
const auto &inputs = cnode->inputs();
|
||||||
|
// there is a input not in fused_op_set, but the input depends on the fused_op_set
|
||||||
|
bool has_circle = false;
|
||||||
|
for (auto input : inputs) {
|
||||||
|
if (input->isa<CNode>() && !fused_op_set.count(input)) {
|
||||||
|
std::set<AnfNodePtr> done;
|
||||||
|
std::vector<AnfNodePtr> todos = {input};
|
||||||
|
while (!todos.empty()) {
|
||||||
|
auto node = todos.back();
|
||||||
|
todos.pop_back();
|
||||||
|
if (done.count(node) || cached_unconnected_set->count(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
done.insert(node);
|
||||||
|
if (fused_op_set.count(node)) {
|
||||||
|
has_circle = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->isa<CNode>()) {
|
||||||
|
auto cnode_ptr = node->cast<CNodePtr>();
|
||||||
|
for (auto it : cnode_ptr->inputs()) {
|
||||||
|
if (it->isa<CNode>()) {
|
||||||
|
todos.push_back(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_circle) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
cached_unconnected_set->insert(done.begin(), done.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) {
|
||||||
|
std::set<AnfNodePtr> cached_unconnected_set;
|
||||||
|
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||||
|
auto include = [&fused_op_set](const AnfNodePtr &node) {
|
||||||
|
if (fused_op_set.count(node)) {
|
||||||
|
return FOLLOW;
|
||||||
|
}
|
||||||
|
return EXCLUDE;
|
||||||
|
};
|
||||||
|
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
|
||||||
|
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set);
|
||||||
|
// delete the circle node and the node which depend on the circle node in fused op
|
||||||
|
if (has_circle) {
|
||||||
|
auto mng = (*iter)->func_graph()->manager();
|
||||||
|
std::vector<AnfNodePtr> erase_nodes;
|
||||||
|
if (is_backward) {
|
||||||
|
erase_nodes = DeepUsersSearch(*iter, include, mng);
|
||||||
|
} else {
|
||||||
|
erase_nodes = DeepLinkedGraphSearch(*iter, include);
|
||||||
|
}
|
||||||
|
for (auto erase_node : erase_nodes) {
|
||||||
|
fused_op_set.erase(erase_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> res;
|
||||||
|
for (auto node : fused_op) {
|
||||||
|
if (fused_op_set.count(node)) {
|
||||||
|
res.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
|
||||||
|
if (lst->size() < 2) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> res;
|
||||||
|
std::set<AnfNodePtr> node_sets(lst->begin(), lst->end());
|
||||||
|
OrderedMap<AnfNodePtr, std::set<AnfNodePtr>> ins;
|
||||||
|
OrderedMap<AnfNodePtr, OrderedSet<AnfNodePtr>> outs;
|
||||||
|
std::queue<AnfNodePtr> q;
|
||||||
|
for (auto node : *lst) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
for (auto input : cnode->inputs()) {
|
||||||
|
if (!node_sets.count(input)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// out_degree
|
||||||
|
outs[input].insert(node);
|
||||||
|
// in_degree
|
||||||
|
ins[node].insert(input);
|
||||||
|
}
|
||||||
|
if (!ins.count(node)) {
|
||||||
|
ins[node] = {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto p : ins) {
|
||||||
|
if (p.second.size() == 0) {
|
||||||
|
q.push(p.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!q.empty()) {
|
||||||
|
auto node = q.front();
|
||||||
|
q.pop();
|
||||||
|
res.push_back(node);
|
||||||
|
if (!outs.count(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (auto out : outs[node]) {
|
||||||
|
if (!ins.count(out)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ins[out].erase(node);
|
||||||
|
if (ins[out].size() == 0) {
|
||||||
|
q.push(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lst->assign(res.begin(), res.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
|
||||||
|
auto func_graph = cnode->func_graph();
|
||||||
|
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||||
|
GraphKernelInfo info;
|
||||||
|
info.is_before_kernel_select = is_before_kernel_select;
|
||||||
|
GetGraphKernelInfo(graph_kernel_g, &info);
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
// Search fusable nodes according input direction.
|
||||||
|
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, &info, std::placeholders::_1);
|
||||||
|
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
||||||
|
std::reverse(used_nodes.begin(), used_nodes.end());
|
||||||
|
// Search fusable nodes according output direction.
|
||||||
|
auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, &info, std::placeholders::_1);
|
||||||
|
auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng);
|
||||||
|
|
||||||
|
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
|
||||||
|
if (used_nodes.size() > 1) {
|
||||||
|
used_nodes = RemoveCircle(used_nodes);
|
||||||
|
}
|
||||||
|
TopoSortForNodeList(&used_nodes);
|
||||||
|
return used_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
bool changed = false;
|
||||||
|
auto &todos = kernel_graph->execution_order();
|
||||||
|
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||||
|
auto node = *iter;
|
||||||
|
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||||
|
auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||||
|
if (fg_attr != nullptr) {
|
||||||
|
auto fg_name = GetValue<std::string>(fg_attr);
|
||||||
|
if (graph_kernel_black_list.count(fg_name) != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select);
|
||||||
|
if (fuse_nodes.size() <= 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
changed = true;
|
||||||
|
|
||||||
|
FuseNodesToSubGraph(fuse_nodes, kernel_graph, "", is_before_kernel_select);
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph), false);
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
|
@ -14,13 +13,14 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_GRAPH_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_GRAPH_KERNEL_H_
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <limits>
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
#include "backend/session/kernel_graph.h"
|
#include "backend/session/kernel_graph.h"
|
||||||
|
|
||||||
|
@ -31,18 +31,20 @@ enum GraphKernelType {
|
||||||
REDUCE, // contain reduce ops
|
REDUCE, // contain reduce ops
|
||||||
CUBE, // contain cube ops
|
CUBE, // contain cube ops
|
||||||
};
|
};
|
||||||
|
|
||||||
struct GraphKernelInfo {
|
struct GraphKernelInfo {
|
||||||
GraphKernelType op_type = ELEWISE;
|
GraphKernelType op_type = ELEWISE;
|
||||||
bool is_before_kernel_select = false;
|
bool is_before_kernel_select = false;
|
||||||
int reduce_op_num = 0;
|
int reduce_op_num = 0;
|
||||||
int cal_step = 0;
|
int cal_step = 0;
|
||||||
|
std::string origin_composite_name = "";
|
||||||
};
|
};
|
||||||
|
|
||||||
// when reduce graph kernel's cal step is greater than this number, not fuse
|
// when composite fuse composite the cal step is greate than this number, not fuse
|
||||||
|
#if ENABLE_D
|
||||||
const int MAX_REDUCE_OP_FUSION_CAL_STEP = 5;
|
const int MAX_REDUCE_OP_FUSION_CAL_STEP = 5;
|
||||||
// when reduce graph kernel contain reduce op num is greater than this number, not fuse
|
|
||||||
const int MAX_REDUCE_OP_FUSION_REDUCE_NUM = 2;
|
const int MAX_REDUCE_OP_FUSION_REDUCE_NUM = 2;
|
||||||
|
#endif
|
||||||
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
||||||
"LambNextMV", "LambUpdateWithLR"};
|
"LambNextMV", "LambUpdateWithLR"};
|
||||||
|
|
||||||
|
@ -50,14 +52,15 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
|
||||||
|
|
||||||
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
|
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
|
||||||
|
|
||||||
AnfNodePtr CreateNewFuseCNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const FuncGraphPtr &fg,
|
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select = false);
|
||||||
const AnfNodePtrList &inputs, const AnfNodePtrList &outputs,
|
|
||||||
bool is_before_kernel_select);
|
|
||||||
|
|
||||||
void ReplaceNewFuseCNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const AnfNodePtr &new_fuse_cnode,
|
class CompositeOpsFusion : public Pass {
|
||||||
const AnfNodePtrList &outputs);
|
public:
|
||||||
|
CompositeOpsFusion() : Pass("composite_ops_fusion") {}
|
||||||
void FuseGraphKernel(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select = false);
|
~CompositeOpsFusion() override = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
};
|
||||||
|
using FuseGraphKernelPassPtr = std::shared_ptr<CompositeOpsFusion>;
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_GRAPH_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
|
@ -0,0 +1,206 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "pipeline/jit/parse/python_adapter.h"
|
||||||
|
#include "mindspore/core/ir/graph_utils.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
|
#include "vm/segment_runner.h"
|
||||||
|
#include "runtime/device/kernel_info.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
constexpr auto kJsonKeyExpandInfo = "expand_info";
|
||||||
|
|
||||||
|
#define GET_VALUE_FOR_JSON(JSON, VALUE, VALUE_ELEM, TYPE_NAME, TYPE) \
|
||||||
|
if (VALUE_ELEM->isa<TYPE_NAME>()) { \
|
||||||
|
JSON = GetValue<TYPE>(VALUE); \
|
||||||
|
}
|
||||||
|
|
||||||
|
nlohmann::json ExpandAttrJsonInfo(const CNodePtr &cnode) {
|
||||||
|
nlohmann::json attrs_json;
|
||||||
|
if (auto prim = GetCNodePrimitive(cnode); prim != nullptr) {
|
||||||
|
auto attrs = prim->attrs();
|
||||||
|
for (const auto &[k, v] : attrs) {
|
||||||
|
nlohmann::json attr_json;
|
||||||
|
MS_LOG(DEBUG) << "attr key is : " << k << " and value type is : " << v->type_name();
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, v, Int32Imm, int);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, v, Int64Imm, int64_t);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, v, UInt32Imm, uint32_t);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, v, UInt64Imm, uint64_t);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, v, FP32Imm, float);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, v, FP64Imm, double);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, v, BoolImm, bool);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, v, StringImm, std::string);
|
||||||
|
|
||||||
|
if (v->isa<ValueList>() || v->isa<ValueTuple>()) {
|
||||||
|
auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value();
|
||||||
|
if (!vec.empty()) {
|
||||||
|
MS_LOG(DEBUG) << "value type is : " << vec[0]->type_name();
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], Int32Imm, std::vector<int>);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], Int64Imm, std::vector<int64_t>);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], UInt32Imm, std::vector<uint32_t>);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], UInt64Imm, std::vector<uint64_t>);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], FP32Imm, std::vector<float>);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], FP64Imm, std::vector<double>);
|
||||||
|
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], StringImm, std::vector<std::string>);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!attr_json.empty()) {
|
||||||
|
attrs_json.push_back(attr_json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return attrs_json;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ExpandJsonInfo(const CNodePtr &cnode, nlohmann::json *kernel_json) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_json);
|
||||||
|
if (kernel_json->find(kJsonKeyExpandInfo) != kernel_json->end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
nlohmann::json expand_info;
|
||||||
|
expand_info[kernel::kJsonKeyAttr] = ExpandAttrJsonInfo(cnode);
|
||||||
|
expand_info[kernel::kJsonKeyName] = AnfAlgo::GetCNodeName(cnode);
|
||||||
|
expand_info[kernel::kJsonKeyProcess] = kernel::GetProcessorStr(cnode);
|
||||||
|
std::vector<nlohmann::json> inputs_info;
|
||||||
|
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) {
|
||||||
|
nlohmann::json input_info;
|
||||||
|
input_info[kernel::kJsonKeyFormat] = AnfAlgo::GetInputFormat(cnode, i);
|
||||||
|
input_info[kernel::kJsonKeyInferShape] = AnfAlgo::GetPrevNodeOutputInferShape(cnode, i);
|
||||||
|
input_info[kernel::kJsonKeyShape] = AnfAlgo::GetInputDeviceShape(cnode, i);
|
||||||
|
input_info[kernel::kJsonKeyInferDataType] =
|
||||||
|
kernel::TypeId2String(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, i));
|
||||||
|
input_info[kernel::kJsonKeyDataType] = kernel::TypeId2String(AnfAlgo::GetInputDeviceDataType(cnode, i));
|
||||||
|
inputs_info.push_back(input_info);
|
||||||
|
}
|
||||||
|
expand_info[kernel::kJsonKeyInputDesc] = inputs_info;
|
||||||
|
|
||||||
|
std::vector<nlohmann::json> outputs_info;
|
||||||
|
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(cnode); ++i) {
|
||||||
|
nlohmann::json output_info;
|
||||||
|
output_info[kernel::kJsonKeyFormat] = AnfAlgo::GetOutputFormat(cnode, i);
|
||||||
|
output_info[kernel::kJsonKeyInferShape] = AnfAlgo::GetOutputInferShape(cnode, i);
|
||||||
|
output_info[kernel::kJsonKeyShape] = AnfAlgo::GetOutputDeviceShape(cnode, i);
|
||||||
|
output_info[kernel::kJsonKeyInferDataType] = kernel::TypeId2String(AnfAlgo::GetOutputInferDataType(cnode, i));
|
||||||
|
output_info[kernel::kJsonKeyDataType] = kernel::TypeId2String(AnfAlgo::GetOutputDeviceDataType(cnode, i));
|
||||||
|
outputs_info.push_back(output_info);
|
||||||
|
}
|
||||||
|
expand_info[kernel::kJsonKeyOutputDesc] = outputs_info;
|
||||||
|
(*kernel_json)[kJsonKeyExpandInfo] = expand_info;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
|
||||||
|
nlohmann::json kernel_json;
|
||||||
|
if (!ExpandJsonInfo(node, &kernel_json)) {
|
||||||
|
MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto node_desc_str = kernel_json.dump();
|
||||||
|
|
||||||
|
// call graph kernel ops generator.
|
||||||
|
MS_LOG(DEBUG) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] with input json:\n" << node_desc_str;
|
||||||
|
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str);
|
||||||
|
// parse result.
|
||||||
|
if (ret.is(py::none())) {
|
||||||
|
MS_LOG(ERROR) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] return invalid result, input json:\n"
|
||||||
|
<< node_desc_str;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::string kernel_desc_str = py::cast<std::string>(ret);
|
||||||
|
if (kernel_desc_str.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Jump expand node: " << node->fullname_with_scope();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// decode json to func_graph.
|
||||||
|
std::vector<AnfNodePtr> ori_inputs(node->inputs().begin() + 1, node->inputs().end());
|
||||||
|
return JsonDescToAnf(kernel_desc_str, ori_inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func_graph,
|
||||||
|
const FuncGraphPtr &new_func_graph, const CNodePtr &node) {
|
||||||
|
std::vector<AnfNodePtr> inputs(node->inputs().begin() + 1, node->inputs().end());
|
||||||
|
AnfNodePtrList kernel_nodes;
|
||||||
|
AnfNodePtrList outputs;
|
||||||
|
kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes);
|
||||||
|
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
|
||||||
|
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs, false);
|
||||||
|
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node));
|
||||||
|
std::string graph_kernel_flag;
|
||||||
|
std::for_each(kernel_nodes.begin(), kernel_nodes.end(), [&graph_kernel_flag](const AnfNodePtr &node) {
|
||||||
|
static_cast<void>(graph_kernel_flag.append(AnfAlgo::GetCNodeName(node)).append("_"));
|
||||||
|
});
|
||||||
|
MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() << " with: " << graph_kernel_flag;
|
||||||
|
return graph_kernel_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
||||||
|
bool changed = false;
|
||||||
|
auto todos = TopoSort(func_graph->get_return());
|
||||||
|
std::reverse(todos.begin(), todos.end());
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
for (const auto &n : todos) {
|
||||||
|
auto node = n->cast<CNodePtr>();
|
||||||
|
if (node == nullptr || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) || !CanExpand(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Expand process node: " << node->fullname_with_scope();
|
||||||
|
auto new_func_graph = CreateExpandFuncGraph(node);
|
||||||
|
if (new_func_graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Decode fused nodes failed, " << node->fullname_with_scope();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
mng->AddFuncGraph(new_func_graph);
|
||||||
|
MS_LOG(DEBUG) << "decode fused nodes success.";
|
||||||
|
|
||||||
|
auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node);
|
||||||
|
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node)));
|
||||||
|
MS_LOG(INFO) << "create new cnode success.";
|
||||||
|
|
||||||
|
// replace origin node.
|
||||||
|
(void)mng->Replace(node, graph_kernel_node);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
expand_ops_ = GetExpandOps();
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(func_graph, true);
|
||||||
|
func_graph->set_manager(mng);
|
||||||
|
}
|
||||||
|
return DoExpand(func_graph);
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "backend/optimizer/common/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class GraphKernelExpander : public Pass {
|
||||||
|
public:
|
||||||
|
GraphKernelExpander() : Pass("graph_kernel_expander") {}
|
||||||
|
~GraphKernelExpander() override = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
|
private:
|
||||||
|
FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
|
||||||
|
bool DoExpand(const FuncGraphPtr &func_graph);
|
||||||
|
AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph,
|
||||||
|
const CNodePtr &node);
|
||||||
|
bool CanExpand(const CNodePtr &node) {
|
||||||
|
return std::any_of(expand_ops_.begin(), expand_ops_.end(),
|
||||||
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_set<PrimitivePtr> expand_ops_;
|
||||||
|
};
|
||||||
|
using GraphKernelExpanderPtr = std::shared_ptr<GraphKernelExpander>;
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
|
|
@ -0,0 +1,674 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
|
#include <map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include "pipeline/jit/parse/python_adapter.h"
|
||||||
|
#include "pipeline/jit/action.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "vm/segment_runner.h"
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
||||||
|
#include "ir/func_graph_cloner.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
||||||
|
#ifdef ENABLE_D
|
||||||
|
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
void DebugDump(const FuncGraphPtr &graph, std::stringstream *buf) {
|
||||||
|
(*buf) << "Parameters: \n";
|
||||||
|
const auto ¶meters = graph->parameters();
|
||||||
|
(*buf) << "size: " << parameters.size() << "\n";
|
||||||
|
for (const auto &p : parameters) {
|
||||||
|
(*buf) << "\t" << p->DebugString(2) << "\n";
|
||||||
|
}
|
||||||
|
(*buf) << "ValueNodes: \n";
|
||||||
|
const auto &value_nodes = graph->value_nodes();
|
||||||
|
(*buf) << "size: " << value_nodes.size() << "\n";
|
||||||
|
for (const auto &v : value_nodes) {
|
||||||
|
(*buf) << "\t" << v.first->DebugString(2) << "\n";
|
||||||
|
}
|
||||||
|
(*buf) << "CNodes: \n";
|
||||||
|
const auto &all_nodes = graph->nodes();
|
||||||
|
(*buf) << "size: " << all_nodes.size() << "\n";
|
||||||
|
for (const auto &n : all_nodes) {
|
||||||
|
(*buf) << "\t" << n->DebugString(2) << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(real_outs);
|
||||||
|
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
|
||||||
|
auto &inputs = out->cast<CNodePtr>()->inputs();
|
||||||
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
|
real_outs->push_back(inputs[i]);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out); fg != nullptr) {
|
||||||
|
auto fg_out = fg->output();
|
||||||
|
if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) {
|
||||||
|
auto inputs = fg_out->cast<CNodePtr>()->inputs();
|
||||||
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
|
real_outs->push_back(inputs[i]);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) {
|
||||||
|
auto out_spec = node->abstract();
|
||||||
|
if (out_spec->isa<abstract::AbstractTuple>()) {
|
||||||
|
return out_spec->cast<abstract::AbstractTuplePtr>()->elements()[output_idx];
|
||||||
|
}
|
||||||
|
return out_spec;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueNodePtr ProcessAttrsForCast(const CNodePtr &cnode, const std::string &attr_name) {
|
||||||
|
auto dst_type = AnfAlgo::GetNodeAttr<std::string>(cnode, attr_name);
|
||||||
|
auto type = TypeIdToType(kernel::DtypeToTypeId(dst_type));
|
||||||
|
auto type_val_node = NewValueNode(type);
|
||||||
|
return type_val_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::map<std::string, std::function<ValueNodePtr(const CNodePtr &cnode, const std::string &attr_name)>>
|
||||||
|
attrs_process_map = {
|
||||||
|
{kCastOpName, ProcessAttrsForCast},
|
||||||
|
};
|
||||||
|
|
||||||
|
ValueNodePtr ProcessAttrValue(const CNodePtr &cnode, const std::string &attr_name) {
|
||||||
|
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||||
|
if (attrs_process_map.count(op_name) != 0) {
|
||||||
|
return attrs_process_map.at(op_name)(cnode, attr_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto attr_val = AnfAlgo::GetNodeAttr<ValuePtr>(cnode, attr_name);
|
||||||
|
auto attr_val_node = NewValueNode(attr_val);
|
||||||
|
return attr_val_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr ConstAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
|
const std::unordered_set<size_t> &input_attrs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
MS_LOG(DEBUG) << "process node: " << cnode->DebugString(2);
|
||||||
|
if (input_attrs.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input_names = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames);
|
||||||
|
MS_LOG(DEBUG) << "ori_input_names: " << kernel::Vector2Str(input_names);
|
||||||
|
std::vector<AnfNodePtr> new_inputs;
|
||||||
|
std::vector<std::string> new_input_names;
|
||||||
|
const auto &inputs = cnode->inputs();
|
||||||
|
for (size_t i = 0; i < inputs.size() - 1; ++i) {
|
||||||
|
new_input_names.push_back(input_names[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
(void)new_inputs.insert(new_inputs.end(), inputs.begin(), inputs.end());
|
||||||
|
bool need_update = false;
|
||||||
|
for (size_t i = inputs.size() - 1; i < input_names.size(); ++i) {
|
||||||
|
auto attr_name = input_names[i];
|
||||||
|
if (input_attrs.find(i) == input_attrs.end()) {
|
||||||
|
MS_LOG(WARNING) << "Other type input between tensors and attrs, name: " << attr_name
|
||||||
|
<< ", node: " << cnode->DebugString(2);
|
||||||
|
new_input_names.push_back(attr_name);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!AnfAlgo::HasNodeAttr(attr_name, cnode)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Attr: " << attr_name << " not found in node: " << cnode->DebugString(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hardcode. It should convert attrs value according to format, like op ReduceSum.
|
||||||
|
auto attr_val_node = ProcessAttrValue(cnode, attr_name);
|
||||||
|
new_inputs.push_back(attr_val_node);
|
||||||
|
new_input_names.push_back(attr_name);
|
||||||
|
need_update = true;
|
||||||
|
MS_LOG(DEBUG) << "convert attr: " << attr_name << " to input, value: " << attr_val_node;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "new_input_names: " << kernel::Vector2Str(new_input_names);
|
||||||
|
|
||||||
|
if (!need_update) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto new_cnode = func_graph->NewCNode(new_inputs);
|
||||||
|
// we do not modify abstract and kernel info.
|
||||||
|
new_cnode->set_abstract(cnode->abstract());
|
||||||
|
new_cnode->set_kernel_info(cnode->kernel_info_ptr());
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(new_input_names), new_cnode);
|
||||||
|
return new_cnode;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr DeleteAttrInInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
|
const std::unordered_set<size_t> &input_attrs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
MS_LOG(DEBUG) << "process node: " << cnode->DebugString(2);
|
||||||
|
if (input_attrs.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input_names = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames);
|
||||||
|
MS_LOG(DEBUG) << "ori_input_names: " << kernel::Vector2Str(input_names);
|
||||||
|
std::vector<AnfNodePtr> new_inputs;
|
||||||
|
std::vector<std::string> new_input_names;
|
||||||
|
|
||||||
|
const auto &inputs = cnode->inputs();
|
||||||
|
new_inputs.push_back(inputs[0]);
|
||||||
|
bool need_update = false;
|
||||||
|
for (size_t i = 0; i < inputs.size() - 1; ++i) {
|
||||||
|
auto input_node = inputs[i + 1];
|
||||||
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
|
// The attrs counts from 0
|
||||||
|
if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
|
||||||
|
auto value_node = input_node->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
MS_LOG(DEBUG) << "delete attr input: " << i << " of node: " << cnode->DebugString(2);
|
||||||
|
if (i >= input_names.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Index " << i << " is larger than input names size: " << input_names.size();
|
||||||
|
}
|
||||||
|
need_update = true;
|
||||||
|
} else {
|
||||||
|
new_inputs.push_back(input_node);
|
||||||
|
if (i < input_names.size()) {
|
||||||
|
new_input_names.push_back(input_names[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "new_input_names: " << kernel::Vector2Str(new_input_names);
|
||||||
|
|
||||||
|
if (!need_update) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto new_cnode = func_graph->NewCNode(new_inputs);
|
||||||
|
// we do not modify abstract and kernel info.
|
||||||
|
new_cnode->set_abstract(cnode->abstract());
|
||||||
|
new_cnode->set_kernel_info(cnode->kernel_info_ptr());
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(new_input_names), new_cnode);
|
||||||
|
return new_cnode;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) {
|
||||||
|
AnfNodePtrList outs;
|
||||||
|
auto out_node = (*fg)->output();
|
||||||
|
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
|
||||||
|
std::vector<AnfNodePtr> output_args;
|
||||||
|
auto out_cnode = out_node->cast<CNodePtr>();
|
||||||
|
for (auto out : out_cnode->inputs()) {
|
||||||
|
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
|
||||||
|
auto inputs = out->cast<CNodePtr>()->inputs();
|
||||||
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
|
output_args.push_back(inputs[i]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
output_args.push_back(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (output_args.size() != out_cnode->inputs().size()) {
|
||||||
|
auto new_out = (*fg)->NewCNode(output_args);
|
||||||
|
(*mng)->Replace(out_node, new_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 1; i < output_args.size(); ++i) {
|
||||||
|
outs.push_back(output_args[i]);
|
||||||
|
}
|
||||||
|
return outs;
|
||||||
|
}
|
||||||
|
|
||||||
|
outs.push_back(out_node);
|
||||||
|
return outs;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||||
|
const AnfNodePtrList &outputs, kernel::Processor processor) {
|
||||||
|
std::vector<std::string> graph_input_format;
|
||||||
|
std::vector<TypeId> graph_input_type;
|
||||||
|
std::vector<std::string> graph_output_format;
|
||||||
|
std::vector<TypeId> graph_output_type;
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
|
||||||
|
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||||
|
graph_input_format.push_back(input_format);
|
||||||
|
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||||
|
graph_input_type.push_back(input_type);
|
||||||
|
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
|
||||||
|
fg->parameters()[i]->set_abstract(input_abs);
|
||||||
|
}
|
||||||
|
auto new_outputs = outputs;
|
||||||
|
if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) {
|
||||||
|
std::vector<AnfNodePtr> real_outs;
|
||||||
|
if (IsMakeTupleOut(outputs[0], &real_outs)) {
|
||||||
|
new_outputs = real_outs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < new_outputs.size(); ++i) {
|
||||||
|
auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0);
|
||||||
|
auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||||
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||||
|
graph_output_format.push_back(output_format);
|
||||||
|
graph_output_type.push_back(output_type);
|
||||||
|
}
|
||||||
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||||
|
graph_info_builder.SetInputsFormat(graph_input_format);
|
||||||
|
graph_info_builder.SetInputsDeviceType(graph_input_type);
|
||||||
|
graph_info_builder.SetOutputsFormat(graph_output_format);
|
||||||
|
graph_info_builder.SetOutputsDeviceType(graph_output_type);
|
||||||
|
graph_info_builder.SetProcessor(processor);
|
||||||
|
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||||
|
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||||
|
auto graph_selected_info = graph_info_builder.Build();
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConstAttrToInput(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
std::vector<AnfNodePtr> todos;
|
||||||
|
kernel::GetValidKernelNodes(func_graph, &todos);
|
||||||
|
for (const auto &node : todos) {
|
||||||
|
ConstInputToAttrInfoRegister reg;
|
||||||
|
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(node), ®)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto new_node = ConstAttrToInput(func_graph, node->cast<CNodePtr>(), reg.GetConstInputAttrInfo());
|
||||||
|
if (new_node != nullptr && new_node != node) {
|
||||||
|
mng->Replace(node, new_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeleteAttrInInput(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
std::vector<AnfNodePtr> todos;
|
||||||
|
kernel::GetValidKernelNodes(func_graph, &todos);
|
||||||
|
for (const auto &node : todos) {
|
||||||
|
ConstInputToAttrInfoRegister reg;
|
||||||
|
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(node), ®)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto new_node = DeleteAttrInInput(func_graph, node->cast<CNodePtr>(), reg.GetConstInputAttrInfo());
|
||||||
|
if (new_node != nullptr && new_node != node) {
|
||||||
|
mng->Replace(node, new_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) {
|
||||||
|
AnfNodePtrList res;
|
||||||
|
if (outs.size() <= 1) {
|
||||||
|
return outs;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto out : outs) {
|
||||||
|
AnfNodePtrList real_outs;
|
||||||
|
if (IsMakeTupleOut(out, &real_outs)) {
|
||||||
|
res.insert(res.end(), real_outs.begin(), real_outs.end());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
res.push_back(out);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||||
|
const AnfNodePtrList &outputs, bool is_before_kernel_select) {
|
||||||
|
auto func_node = NewValueNode(fg);
|
||||||
|
std::vector<AnfNodePtr> fn_inputs;
|
||||||
|
fn_inputs.push_back(func_node);
|
||||||
|
fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end());
|
||||||
|
auto fuse_cnode = func_graph->NewCNode(fn_inputs);
|
||||||
|
// Set output abstract
|
||||||
|
if (outputs.size() > 1) {
|
||||||
|
std::vector<AbstractBasePtr> out_specs;
|
||||||
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
out_specs.push_back(outputs[i]->abstract());
|
||||||
|
}
|
||||||
|
auto out_spec = std::make_shared<abstract::AbstractTuple>(out_specs);
|
||||||
|
fuse_cnode->set_abstract(out_spec);
|
||||||
|
} else {
|
||||||
|
fuse_cnode->set_abstract(outputs[0]->abstract());
|
||||||
|
}
|
||||||
|
// Set parameter abstract.
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
|
||||||
|
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
|
||||||
|
fg->parameters()[i]->set_abstract(input_abs);
|
||||||
|
if (is_before_kernel_select) {
|
||||||
|
fg->parameters()[i]->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fuse_cnode;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode,
|
||||||
|
const AnfNodePtrList &outputs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
// single out
|
||||||
|
if (outputs.size() == 1) {
|
||||||
|
mng->Replace(outputs[0], new_fuse_cnode);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> fn_inputs;
|
||||||
|
for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
|
||||||
|
AnfNodePtrList real_outs;
|
||||||
|
// not make tuple out, replace
|
||||||
|
if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) {
|
||||||
|
fn_inputs.clear();
|
||||||
|
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
|
||||||
|
fn_inputs.push_back(new_fuse_cnode);
|
||||||
|
fn_inputs.push_back(NewValueNode(MakeValue(SizeToInt(out_idx))));
|
||||||
|
auto new_out = func_graph->NewCNode(fn_inputs);
|
||||||
|
new_out->set_abstract(outputs[out_idx]->abstract());
|
||||||
|
mng->Replace(outputs[out_idx], new_out);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the out is make tuple , modify the get_item node's value
|
||||||
|
auto users = mng->node_users()[outputs[out_idx]];
|
||||||
|
for (auto &user : users) {
|
||||||
|
auto use_node = user.first;
|
||||||
|
if (!use_node->isa<CNode>() || !IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto get_item_cnode = use_node->cast<CNodePtr>();
|
||||||
|
auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_input);
|
||||||
|
auto value_node = value_input->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
int item_idx = GetValue<int>(value_node->value());
|
||||||
|
int new_item_idx = SizeToInt(out_idx) + item_idx;
|
||||||
|
fn_inputs.clear();
|
||||||
|
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
|
||||||
|
fn_inputs.push_back(new_fuse_cnode);
|
||||||
|
fn_inputs.push_back(NewValueNode(new_item_idx));
|
||||||
|
auto new_out = func_graph->NewCNode(fn_inputs);
|
||||||
|
new_out->set_abstract(get_item_cnode->abstract());
|
||||||
|
mng->Replace(get_item_cnode, new_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||||
|
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix,
|
||||||
|
bool is_before_kernel_select) {
|
||||||
|
if (fuse_nodes.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mng = kernel_graph->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(kernel_graph, true);
|
||||||
|
kernel_graph->set_manager(mng);
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr fg;
|
||||||
|
AnfNodePtrList inputs;
|
||||||
|
AnfNodePtrList outputs;
|
||||||
|
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
|
||||||
|
|
||||||
|
// Remove nest make tuple in outs
|
||||||
|
auto expand_out = GetExpandOuts(outputs);
|
||||||
|
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, expand_out, is_before_kernel_select);
|
||||||
|
if (!is_before_kernel_select) {
|
||||||
|
SetNewKernelInfo(fuse_new_node, fg, inputs, expand_out, AnfAlgo::GetProcessor(fuse_nodes[0]));
|
||||||
|
}
|
||||||
|
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
|
||||||
|
|
||||||
|
// Inline origin graphkernel
|
||||||
|
auto cnodes = fg->GetOrderedCnodes();
|
||||||
|
for (const auto &n : cnodes) {
|
||||||
|
if (!AnfAlgo::IsGraphKernel(n)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0));
|
||||||
|
AnfNodePtrList ins;
|
||||||
|
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end());
|
||||||
|
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope());
|
||||||
|
mng->Replace(n, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
EliminateMakeTuple(&fg, &mng);
|
||||||
|
// set graphKernel attr
|
||||||
|
std::string fuse_op_name = "";
|
||||||
|
for (auto &fuse_node : fuse_nodes) {
|
||||||
|
if (IsPrimitiveCNode(fuse_node)) {
|
||||||
|
fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_";
|
||||||
|
} else if (AnfAlgo::IsGraphKernel(fuse_node)) {
|
||||||
|
auto fuse_cnode = fuse_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(fuse_cnode);
|
||||||
|
auto graph_kernel_fg = GetValueNode<FuncGraphPtr>(fuse_cnode->input(kAnfPrimitiveIndex));
|
||||||
|
auto fg_flag_val = graph_kernel_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||||
|
auto fuse_fg_name = GetValue<std::string>(fg_flag_val);
|
||||||
|
fuse_op_name += fuse_fg_name + "_";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fuse_op_name += postfix;
|
||||||
|
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann::json *op_desc,
|
||||||
|
std::map<std::string, AnfNodePtr> *address_node_map) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_desc);
|
||||||
|
if (nodes.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Input nodes is empty.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool has_graph_kernel =
|
||||||
|
std::any_of(nodes.begin(), nodes.end(), [](const AnfNodePtr &node) { return AnfAlgo::IsGraphKernel(node); });
|
||||||
|
bool is_single_graph_kernel = has_graph_kernel && nodes.size() == 1;
|
||||||
|
|
||||||
|
auto gen_json = [&dump_option, &op_desc, &address_node_map](const AnfNodePtrList &op_nodes,
|
||||||
|
const AnfNodePtrList &inputs,
|
||||||
|
const AnfNodePtrList &outputs) -> bool {
|
||||||
|
kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option);
|
||||||
|
if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, inputs, outputs)) {
|
||||||
|
MS_LOG(ERROR) << "Collect json desc failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
*op_desc = akg_kernel_json_generator.kernel_json();
|
||||||
|
if (address_node_map != nullptr) {
|
||||||
|
*address_node_map = akg_kernel_json_generator.address_node_map();
|
||||||
|
}
|
||||||
|
std::string fused_name;
|
||||||
|
std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) {
|
||||||
|
(void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_");
|
||||||
|
});
|
||||||
|
MS_LOG(INFO) << "Collect fusion json: " << fused_name;
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
FuncGraphPtr fg;
|
||||||
|
AnfNodePtrList op_nodes;
|
||||||
|
AnfNodePtrList inputs;
|
||||||
|
AnfNodePtrList outputs;
|
||||||
|
if (is_single_graph_kernel) {
|
||||||
|
fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
|
||||||
|
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
|
||||||
|
return gen_json(op_nodes, inputs, outputs);
|
||||||
|
} else if (!has_graph_kernel) {
|
||||||
|
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes);
|
||||||
|
op_nodes = nodes;
|
||||||
|
return gen_json(op_nodes, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes);
|
||||||
|
auto mng = Manage(fg, false);
|
||||||
|
fg->set_manager(mng);
|
||||||
|
// Inline origin graph kernel
|
||||||
|
auto fg_nodes = fg->GetOrderedCnodes();
|
||||||
|
for (auto const &n : fg_nodes) {
|
||||||
|
if (!AnfAlgo::IsGraphKernel(n)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0));
|
||||||
|
AnfNodePtrList ins;
|
||||||
|
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end());
|
||||||
|
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope());
|
||||||
|
mng->Replace(n, out);
|
||||||
|
}
|
||||||
|
inputs.clear();
|
||||||
|
outputs.clear();
|
||||||
|
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
|
||||||
|
return gen_json(op_nodes, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, DumpOption dump_option, nlohmann::json *op_desc) {
|
||||||
|
MS_EXCEPTION_IF_NULL(op_desc);
|
||||||
|
std::vector<nlohmann::json> graphs_desc;
|
||||||
|
for (auto const &graph_nodes : graphs) {
|
||||||
|
nlohmann::json desc;
|
||||||
|
if (!AnfToJsonDesc(graph_nodes, dump_option, &desc)) {
|
||||||
|
MS_LOG(ERROR) << "Collect json desc failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
graphs_desc.push_back(desc);
|
||||||
|
}
|
||||||
|
if (graphs_desc.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Collect zero json desc.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (graphs_desc.size() > 1) {
|
||||||
|
nlohmann::json op_json_desc;
|
||||||
|
op_json_desc[kJsonKeyMultiGraph] = true;
|
||||||
|
op_json_desc[kJsonKeyGraphDesc] = graphs_desc;
|
||||||
|
*op_desc = op_json_desc;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
*op_desc = graphs_desc[0];
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs) {
|
||||||
|
kernel::AkgKernelJsonDecoder akg_kernel_json_decoder;
|
||||||
|
auto fg = akg_kernel_json_decoder.DecodeFusedNodes(json_desc);
|
||||||
|
if (fg == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Akg decode json to graph failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||||
|
auto mng = resource->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
mng->AddFuncGraph(fg);
|
||||||
|
ConstAttrToInput(fg);
|
||||||
|
std::stringstream buf;
|
||||||
|
buf << "===================== graph after ConstAttrToInput " << fg->ToString() << " =====================\n";
|
||||||
|
DebugDump(fg, &buf);
|
||||||
|
MS_LOG(DEBUG) << buf.str();
|
||||||
|
|
||||||
|
// Do infer and specialize.
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
std::for_each(inputs.begin(), inputs.end(),
|
||||||
|
[&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); });
|
||||||
|
auto infer_fg = pipeline::Renormalize(resource, fg, args_spec_list);
|
||||||
|
if (infer_fg == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Infer decoded graph failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
buf.str("");
|
||||||
|
buf << "===================== graph after Renormalize " << infer_fg->ToString() << " =====================\n";
|
||||||
|
DebugDump(infer_fg, &buf);
|
||||||
|
MS_LOG(DEBUG) << buf.str();
|
||||||
|
|
||||||
|
// delete no use inputs(attrs), like op ReduceSum(axis).
|
||||||
|
DeleteAttrInInput(infer_fg);
|
||||||
|
buf.str("");
|
||||||
|
buf << "===================== graph after DeleteAttrInInput " << infer_fg->ToString() << " =====================\n";
|
||||||
|
DebugDump(infer_fg, &buf);
|
||||||
|
MS_LOG(DEBUG) << buf.str();
|
||||||
|
|
||||||
|
// clone a new graph.
|
||||||
|
auto new_fg = TransformableClone(infer_fg, std::make_shared<TraceTransform>("akg_decode"));
|
||||||
|
return new_fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||||
|
std::vector<AnfNodePtrList> *res_graphs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(res_graphs);
|
||||||
|
auto kernel_json = nlohmann::json::parse(json_desc);
|
||||||
|
if (kernel_json.find(kJsonKeyMultiGraph) == kernel_json.end() || kernel_json[kJsonKeyMultiGraph].is_null()) {
|
||||||
|
// not multi graphs.
|
||||||
|
MS_LOG(ERROR) << "Input json is not multi graph, " << json_desc;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel::AkgKernelJsonDecoder akg_kernel_json_decoder;
|
||||||
|
std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc];
|
||||||
|
if (graph_descs.empty()) {
|
||||||
|
MS_LOG(ERROR) << "No sub graph found, " << json_desc;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < graph_descs.size(); ++i) {
|
||||||
|
const auto &graph_desc = graph_descs[i];
|
||||||
|
AnfNodePtrList res_graph;
|
||||||
|
if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
|
||||||
|
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
res_graphs->push_back(res_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_set<PrimitivePtr> GetExpandOps() {
|
||||||
|
std::unordered_set<PrimitivePtr> expand_ops = {
|
||||||
|
prim::kPrimSquare,
|
||||||
|
prim::kPrimGelu,
|
||||||
|
prim::kPrimSoftmax,
|
||||||
|
prim::kPrimLayerNorm,
|
||||||
|
};
|
||||||
|
return expand_ops;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) {
|
||||||
|
std::stringstream name;
|
||||||
|
if (prefix != "") {
|
||||||
|
name << prefix << "_";
|
||||||
|
}
|
||||||
|
for (const auto &node : cnodes) {
|
||||||
|
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
||||||
|
name << AnfAlgo::GetCNodeName(node) << "_";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (postfix != "") {
|
||||||
|
name << postfix;
|
||||||
|
}
|
||||||
|
return name.str();
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "backend/session/kernel_graph.h"
|
||||||
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
using kernel::DumpOption;
|
||||||
|
|
||||||
|
constexpr auto kGraphKernelModule = "mindspore._extends.graph_kernel";
|
||||||
|
constexpr auto kGraphKernelSplitFunc = "split_with_json";
|
||||||
|
constexpr auto kGetGraphKernelOpExpander = "get_op_expander";
|
||||||
|
constexpr auto kJsonKeyMultiGraph = "multi_graph";
|
||||||
|
constexpr auto kJsonKeyGraphDesc = "graph_desc";
|
||||||
|
|
||||||
|
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||||
|
const AnfNodePtrList &outputs, kernel::Processor processor);
|
||||||
|
AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs);
|
||||||
|
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||||
|
const AnfNodePtrList &outputs, bool is_before_kernel_select);
|
||||||
|
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
|
||||||
|
const AnfNodePtrList &outputs);
|
||||||
|
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||||
|
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix,
|
||||||
|
bool is_before_kernel_select);
|
||||||
|
bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann::json *op_desc,
|
||||||
|
std::map<std::string, AnfNodePtr> *address_node_map = nullptr);
|
||||||
|
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, DumpOption dump_option, nlohmann::json *op_desc);
|
||||||
|
FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs);
|
||||||
|
bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||||
|
std::vector<AnfNodePtrList> *res_graphs);
|
||||||
|
std::unordered_set<PrimitivePtr> GetExpandOps();
|
||||||
|
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
|
@ -0,0 +1,742 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <utility>
|
||||||
|
#include <queue>
|
||||||
|
#include <map>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "frontend/optimizer/irpass.h"
|
||||||
|
#include "pipeline/jit/parse/python_adapter.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
|
#include "debug/anf_ir_dump.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, std::function<void(AnfNodePtr &)> callback) {
|
||||||
|
std::unordered_set<AnfNodePtr> visited;
|
||||||
|
std::queue<AnfNodePtr> que;
|
||||||
|
que.push(cnode);
|
||||||
|
visited.insert(cnode);
|
||||||
|
while (!que.empty()) {
|
||||||
|
auto ft_node = que.front();
|
||||||
|
que.pop();
|
||||||
|
callback(ft_node);
|
||||||
|
auto ft_cnode = ft_node->cast<CNodePtr>();
|
||||||
|
if (ft_cnode == nullptr) continue;
|
||||||
|
for (const auto &in_node : ft_cnode->inputs()) {
|
||||||
|
if (visited.count(in_node) == 0) {
|
||||||
|
que.push(in_node);
|
||||||
|
visited.insert(in_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visited each AnfNode once, use callback to do the job on AnfNode
|
||||||
|
inline void TraverseFuncGraph(const FuncGraphPtr &root, std::function<void(AnfNodePtr &)> callback) {
|
||||||
|
TraverseFuncGraphFromCNode(root->get_return(), callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
class AreaGraph;
|
||||||
|
class Splitter;
|
||||||
|
class Area {
|
||||||
|
public:
|
||||||
|
explicit Area(const AnfNodePtrList &anf_arr) {
|
||||||
|
nodes_.insert(anf_arr.begin(), anf_arr.end());
|
||||||
|
for (auto &node : anf_arr) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr) continue;
|
||||||
|
const auto &inputs = cnode->inputs();
|
||||||
|
if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) {
|
||||||
|
spy_cnodes_.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the external inputs of spy as a Parameter.
|
||||||
|
void CreateParameters(const FuncGraphPtr &func_graph, std::unordered_map<ParameterPtr, AnfNodePtr> *param_node_map) {
|
||||||
|
std::unordered_map<AnfNodePtr, ParameterPtr> node_param_map;
|
||||||
|
for (auto node : this->spy_cnodes_) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||||
|
AnfNodePtr in_node = cnode->input(i);
|
||||||
|
if (!IsExternalCNode(in_node)) continue;
|
||||||
|
auto it = node_param_map.find(in_node);
|
||||||
|
if (it == node_param_map.end()) {
|
||||||
|
auto new_param = std::make_shared<Parameter>(func_graph);
|
||||||
|
new_param->set_abstract(in_node->abstract());
|
||||||
|
func_graph->add_parameter(new_param);
|
||||||
|
node_param_map.insert(std::make_pair(in_node, new_param));
|
||||||
|
cnode->set_input(i, new_param);
|
||||||
|
} else {
|
||||||
|
cnode->set_input(i, it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this->spy_cnodes_.clear(); // spy list is not useful anymore
|
||||||
|
for (auto &&elem : node_param_map) {
|
||||||
|
param_node_map->insert(std::make_pair(elem.second, elem.first));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a return node for traitor nodes.
|
||||||
|
void CreateReturnNode(const FuncGraphPtr &func_graph, std::unordered_map<AnfNodePtr, size_t> *tuple_node_index) {
|
||||||
|
// If there's no traitor in the area, it means that this area is the last part
|
||||||
|
// of the original FuncGraph, it already contains the original Return node.
|
||||||
|
if (traitor_nodes_.empty()) {
|
||||||
|
for (auto &node : nodes_) {
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
|
||||||
|
func_graph->set_return(node->cast<CNodePtr>());
|
||||||
|
node->set_func_graph(func_graph);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(ERROR) << "Cannot find the return node in " << func_graph->ToString();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
AnfNodePtrList return_inputs = {NewValueNode(prim::kPrimReturn)};
|
||||||
|
if (traitor_nodes_.size() > 1) {
|
||||||
|
// The area has multiple output, it's necessary to make a tuple for them.
|
||||||
|
AnfNodePtrList maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||||
|
AbstractBasePtrList abstracts;
|
||||||
|
size_t i = 0;
|
||||||
|
for (auto &traitor : traitor_nodes_) {
|
||||||
|
tuple_node_index->insert(std::make_pair(traitor, i++));
|
||||||
|
maketuple_inputs.push_back(traitor);
|
||||||
|
abstracts.push_back(traitor->abstract());
|
||||||
|
}
|
||||||
|
auto maketuple_node = func_graph->NewCNode(maketuple_inputs);
|
||||||
|
maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstracts));
|
||||||
|
nodes_.insert(maketuple_node);
|
||||||
|
return_inputs.push_back(maketuple_node);
|
||||||
|
} else {
|
||||||
|
return_inputs.push_back(traitor_nodes_[0]);
|
||||||
|
}
|
||||||
|
auto return_node = func_graph->NewCNode(return_inputs);
|
||||||
|
return_node->set_abstract(return_inputs.back()->abstract());
|
||||||
|
func_graph->set_return(return_node);
|
||||||
|
nodes_.insert(return_node);
|
||||||
|
traitor_nodes_.clear(); // traitor list is not useful anymore
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddTraitor(const AnfNodePtr &node) {
|
||||||
|
if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) {
|
||||||
|
traitor_nodes_.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
friend AreaGraph;
|
||||||
|
friend Splitter;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// This is a CNode that does not belong to this area.
|
||||||
|
bool IsExternalCNode(const AnfNodePtr &node) { return node->isa<CNode>() && this->nodes_.count(node) == 0; }
|
||||||
|
|
||||||
|
// nodes in this area
|
||||||
|
std::unordered_set<AnfNodePtr> nodes_;
|
||||||
|
// if a node's output is used by other Area, it's a traitor
|
||||||
|
std::vector<AnfNodePtr> traitor_nodes_;
|
||||||
|
// if a node use other Area's output, it's a spy
|
||||||
|
std::vector<AnfNodePtr> spy_cnodes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AreaGraph {
|
||||||
|
public:
|
||||||
|
using AreaGraphPtr = std::shared_ptr<AreaGraph>;
|
||||||
|
|
||||||
|
// Build an area graph to maintain the relation between areas.
|
||||||
|
// Input node_groups: A group list, each element is a AnfNode list representing the node set in this group.
|
||||||
|
static AreaGraphPtr BuildAreaGraph(const std::vector<AnfNodePtrList> &node_groups) {
|
||||||
|
AreaGraph *area_graph_ptr = new (std::nothrow) AreaGraph(node_groups);
|
||||||
|
if (!area_graph_ptr) return nullptr;
|
||||||
|
auto area_graph = AreaGraphPtr(area_graph_ptr);
|
||||||
|
if (!area_graph->TopoSort()) {
|
||||||
|
MS_LOG(WARNING) << "The groups have a cycle.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return area_graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split the graph to multiple areas, and reconnect the edges between the areas.
|
||||||
|
// The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs.
|
||||||
|
// The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting.
|
||||||
|
void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes,
|
||||||
|
std::vector<size_t> *cnode_group_id, std::function<void(Area *)> expand_callback) {
|
||||||
|
main_cnodes->clear();
|
||||||
|
main_cnodes->resize(areas_.size(), nullptr);
|
||||||
|
|
||||||
|
for (auto &area : this->areas_) {
|
||||||
|
expand_callback(&area);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto index : topo_order_) {
|
||||||
|
auto ¤t_area = areas_[index];
|
||||||
|
auto sub_func_graph = std::make_shared<FuncGraph>();
|
||||||
|
std::unordered_map<ParameterPtr, AnfNodePtr> param_node_map;
|
||||||
|
|
||||||
|
current_area.CreateParameters(sub_func_graph, ¶m_node_map);
|
||||||
|
current_area.CreateReturnNode(sub_func_graph, &node_index_in_returned_tuple_);
|
||||||
|
auto new_main_cnode = this->CreateMainCNode(main_func_graph, sub_func_graph, *main_cnodes, param_node_map);
|
||||||
|
(*main_cnodes)[index] = new_main_cnode;
|
||||||
|
}
|
||||||
|
|
||||||
|
SortCNodes(main_cnodes);
|
||||||
|
cnode_group_id->swap(topo_order_); // The topo_order is not used anymore.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
explicit AreaGraph(const std::vector<AnfNodePtrList> &node_groups) : edge_prev_(node_groups.size()) {
|
||||||
|
for (size_t i = 0; i < node_groups.size(); ++i) {
|
||||||
|
areas_.emplace_back(node_groups[i]);
|
||||||
|
for (const auto &node : node_groups[i]) {
|
||||||
|
node_area_map_[node] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto &area : areas_) {
|
||||||
|
for (auto &spy : area.spy_cnodes_) {
|
||||||
|
auto cnode = spy->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
size_t v = node_area_map_[spy];
|
||||||
|
for (auto &in_node : cnode->inputs()) {
|
||||||
|
if (!in_node->isa<CNode>()) continue;
|
||||||
|
// area edge u -> v
|
||||||
|
size_t u = node_area_map_[in_node];
|
||||||
|
if (u == v) continue;
|
||||||
|
areas_[u].AddTraitor(in_node);
|
||||||
|
if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) {
|
||||||
|
edge_prev_[v].push_back(u);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topological sort the areas.
|
||||||
|
bool TopoSort() {
|
||||||
|
std::vector<int> out_degree(edge_prev_.size(), 0);
|
||||||
|
std::queue<size_t> que;
|
||||||
|
for (auto &prev : edge_prev_) {
|
||||||
|
for (size_t i : prev) {
|
||||||
|
out_degree[i]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < out_degree.size(); ++i) {
|
||||||
|
if (out_degree[i] == 0) que.push(i);
|
||||||
|
}
|
||||||
|
while (!que.empty()) {
|
||||||
|
size_t u = que.front();
|
||||||
|
que.pop();
|
||||||
|
topo_order_.push_back(u);
|
||||||
|
for (size_t i : edge_prev_[u]) {
|
||||||
|
if (--out_degree[i] == 0) que.push(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::reverse(topo_order_.begin(), topo_order_.end());
|
||||||
|
return topo_order_.size() == areas_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a CNode in main graph to hold the sub_func_graph.
|
||||||
|
CNodePtr CreateMainCNode(const FuncGraphPtr &main_func_graph, const FuncGraphPtr &sub_func_graph,
|
||||||
|
const std::vector<CNodePtr> &main_cnodes,
|
||||||
|
const std::unordered_map<ParameterPtr, AnfNodePtr> ¶m_node_map) {
|
||||||
|
AnfNodePtrList main_cnode_inputs = {NewValueNode(sub_func_graph)};
|
||||||
|
for (const auto ¶m : sub_func_graph->parameters()) {
|
||||||
|
// assert the param exists.
|
||||||
|
const auto &input_node = param_node_map.find(param->cast<ParameterPtr>())->second;
|
||||||
|
size_t input_area = node_area_map_[input_node];
|
||||||
|
// if the input node is in a tuple, then we need to create a GetItem fot it.
|
||||||
|
if (node_index_in_returned_tuple_.count(input_node) != 0) {
|
||||||
|
int idx_val = SizeToInt(node_index_in_returned_tuple_[input_node]);
|
||||||
|
auto idx = NewValueNode(idx_val);
|
||||||
|
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
|
||||||
|
AnfNodePtrList getitem_inputs = {NewValueNode(prim::kPrimTupleGetItem), main_cnodes[input_area], idx};
|
||||||
|
auto getitem_node = main_func_graph->NewCNode(getitem_inputs);
|
||||||
|
getitem_node->set_abstract(main_cnodes[input_area]->abstract());
|
||||||
|
main_cnode_inputs.push_back(getitem_node);
|
||||||
|
} else {
|
||||||
|
main_cnode_inputs.push_back(main_cnodes[input_area]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs);
|
||||||
|
new_main_cnode->set_abstract(sub_func_graph->get_return()->abstract());
|
||||||
|
return new_main_cnode;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SortCNodes(std::vector<CNodePtr> *main_cnodes) {
|
||||||
|
std::vector<CNodePtr> main_cnodes_sorted;
|
||||||
|
std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted),
|
||||||
|
[main_cnodes](int index) { return main_cnodes->at(index); });
|
||||||
|
main_cnodes->swap(main_cnodes_sorted);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Areas in this subgraph
|
||||||
|
std::vector<Area> areas_;
|
||||||
|
// Adjacency table of areas
|
||||||
|
std::vector<std::vector<size_t>> edge_prev_;
|
||||||
|
// Topological order of areas
|
||||||
|
std::vector<size_t> topo_order_;
|
||||||
|
// Map AnfNode to Area id
|
||||||
|
std::unordered_map<AnfNodePtr, size_t> node_area_map_;
|
||||||
|
// Map the nodes to their index if there are multiple value in an area
|
||||||
|
std::unordered_map<AnfNodePtr, size_t> node_index_in_returned_tuple_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Splitter {
|
||||||
|
public:
|
||||||
|
class SplitSchemer {
|
||||||
|
public:
|
||||||
|
virtual bool Split(const FuncGraphPtr &func_graph) = 0;
|
||||||
|
virtual bool NeedInline(size_t group_id) const { return false; }
|
||||||
|
const std::vector<AnfNodePtrList> &split_plan() const { return split_plan_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<AnfNodePtrList> split_plan_;
|
||||||
|
};
|
||||||
|
using SplitSchemerPtr = std::shared_ptr<SplitSchemer>;
|
||||||
|
using SplitterPtr = std::shared_ptr<Splitter>;
|
||||||
|
|
||||||
|
bool Split() {
|
||||||
|
GenParamMap();
|
||||||
|
auto ori_sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(old_subgraph_cnode_);
|
||||||
|
if (!split_schemer_->Split(ori_sub_func_graph)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto area_graph = AreaGraph::BuildAreaGraph(split_schemer_->split_plan());
|
||||||
|
if (area_graph == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan.
|
||||||
|
std::vector<size_t> cnodes_group_id;
|
||||||
|
std::function<void(Area *)> expand_callback = std::bind(&Splitter::AreaExpand, this, std::placeholders::_1);
|
||||||
|
area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id, expand_callback);
|
||||||
|
|
||||||
|
RebuildGraph(cnodes_group_id);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) {
|
||||||
|
MS_EXCEPTION_IF_NULL(main_cnode);
|
||||||
|
MS_EXCEPTION_IF_NULL(main_cnode->func_graph());
|
||||||
|
MS_EXCEPTION_IF_NULL(split_schemer);
|
||||||
|
return SplitterPtr(new Splitter(main_cnode, split_schemer));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer)
|
||||||
|
: main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {}
|
||||||
|
|
||||||
|
// Maintain new subgraphs in main graph.
|
||||||
|
void RebuildGraph(const std::vector<size_t> &cnodes_group_id) {
|
||||||
|
BindFuncGraph();
|
||||||
|
RecoverParameter();
|
||||||
|
ConnectToMainGraph(cnodes_group_id);
|
||||||
|
UpdateSubGraphInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebind nodes to its new sub_func_graph
|
||||||
|
void BindFuncGraph() {
|
||||||
|
for (const auto &cnode : new_subgraph_cnodes_) {
|
||||||
|
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||||
|
auto callback = [&sub_func_graph, this](const AnfNodePtr &node) {
|
||||||
|
if (!node->isa<ValueNode>()) {
|
||||||
|
node->set_func_graph(sub_func_graph);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
TraverseFuncGraph(sub_func_graph, callback);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recover the original subgraph's parameter if the new graph needs it
|
||||||
|
void RecoverParameter() {
|
||||||
|
for (const auto &cnode : new_subgraph_cnodes_) {
|
||||||
|
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||||
|
auto callback = [&cnode, &sub_func_graph, this](const AnfNodePtr &node) {
|
||||||
|
auto param = node->cast<ParameterPtr>();
|
||||||
|
if (param == nullptr) return;
|
||||||
|
auto it = this->param_to_main_graph_node_map_.find(param);
|
||||||
|
if (it != this->param_to_main_graph_node_map_.end()) {
|
||||||
|
cnode->add_input(it->second);
|
||||||
|
sub_func_graph->add_parameter(param);
|
||||||
|
// Avoid repeating parameters.
|
||||||
|
this->param_to_main_graph_node_map_.erase(it);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
TraverseFuncGraph(sub_func_graph, callback);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr InlineSubFuncGraph(const CNodePtr &main_node) {
|
||||||
|
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(main_node);
|
||||||
|
const auto &inputs = main_node->inputs();
|
||||||
|
auto output = func_graph->output()->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
|
const auto ¶meters = func_graph->parameters();
|
||||||
|
std::unordered_map<AnfNodePtr, AnfNodePtr> param_input;
|
||||||
|
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||||
|
param_input[parameters[i]] = inputs[i + 1];
|
||||||
|
}
|
||||||
|
auto sub_nodes = TopoSort(func_graph->get_return());
|
||||||
|
for (auto node : sub_nodes) {
|
||||||
|
if (auto cnode = node->cast<CNodePtr>(); cnode != nullptr) {
|
||||||
|
cnode->set_func_graph(main_func_graph_);
|
||||||
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||||
|
auto iter = param_input.find(cnode->input(i));
|
||||||
|
if (iter != param_input.end()) {
|
||||||
|
cnode->set_input(i, iter->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the new sub_func_graph node as input of nodes original main graph.
|
||||||
|
void ConnectToMainGraph(const std::vector<size_t> &cnodes_group_id) {
|
||||||
|
// For single output kernel, the last area contains the original output node (return node),
|
||||||
|
// to replace old subgraph with new subgraphs, just replace the old CNode with new last CNode.
|
||||||
|
// For multiple output kernel, to avoid returning Parameter, the last MakeTuple was distribute to
|
||||||
|
// a new FuncGraph, just inline the last MakeTuple node.
|
||||||
|
std::vector<CNodePtr> tmp_subgraph_cnodes;
|
||||||
|
std::unordered_map<AnfNodePtr, AnfNodePtr> replace_map;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < new_subgraph_cnodes_.size(); ++i) {
|
||||||
|
if (split_schemer_->NeedInline(cnodes_group_id[i])) {
|
||||||
|
// Connect the sub_graph's inner node to main_graph
|
||||||
|
auto output = InlineSubFuncGraph(new_subgraph_cnodes_[i]);
|
||||||
|
if (i + 1 == new_subgraph_cnodes_.size()) {
|
||||||
|
replace_map[this->old_subgraph_cnode_] = output;
|
||||||
|
} else {
|
||||||
|
replace_map[new_subgraph_cnodes_[i]] = output;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (i + 1 == new_subgraph_cnodes_.size()) {
|
||||||
|
replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back();
|
||||||
|
}
|
||||||
|
tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
new_subgraph_cnodes_.swap(tmp_subgraph_cnodes);
|
||||||
|
|
||||||
|
TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr) return;
|
||||||
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||||
|
auto input_node = cnode->input(i);
|
||||||
|
auto iter = replace_map.find(input_node);
|
||||||
|
if (iter != replace_map.end()) {
|
||||||
|
cnode->set_input(i, iter->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateSubGraphInfo() {
|
||||||
|
auto graph_manager = main_func_graph_->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(graph_manager);
|
||||||
|
|
||||||
|
for (auto cnode : new_subgraph_cnodes_) {
|
||||||
|
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||||
|
// add new sub_func_graph to manager
|
||||||
|
graph_manager->AddFuncGraph(sub_func_graph);
|
||||||
|
|
||||||
|
// set GraphKernel attr
|
||||||
|
auto attr = ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split");
|
||||||
|
sub_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(attr));
|
||||||
|
|
||||||
|
// set kernel info
|
||||||
|
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||||
|
AnfNodePtrList outputs;
|
||||||
|
kernel::GetFuncGraphOutputNodes(sub_func_graph, &outputs);
|
||||||
|
SetNewKernelInfo(cnode, sub_func_graph, inputs, outputs, AnfAlgo::GetProcessor(old_subgraph_cnode_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy all Parameter and ValueNode that the area used.
|
||||||
|
void AreaExpand(Area *area) {
|
||||||
|
std::unordered_map<AnfNodePtr, AnfNodePtr> old_valuenode_and_param_map;
|
||||||
|
for (auto sub_node : area->nodes_) {
|
||||||
|
auto sub_cnode = sub_node->cast<CNodePtr>();
|
||||||
|
if (sub_cnode == nullptr) continue;
|
||||||
|
for (size_t i = 1; i < sub_cnode->inputs().size(); ++i) {
|
||||||
|
auto in_node = sub_cnode->input(i);
|
||||||
|
if (in_node->isa<CNode>()) continue;
|
||||||
|
auto it = old_valuenode_and_param_map.find(in_node);
|
||||||
|
if (it != old_valuenode_and_param_map.end()) {
|
||||||
|
sub_cnode->set_input(i, it->second);
|
||||||
|
} else {
|
||||||
|
if (in_node->isa<Parameter>()) {
|
||||||
|
auto param = in_node->cast<ParameterPtr>();
|
||||||
|
auto cp_param = this->ParameterClone(param, in_node->func_graph());
|
||||||
|
old_valuenode_and_param_map[in_node] = cp_param->cast<AnfNodePtr>();
|
||||||
|
sub_cnode->set_input(i, cp_param);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GenParamMap() {
|
||||||
|
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(old_subgraph_cnode_);
|
||||||
|
auto ¶m_arr = sub_func_graph->parameters();
|
||||||
|
for (size_t i = 0; i < param_arr.size(); ++i) {
|
||||||
|
auto param = param_arr[i]->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
param_to_main_graph_node_map_[param] = old_subgraph_cnode_->input(i + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ParameterPtr ParameterClone(const ParameterPtr ¶m, const FuncGraphPtr &func) {
|
||||||
|
ParameterPtr param_c = std::make_shared<Parameter>(func);
|
||||||
|
param_c->set_name(param->name());
|
||||||
|
param_c->set_abstract(param->abstract());
|
||||||
|
param_to_main_graph_node_map_[param_c] = param_to_main_graph_node_map_[param];
|
||||||
|
return param_c;
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr main_func_graph_;
|
||||||
|
CNodePtr old_subgraph_cnode_; // The cnode that holds the original sub_func_graph
|
||||||
|
std::vector<CNodePtr> new_subgraph_cnodes_; // The cnode list that hold the new sub_func_graph
|
||||||
|
SplitSchemerPtr split_schemer_;
|
||||||
|
std::unordered_map<ParameterPtr, AnfNodePtr> param_to_main_graph_node_map_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CostModelSplitSchemer : public Splitter::SplitSchemer {
|
||||||
|
public:
|
||||||
|
bool Split(const FuncGraphPtr &func_graph) override {
|
||||||
|
if (!func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||||
|
MS_EXCEPTION(NotSupportError) << "func_graph must be a GraphKernel node.";
|
||||||
|
}
|
||||||
|
func_graph_ = func_graph;
|
||||||
|
this->Run();
|
||||||
|
return split_plan_.size() > 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NeedInline(size_t group_id) const override {
|
||||||
|
if (group_id >= need_inline_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The group_id " << group_id << " should be less than the group num " << need_inline_.size();
|
||||||
|
}
|
||||||
|
return need_inline_[group_id] != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual bool SplitByCostModel() {
|
||||||
|
// Use an address map to record the anf node address when converting to json,
|
||||||
|
// it will recover the original node after split.
|
||||||
|
std::map<std::string, AnfNodePtr> address_node_map;
|
||||||
|
|
||||||
|
// convert anf-ir to json
|
||||||
|
nlohmann::json json_desc;
|
||||||
|
DumpOption dump_option;
|
||||||
|
dump_option.is_before_select_kernel = false;
|
||||||
|
dump_option.save_ptr_address = true;
|
||||||
|
if (!AnfToJsonDesc(topo_valid_nodes_, dump_option, &json_desc, &address_node_map)) {
|
||||||
|
MS_LOG(ERROR) << "Collect json desc failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// call costmodel split function.
|
||||||
|
auto json_desc_str = json_desc.dump();
|
||||||
|
MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json:\n" << json_desc_str;
|
||||||
|
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str);
|
||||||
|
if (ret.is(py::none())) {
|
||||||
|
MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
|
||||||
|
<< json_desc_str;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string split_graphs_str = py::cast<std::string>(ret);
|
||||||
|
if (split_graphs_str.empty()) {
|
||||||
|
MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
|
||||||
|
<< json_desc_str;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// recover json to anf-ir.
|
||||||
|
split_plan_.clear();
|
||||||
|
if (!JsonDescToAnf(split_graphs_str, address_node_map, &split_plan_)) {
|
||||||
|
MS_LOG(ERROR) << "Failed to decode split graphs.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The info should be returned from costmodel.
|
||||||
|
need_inline_.assign(split_plan_.size(), 0);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void Run() {
|
||||||
|
auto mng = func_graph_->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(func_graph_, true);
|
||||||
|
func_graph_->set_manager(mng);
|
||||||
|
}
|
||||||
|
GetValidKernelNodes();
|
||||||
|
// call CostModel to get a split plan.
|
||||||
|
if (!SplitByCostModel() || split_plan_.size() <= 1) {
|
||||||
|
split_plan_.clear();
|
||||||
|
need_inline_.clear();
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "CostModel split successed. The kernel is split to " << split_plan_.size() << " parts.";
|
||||||
|
}
|
||||||
|
MapNodeGroup();
|
||||||
|
GroupReturnNode();
|
||||||
|
GroupVirtualNodes();
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual bool IsValidKernelNode(const AnfNodePtr &node) const {
|
||||||
|
if (!node->isa<CNode>()) return false;
|
||||||
|
if (AnfAlgo::IsRealKernel(node)) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void GetValidKernelNodes() {
|
||||||
|
topo_all_nodes_ = TopoSort(func_graph_->get_return());
|
||||||
|
topo_valid_nodes_.clear();
|
||||||
|
std::copy_if(topo_all_nodes_.begin(), topo_all_nodes_.end(), std::back_inserter(topo_valid_nodes_),
|
||||||
|
[this](const AnfNodePtr &node) { return IsValidKernelNode(node); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void MapNodeGroup() {
|
||||||
|
node_group_.clear();
|
||||||
|
for (size_t i = 0; i < split_plan_.size(); ++i) {
|
||||||
|
for (const auto &node : split_plan_[i]) {
|
||||||
|
node_group_[node] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// group the return node and last MakeTuple node (if exists).
|
||||||
|
virtual void GroupReturnNode() {
|
||||||
|
AnfNodePtrList outputs;
|
||||||
|
kernel::GetFuncGraphOutputNodes(func_graph_, &outputs);
|
||||||
|
auto ret_node = func_graph_->get_return();
|
||||||
|
auto output = func_graph_->output();
|
||||||
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
|
|
||||||
|
if (IsValidKernelNode(output)) {
|
||||||
|
auto group_id = node_group_[ret_node] = node_group_[output];
|
||||||
|
split_plan_[group_id].push_back(ret_node);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// assign the make_tuple node to a new group.
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
|
||||||
|
auto group_id = split_plan_.size();
|
||||||
|
split_plan_.push_back({output, ret_node});
|
||||||
|
need_inline_.push_back(1);
|
||||||
|
node_group_[ret_node] = node_group_[output] = group_id;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign virtual node to the same group of its input.
|
||||||
|
virtual void GroupVirtualNodes() {
|
||||||
|
for (const auto &node : topo_all_nodes_) {
|
||||||
|
if (node_group_.count(node)) continue;
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr) continue;
|
||||||
|
bool found = false;
|
||||||
|
for (const auto &input : cnode->inputs()) {
|
||||||
|
auto iter = node_group_.find(input);
|
||||||
|
if (iter != node_group_.end()) {
|
||||||
|
node_group_[node] = iter->second;
|
||||||
|
split_plan_[iter->second].push_back(node);
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
MS_LOG(WARNING) << cnode->fullname_with_scope() << " is ungrouped.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<FuncGraph> func_graph_;
|
||||||
|
AnfNodePtrList topo_all_nodes_;
|
||||||
|
AnfNodePtrList topo_valid_nodes_;
|
||||||
|
std::unordered_map<AnfNodePtr, size_t> node_group_;
|
||||||
|
std::vector<int> need_inline_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Eliminate the redundant MakeTuple-GetItem operations.
|
||||||
|
void EliminateTupleGetItem(const FuncGraphPtr &func_graph) {
|
||||||
|
auto callback = [](const AnfNodePtr &node) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr) return;
|
||||||
|
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||||
|
auto getitem = cnode->input(i);
|
||||||
|
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
|
||||||
|
auto getitem_cnode = getitem->cast<CNodePtr>();
|
||||||
|
auto maketuple = getitem_cnode->input(kRealInputNodeIndexInTupleGetItem);
|
||||||
|
if (!AnfAlgo::CheckPrimitiveType(maketuple, prim::kPrimMakeTuple)) continue;
|
||||||
|
auto maketuple_cnode = maketuple->cast<CNodePtr>();
|
||||||
|
int getitem_idx =
|
||||||
|
GetValue<int>(getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value());
|
||||||
|
cnode->set_input(i, maketuple_cnode->input(getitem_idx + 1));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
TraverseFuncGraph(func_graph, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TrySplit(const CNodePtr &sub_root_cnode) {
|
||||||
|
MS_LOG(INFO) << "Split process node: " << sub_root_cnode->fullname_with_scope();
|
||||||
|
auto splitter = Splitter::MakeSplitter(sub_root_cnode, std::make_shared<CostModelSplitSchemer>());
|
||||||
|
MS_EXCEPTION_IF_NULL(splitter);
|
||||||
|
bool result = splitter->Split();
|
||||||
|
MS_LOG(INFO) << "Split node completed, result: " << result;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool GraphKernelSplitter::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(func_graph, true);
|
||||||
|
func_graph->set_manager(mng);
|
||||||
|
}
|
||||||
|
auto todos = TopoSort(func_graph->get_return());
|
||||||
|
|
||||||
|
// Split subgraphs in reversed topo order,
|
||||||
|
// since the nodes behind the processing node may be modified when spliting.
|
||||||
|
bool changed = false;
|
||||||
|
for (auto iter = todos.crbegin(); iter != todos.crend(); ++iter) {
|
||||||
|
auto node = (*iter)->cast<CNodePtr>();
|
||||||
|
if (node != nullptr && AnfAlgo::IsGraphKernel(node)) {
|
||||||
|
changed = TrySplit(node) || changed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (changed) {
|
||||||
|
EliminateTupleGetItem(func_graph);
|
||||||
|
}
|
||||||
|
mng->RemoveRoots();
|
||||||
|
mng->KeepRoots({func_graph});
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
|
||||||
|
#include <memory>
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
#include "backend/optimizer/common/pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class GraphKernelSplitter : public Pass {
|
||||||
|
public:
|
||||||
|
GraphKernelSplitter() : Pass("graph_kernel_splitter") {}
|
||||||
|
~GraphKernelSplitter() override = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph);
|
||||||
|
};
|
||||||
|
using GraphKernelSplitterPtr = std::shared_ptr<GraphKernelSplitter>;
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
|
|
@ -1,560 +0,0 @@
|
||||||
|
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "backend/optimizer/pass/fuse_graph_kernel.h"
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <map>
|
|
||||||
#include <set>
|
|
||||||
#include <queue>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "base/core_ops.h"
|
|
||||||
#include "utils/utils.h"
|
|
||||||
#include "ir/graph_utils.h"
|
|
||||||
#include "backend/optimizer/common/helper.h"
|
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
|
||||||
#include "vm/segment_runner.h"
|
|
||||||
#include "debug/anf_ir_dump.h"
|
|
||||||
#include "ir/func_graph_cloner.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
std::vector<PrimitivePtr> get_fusable_basic_ops(bool is_before_kernel_select) {
|
|
||||||
std::vector<PrimitivePtr> fusable_basic_ops = {
|
|
||||||
prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum,
|
|
||||||
prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt,
|
|
||||||
prim::kPrimReciprocal, prim::kPrimExpandDims, prim::kPrimLessEqual};
|
|
||||||
if (!is_before_kernel_select) {
|
|
||||||
fusable_basic_ops.push_back(prim::kPrimCast);
|
|
||||||
}
|
|
||||||
return fusable_basic_ops;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<PrimitivePtr> get_fusable_basic_ops_with_reduce(bool is_before_kernel_select) {
|
|
||||||
std::vector<PrimitivePtr> fusable_basic_ops_with_reduce;
|
|
||||||
if (!is_before_kernel_select) {
|
|
||||||
fusable_basic_ops_with_reduce.push_back(prim::kPrimCast);
|
|
||||||
}
|
|
||||||
return fusable_basic_ops_with_reduce;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<PrimitivePtr> get_reduce_ops() {
|
|
||||||
std::vector<PrimitivePtr> reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin,
|
|
||||||
prim::kPrimReduceMax, prim::kPrimReduceAll};
|
|
||||||
return reduce_ops;
|
|
||||||
}
|
|
||||||
|
|
||||||
void GetGraphKernelInfo(const FuncGraphPtr fg, GraphKernelInfo *info) {
|
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
|
||||||
auto reduce_ops = get_reduce_ops();
|
|
||||||
const auto &nodes = fg->nodes();
|
|
||||||
info->op_type = ELEWISE;
|
|
||||||
info->cal_step = -1;
|
|
||||||
info->reduce_op_num = 0;
|
|
||||||
for (auto node : nodes) {
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
if (cnode == nullptr) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
info->cal_step++;
|
|
||||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
||||||
if (prim != nullptr) {
|
|
||||||
bool is_reudce = std::any_of(reduce_ops.begin(), reduce_ops.end(), [&prim](const PrimitivePtr &op) {
|
|
||||||
return op->hash() == prim->hash() && op->name() == prim->name();
|
|
||||||
});
|
|
||||||
if (is_reudce) {
|
|
||||||
info->op_type = REDUCE;
|
|
||||||
info->reduce_op_num++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) {
|
|
||||||
auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select);
|
|
||||||
auto fusable_basic_ops_with_reduce = get_fusable_basic_ops_with_reduce(info.is_before_kernel_select);
|
|
||||||
bool is_fusable = false;
|
|
||||||
if (info.op_type == REDUCE &&
|
|
||||||
(info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) {
|
|
||||||
is_fusable = std::any_of(fusable_basic_ops_with_reduce.begin(), fusable_basic_ops_with_reduce.end(),
|
|
||||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
|
||||||
} else {
|
|
||||||
is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
|
|
||||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
|
||||||
}
|
|
||||||
|
|
||||||
return is_fusable;
|
|
||||||
}
|
|
||||||
|
|
||||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info,
|
|
||||||
const AnfNodePtr &node) {
|
|
||||||
if (cur_node == node) {
|
|
||||||
return FOLLOW;
|
|
||||||
}
|
|
||||||
if (!IsPrimitiveCNode(node)) {
|
|
||||||
return EXCLUDE;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_fusable = IsFuse(info, node);
|
|
||||||
return is_fusable ? FOLLOW : EXCLUDE;
|
|
||||||
}
|
|
||||||
|
|
||||||
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const GraphKernelInfo &info,
|
|
||||||
const AnfNodePtr &node) {
|
|
||||||
if (cur_node == node) {
|
|
||||||
return FOLLOW;
|
|
||||||
}
|
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kAnfPrimitiveIndex));
|
|
||||||
auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
|
||||||
MS_EXCEPTION_IF_NULL(fg_attr_val);
|
|
||||||
auto fg_attr = GetValue<std::string>(fg_attr_val);
|
|
||||||
if (fg_attr == kApplyMomentumOpName) {
|
|
||||||
return FOLLOW;
|
|
||||||
}
|
|
||||||
return EXCLUDE;
|
|
||||||
}
|
|
||||||
if (!IsPrimitiveCNode(node)) {
|
|
||||||
return EXCLUDE;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_fusable = IsFuse(info, node);
|
|
||||||
return is_fusable ? FOLLOW : EXCLUDE;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
|
|
||||||
std::set<AnfNodePtr> *cached_unconnected_set) {
|
|
||||||
if (!check_node->isa<CNode>() || AnfAlgo::IsGraphKernel(check_node)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto cnode = check_node->cast<CNodePtr>();
|
|
||||||
const auto &inputs = cnode->inputs();
|
|
||||||
// there is a input not in fused_op_set, but the input depends on the fused_op_set
|
|
||||||
bool has_circle = false;
|
|
||||||
for (auto input : inputs) {
|
|
||||||
if (input->isa<CNode>() && !fused_op_set.count(input)) {
|
|
||||||
std::set<AnfNodePtr> done;
|
|
||||||
std::vector<AnfNodePtr> todos = {input};
|
|
||||||
while (!todos.empty()) {
|
|
||||||
auto node = todos.back();
|
|
||||||
todos.pop_back();
|
|
||||||
if (done.count(node) || cached_unconnected_set->count(node)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
done.insert(node);
|
|
||||||
if (fused_op_set.count(node)) {
|
|
||||||
has_circle = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node->isa<CNode>()) {
|
|
||||||
auto cnode_ptr = node->cast<CNodePtr>();
|
|
||||||
for (auto it : cnode_ptr->inputs()) {
|
|
||||||
if (it->isa<CNode>()) {
|
|
||||||
todos.push_back(it);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_circle) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
cached_unconnected_set->insert(done.begin(), done.end());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
|
|
||||||
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
|
|
||||||
auto &inputs = out->cast<CNodePtr>()->inputs();
|
|
||||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
|
||||||
real_outs->push_back(inputs[i]);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (AnfAlgo::GetCNodeFuncGraphPtr(out) != nullptr) {
|
|
||||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out);
|
|
||||||
auto fg_out = fg->output();
|
|
||||||
if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) {
|
|
||||||
auto inputs = fg_out->cast<CNodePtr>()->inputs();
|
|
||||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
|
||||||
real_outs->push_back(inputs[i]);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) {
|
|
||||||
std::set<AnfNodePtr> cached_unconnected_set;
|
|
||||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
|
||||||
auto include = [&fused_op_set](const AnfNodePtr &node) {
|
|
||||||
if (fused_op_set.count(node)) {
|
|
||||||
return FOLLOW;
|
|
||||||
}
|
|
||||||
return EXCLUDE;
|
|
||||||
};
|
|
||||||
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
|
|
||||||
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set);
|
|
||||||
// delete the circle node and the node which depend on the circle node in fused op
|
|
||||||
if (has_circle) {
|
|
||||||
auto mng = (*iter)->func_graph()->manager();
|
|
||||||
std::vector<AnfNodePtr> erase_nodes;
|
|
||||||
if (is_backward) {
|
|
||||||
erase_nodes = DeepUsersSearch(*iter, include, mng);
|
|
||||||
} else {
|
|
||||||
erase_nodes = DeepLinkedGraphSearch(*iter, include);
|
|
||||||
}
|
|
||||||
for (auto erase_node : erase_nodes) {
|
|
||||||
fused_op_set.erase(erase_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> res;
|
|
||||||
for (auto node : fused_op) {
|
|
||||||
if (fused_op_set.count(node)) {
|
|
||||||
res.push_back(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
|
|
||||||
if (lst->size() < 2) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> res;
|
|
||||||
std::set<AnfNodePtr> node_sets(lst->begin(), lst->end());
|
|
||||||
std::map<AnfNodePtr, std::set<AnfNodePtr>> ins;
|
|
||||||
std::map<AnfNodePtr, std::set<AnfNodePtr>> outs;
|
|
||||||
std::queue<AnfNodePtr> q;
|
|
||||||
for (auto node : *lst) {
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
for (auto input : cnode->inputs()) {
|
|
||||||
if (!node_sets.count(input)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// out_degree
|
|
||||||
outs[input].insert(node);
|
|
||||||
// in_degree
|
|
||||||
ins[node].insert(input);
|
|
||||||
}
|
|
||||||
if (!ins.count(node)) {
|
|
||||||
ins[node] = {};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto p : ins) {
|
|
||||||
if (p.second.size() == 0) {
|
|
||||||
q.push(p.first);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
while (!q.empty()) {
|
|
||||||
auto node = q.front();
|
|
||||||
q.pop();
|
|
||||||
res.push_back(node);
|
|
||||||
if (!outs.count(node)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (auto out : outs[node]) {
|
|
||||||
if (!ins.count(out)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
ins[out].erase(node);
|
|
||||||
if (ins[out].size() == 0) {
|
|
||||||
q.push(out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lst->assign(res.begin(), res.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
|
|
||||||
auto func_graph = cnode->func_graph();
|
|
||||||
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
|
||||||
GraphKernelInfo info;
|
|
||||||
info.is_before_kernel_select = is_before_kernel_select;
|
|
||||||
GetGraphKernelInfo(graph_kernel_g, &info);
|
|
||||||
auto mng = func_graph->manager();
|
|
||||||
// Search fusable nodes according input direction.
|
|
||||||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1);
|
|
||||||
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
|
||||||
std::reverse(used_nodes.begin(), used_nodes.end());
|
|
||||||
// Search fusable nodes according output direction.
|
|
||||||
auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, info, std::placeholders::_1);
|
|
||||||
auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng);
|
|
||||||
|
|
||||||
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
|
|
||||||
if (used_nodes.size() > 1) {
|
|
||||||
used_nodes = RemoveCircle(used_nodes);
|
|
||||||
}
|
|
||||||
TopoSortForNodeList(&used_nodes);
|
|
||||||
return used_nodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) {
|
|
||||||
auto out_spec = node->abstract();
|
|
||||||
if (out_spec->isa<abstract::AbstractTuple>()) {
|
|
||||||
return out_spec->cast<abstract::AbstractTuplePtr>()->elements()[output_idx];
|
|
||||||
}
|
|
||||||
return out_spec;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr CreateNewFuseCNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const FuncGraphPtr &fg,
|
|
||||||
const AnfNodePtrList &inputs, const AnfNodePtrList &outputs,
|
|
||||||
bool is_before_kernel_select) {
|
|
||||||
auto func_node = NewValueNode(fg);
|
|
||||||
std::vector<AnfNodePtr> fn_inputs;
|
|
||||||
fn_inputs.push_back(func_node);
|
|
||||||
fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end());
|
|
||||||
auto fuse_cnode = kernel_graph->NewCNode(fn_inputs);
|
|
||||||
// Set output abstract
|
|
||||||
if (outputs.size() > 1) {
|
|
||||||
std::vector<AbstractBasePtr> out_specs;
|
|
||||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
|
||||||
out_specs.push_back(outputs[i]->abstract());
|
|
||||||
}
|
|
||||||
auto out_spec = std::make_shared<abstract::AbstractTuple>(out_specs);
|
|
||||||
fuse_cnode->set_abstract(out_spec);
|
|
||||||
} else {
|
|
||||||
fuse_cnode->set_abstract(outputs[0]->abstract());
|
|
||||||
}
|
|
||||||
// Set parameter abstract.
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
||||||
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
|
|
||||||
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
|
|
||||||
fg->parameters()[i]->set_abstract(input_abs);
|
|
||||||
if (is_before_kernel_select) {
|
|
||||||
fg->parameters()[i]->set_kernel_info(std::make_shared<device::KernelInfo>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Set kernel info.
|
|
||||||
if (!is_before_kernel_select) {
|
|
||||||
std::vector<std::string> graph_input_format;
|
|
||||||
std::vector<TypeId> graph_input_type;
|
|
||||||
std::vector<std::string> graph_output_format;
|
|
||||||
std::vector<TypeId> graph_output_type;
|
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
||||||
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
|
|
||||||
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
|
||||||
graph_input_format.push_back(input_format);
|
|
||||||
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
|
||||||
graph_input_type.push_back(input_type);
|
|
||||||
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
|
|
||||||
fg->parameters()[i]->set_abstract(input_abs);
|
|
||||||
}
|
|
||||||
auto new_outputs = outputs;
|
|
||||||
if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) {
|
|
||||||
std::vector<AnfNodePtr> real_outs;
|
|
||||||
if (IsMakeTupleOut(outputs[0], &real_outs)) {
|
|
||||||
new_outputs = real_outs;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < new_outputs.size(); ++i) {
|
|
||||||
auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0);
|
|
||||||
auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
|
||||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
|
||||||
graph_output_format.push_back(output_format);
|
|
||||||
graph_output_type.push_back(output_type);
|
|
||||||
}
|
|
||||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
|
||||||
graph_info_builder.SetInputsFormat(graph_input_format);
|
|
||||||
graph_info_builder.SetInputsDeviceType(graph_input_type);
|
|
||||||
graph_info_builder.SetOutputsFormat(graph_output_format);
|
|
||||||
graph_info_builder.SetOutputsDeviceType(graph_output_type);
|
|
||||||
graph_info_builder.SetProcessor(kernel::Processor::AICORE);
|
|
||||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
|
||||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
|
||||||
auto graph_selected_info = graph_info_builder.Build();
|
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, fuse_cnode.get());
|
|
||||||
}
|
|
||||||
return fuse_cnode;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReplaceNewFuseCNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const AnfNodePtr &new_fuse_cnode,
|
|
||||||
const AnfNodePtrList &outputs) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
||||||
auto mng = kernel_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
// single out
|
|
||||||
if (outputs.size() == 1) {
|
|
||||||
mng->Replace(outputs[0], new_fuse_cnode);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> fn_inputs;
|
|
||||||
for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
|
|
||||||
AnfNodePtrList real_outs;
|
|
||||||
// not make tuple out, replace
|
|
||||||
if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) {
|
|
||||||
fn_inputs.clear();
|
|
||||||
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
|
|
||||||
fn_inputs.push_back(new_fuse_cnode);
|
|
||||||
fn_inputs.push_back(NewValueNode(MakeValue(SizeToInt(out_idx))));
|
|
||||||
auto new_out = kernel_graph->NewCNode(fn_inputs);
|
|
||||||
new_out->set_abstract(outputs[out_idx]->abstract());
|
|
||||||
mng->Replace(outputs[out_idx], new_out);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// the out is make tuple , modify the get_item node's value
|
|
||||||
auto users = mng->node_users()[outputs[out_idx]];
|
|
||||||
for (auto &user : users) {
|
|
||||||
auto use_node = user.first;
|
|
||||||
if (use_node->isa<CNode>() && (IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem))) {
|
|
||||||
auto get_item_cnode = use_node->cast<CNodePtr>();
|
|
||||||
auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
|
||||||
MS_EXCEPTION_IF_NULL(value_input);
|
|
||||||
auto value_node = value_input->cast<ValueNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
|
||||||
int item_idx = GetValue<int>(value_node->value());
|
|
||||||
int new_item_idx = SizeToInt(out_idx) + item_idx;
|
|
||||||
fn_inputs.clear();
|
|
||||||
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
|
|
||||||
fn_inputs.push_back(new_fuse_cnode);
|
|
||||||
fn_inputs.push_back(NewValueNode(new_item_idx));
|
|
||||||
auto new_out = kernel_graph->NewCNode(fn_inputs);
|
|
||||||
new_out->set_abstract(get_item_cnode->abstract());
|
|
||||||
mng->Replace(get_item_cnode, new_out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) {
|
|
||||||
AnfNodePtrList outs;
|
|
||||||
auto out_node = (*fg)->output();
|
|
||||||
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
|
|
||||||
std::vector<AnfNodePtr> output_args;
|
|
||||||
auto out_cnode = out_node->cast<CNodePtr>();
|
|
||||||
for (auto out : out_cnode->inputs()) {
|
|
||||||
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
|
|
||||||
auto inputs = out->cast<CNodePtr>()->inputs();
|
|
||||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
|
||||||
output_args.push_back(inputs[i]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
output_args.push_back(out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (output_args.size() != out_cnode->inputs().size()) {
|
|
||||||
auto new_out = (*fg)->NewCNode(output_args);
|
|
||||||
(*mng)->Replace(out_node, new_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 1; i < output_args.size(); ++i) {
|
|
||||||
outs.push_back(output_args[i]);
|
|
||||||
}
|
|
||||||
return outs;
|
|
||||||
}
|
|
||||||
|
|
||||||
outs.push_back(out_node);
|
|
||||||
return outs;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) {
|
|
||||||
AnfNodePtrList res;
|
|
||||||
if (outs.size() <= 1) {
|
|
||||||
return outs;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto out : outs) {
|
|
||||||
AnfNodePtrList real_outs;
|
|
||||||
if (IsMakeTupleOut(out, &real_outs)) {
|
|
||||||
res.insert(res.end(), real_outs.begin(), real_outs.end());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
res.push_back(out);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
void FuseGraphKernel(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
||||||
auto mng = kernel_graph->manager();
|
|
||||||
if (mng == nullptr) {
|
|
||||||
mng = Manage(kernel_graph, true);
|
|
||||||
kernel_graph->set_manager(mng);
|
|
||||||
}
|
|
||||||
auto &todos = kernel_graph->execution_order();
|
|
||||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
|
||||||
auto node = *iter;
|
|
||||||
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
|
||||||
if (fg_attr != nullptr) {
|
|
||||||
auto fg_name = GetValue<std::string>(fg_attr);
|
|
||||||
if (graph_kernel_black_list.count(fg_name) != 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select);
|
|
||||||
if (fuse_nodes.size() <= 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
FuncGraphPtr fg;
|
|
||||||
AnfNodePtrList inputs;
|
|
||||||
AnfNodePtrList outputs;
|
|
||||||
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
|
|
||||||
|
|
||||||
// Remove nest make tuple in outs
|
|
||||||
auto expand_out = GetExpandOuts(outputs);
|
|
||||||
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, expand_out, is_before_kernel_select);
|
|
||||||
|
|
||||||
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
|
|
||||||
|
|
||||||
// Inline origin graphkernel
|
|
||||||
auto cnodes = fg->GetOrderedCnodes();
|
|
||||||
for (const auto &n : cnodes) {
|
|
||||||
if (!AnfAlgo::IsGraphKernel(n)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0));
|
|
||||||
AnfNodePtrList ins;
|
|
||||||
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end());
|
|
||||||
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope());
|
|
||||||
mng->Replace(n, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
EliminateMakeTuple(&fg, &mng);
|
|
||||||
// Set graphkernel flag
|
|
||||||
auto ori_fg = GetValueNode<FuncGraphPtr>(node->input(kAnfPrimitiveIndex));
|
|
||||||
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, ori_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
|
@ -41,6 +41,7 @@
|
||||||
#include "utils/config_manager.h"
|
#include "utils/config_manager.h"
|
||||||
#include "utils/base_ref_extends.h"
|
#include "utils/base_ref_extends.h"
|
||||||
#include "debug/tensor_load.h"
|
#include "debug/tensor_load.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace session {
|
namespace session {
|
||||||
|
|
|
@ -39,6 +39,10 @@
|
||||||
#include "backend/optimizer/gpu/insert_format_transform_op.h"
|
#include "backend/optimizer/gpu/insert_format_transform_op.h"
|
||||||
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
|
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
|
||||||
#include "backend/optimizer/gpu/remove_redundant_format_transform.h"
|
#include "backend/optimizer/gpu/remove_redundant_format_transform.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||||
#include "runtime/device/kernel_runtime_manager.h"
|
#include "runtime/device/kernel_runtime_manager.h"
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
#include "common/trans.h"
|
#include "common/trans.h"
|
||||||
|
@ -104,6 +108,22 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
|
||||||
kernel_graph->SetExecOrderByDefault();
|
kernel_graph->SetExecOrderByDefault();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||||
|
auto context_ptr = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
|
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
|
||||||
|
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
||||||
|
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
|
||||||
|
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
|
||||||
|
optimizer->AddPassManager(pm);
|
||||||
|
(void)optimizer->Optimize(kernel_graph);
|
||||||
|
kernel_graph->SetExecOrderByDefault();
|
||||||
|
}
|
||||||
|
|
||||||
void GPUSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
void GPUSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
device::gpu::AssignGpuStream(kernel_graph);
|
device::gpu::AssignGpuStream(kernel_graph);
|
||||||
|
@ -218,6 +238,8 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
|
||||||
SelectKernel(graph);
|
SelectKernel(graph);
|
||||||
// Graph optimization relevant to device data format
|
// Graph optimization relevant to device data format
|
||||||
HardwareOptimize(graph);
|
HardwareOptimize(graph);
|
||||||
|
// Graph kernel fusion optimization
|
||||||
|
GraphKernelOptimize(graph);
|
||||||
// Dump .pb graph after graph optimization
|
// Dump .pb graph after graph optimization
|
||||||
if (save_graphs) {
|
if (save_graphs) {
|
||||||
DumpIRProto(graph, "after_opt_" + std::to_string(graph_id));
|
DumpIRProto(graph, "after_opt_" + std::to_string(graph_id));
|
||||||
|
|
|
@ -51,6 +51,8 @@ class GPUSession : public SessionBasic {
|
||||||
|
|
||||||
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||||
|
|
||||||
|
void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||||
|
|
||||||
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph);
|
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||||
|
|
||||||
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#endif
|
#endif
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <iomanip>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "ir/primitive.h"
|
#include "ir/primitive.h"
|
||||||
|
@ -446,13 +447,30 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string AddGlobalId(const std::string &filename) {
|
||||||
|
static size_t g_id = 0;
|
||||||
|
std::ostringstream s;
|
||||||
|
auto i = filename.rfind('/');
|
||||||
|
if (i == string::npos) {
|
||||||
|
s << std::setfill('0') << std::setw(4) << g_id << "_";
|
||||||
|
s << filename;
|
||||||
|
} else {
|
||||||
|
s << filename.substr(0, i + 1);
|
||||||
|
s << std::setfill('0') << std::setw(4) << g_id << "_";
|
||||||
|
s << filename.substr(i + 1);
|
||||||
|
}
|
||||||
|
++g_id;
|
||||||
|
return s.str();
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_DUMP_IR
|
#ifdef ENABLE_DUMP_IR
|
||||||
void DumpIR(const std::string &filename, const FuncGraphPtr &graph, bool dump_full_name) {
|
void DumpIR(const std::string &filename, const FuncGraphPtr &graph, bool dump_full_name) {
|
||||||
if (graph == nullptr) {
|
if (graph == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (filename.size() > PATH_MAX) {
|
auto real_filename = AddGlobalId(filename);
|
||||||
MS_LOG(ERROR) << "File path " << filename << " is too long.";
|
if (real_filename.size() > PATH_MAX) {
|
||||||
|
MS_LOG(ERROR) << "File path " << real_filename << " is too long.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
char real_path[PATH_MAX] = {0};
|
char real_path[PATH_MAX] = {0};
|
||||||
|
@ -461,8 +479,8 @@ void DumpIR(const std::string &filename, const FuncGraphPtr &graph, bool dump_fu
|
||||||
MS_LOG(DEBUG) << "dir " << filename << " does not exit.";
|
MS_LOG(DEBUG) << "dir " << filename << " does not exit.";
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (nullptr == realpath(filename.c_str(), real_path)) {
|
if (nullptr == realpath(real_filename.c_str(), real_path)) {
|
||||||
MS_LOG(DEBUG) << "Dir " << filename << " does not exit.";
|
MS_LOG(DEBUG) << "Dir " << real_filename << " does not exit.";
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -16,9 +16,9 @@
|
||||||
#include "runtime/device/gpu/gpu_kernel_build.h"
|
#include "runtime/device/gpu/gpu_kernel_build.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
|
|
||||||
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h"
|
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h"
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "backend/session/kernel_build_client.h"
|
#include "backend/session/kernel_build_client.h"
|
||||||
|
@ -56,16 +56,16 @@ void GpuBuild(const KernelGraphPtr &kernel_graph) {
|
||||||
}
|
}
|
||||||
auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel);
|
auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel);
|
||||||
if (!gpu_kernel_ptr) {
|
if (!gpu_kernel_ptr) {
|
||||||
MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed";
|
MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel->fullname_with_scope() << "] failed";
|
||||||
}
|
}
|
||||||
session::AnfRuntimeAlgorithm::SetKernelMod(gpu_kernel_ptr, kernel.get());
|
session::AnfRuntimeAlgorithm::SetKernelMod(gpu_kernel_ptr, kernel.get());
|
||||||
} else {
|
} else {
|
||||||
auto gpu_kernel_ptr = kernel::GpuKernelFactory::GetInstance().Create(kernel_name, kernel);
|
auto gpu_kernel_ptr = kernel::GpuKernelFactory::GetInstance().Create(kernel_name, kernel);
|
||||||
if (!gpu_kernel_ptr) {
|
if (!gpu_kernel_ptr) {
|
||||||
MS_LOG(EXCEPTION) << "Build gpu kernel op[" << kernel_name << "] failed";
|
MS_LOG(EXCEPTION) << "Build gpu kernel op[" << kernel->fullname_with_scope() << "] failed";
|
||||||
}
|
}
|
||||||
if (!gpu_kernel_ptr->Init(kernel)) {
|
if (!gpu_kernel_ptr->Init(kernel)) {
|
||||||
MS_LOG(EXCEPTION) << "Initialize gpu kernel op[" << kernel_name << "] failed.";
|
MS_LOG(EXCEPTION) << "Initialize gpu kernel op[" << kernel->fullname_with_scope() << "] failed.";
|
||||||
}
|
}
|
||||||
session::AnfRuntimeAlgorithm::SetKernelMod((kernel::KernelModPtr)gpu_kernel_ptr, kernel.get());
|
session::AnfRuntimeAlgorithm::SetKernelMod((kernel::KernelModPtr)gpu_kernel_ptr, kernel.get());
|
||||||
}
|
}
|
||||||
|
|
|
@ -392,9 +392,13 @@ void GPUKernelRuntime::ReleaseDeviceRes() {
|
||||||
mem_manager_->FreeDeviceMemory();
|
mem_manager_->FreeDeviceMemory();
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(bin_map);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
bin_map->RemoveKernelCache();
|
if (!(context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG))) {
|
||||||
|
kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(bin_map);
|
||||||
|
bin_map->RemoveKernelCache();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
|
void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
|
||||||
|
|
|
@ -234,6 +234,67 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
|
||||||
*origin_data_format = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "data_format");
|
*origin_data_format = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "data_format");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> node_list;
|
||||||
|
std::vector<AnfNodePtr> input_list;
|
||||||
|
std::vector<AnfNodePtr> output_list;
|
||||||
|
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
|
||||||
|
|
||||||
|
std::vector<std::string> graph_input_format;
|
||||||
|
std::vector<TypeId> graph_input_type;
|
||||||
|
// set graph kernel inputs kernel info.
|
||||||
|
for (size_t i = 0; i < input_list.size(); ++i) {
|
||||||
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||||
|
std::vector<std::string> outputs_format = {kOpFormat_DEFAULT};
|
||||||
|
std::vector<TypeId> outputs_device_type = {AnfAlgo::GetOutputInferDataType(input_list[i], 0)};
|
||||||
|
graph_input_format.push_back(kOpFormat_DEFAULT);
|
||||||
|
graph_input_type.push_back(AnfAlgo::GetOutputInferDataType(input_list[i], 0));
|
||||||
|
builder.SetOutputsFormat(outputs_format);
|
||||||
|
builder.SetOutputsDeviceType(outputs_device_type);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// set graph kernel innner nodes kernel info.
|
||||||
|
for (size_t i = 0; i < node_list.size(); ++i) {
|
||||||
|
const auto &anf_node = node_list[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
|
SetKernelInfo(cnode, KernelType::AKG_KERNEL);
|
||||||
|
}
|
||||||
|
|
||||||
|
// set graph kernel node kernel info.
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(func_graph, true);
|
||||||
|
func_graph->set_manager(mng);
|
||||||
|
}
|
||||||
|
auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
|
||||||
|
std::vector<std::string> graph_output_format;
|
||||||
|
std::vector<TypeId> graph_output_type;
|
||||||
|
for (size_t i = 0; i < output_index.size(); ++i) {
|
||||||
|
auto const &output = output_index[i];
|
||||||
|
graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second));
|
||||||
|
graph_output_type.push_back(AnfAlgo::GetOutputDeviceDataType(output.first, output.second));
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||||
|
graph_info_builder.SetInputsFormat(graph_input_format);
|
||||||
|
graph_info_builder.SetInputsDeviceType(graph_input_type);
|
||||||
|
graph_info_builder.SetOutputsFormat(graph_output_format);
|
||||||
|
graph_info_builder.SetOutputsDeviceType(graph_output_type);
|
||||||
|
graph_info_builder.SetProcessor(kernel::Processor::CUDA);
|
||||||
|
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||||
|
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||||
|
auto graph_selected_info = graph_info_builder.Build();
|
||||||
|
MS_EXCEPTION_IF_NULL(graph_selected_info);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
|
||||||
|
SetTensorDeviceInfo(*graph_selected_info, kernel_node);
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||||
|
@ -266,7 +327,14 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<s
|
||||||
format_transform_ = false;
|
format_transform_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetKernelInfo(const CNodePtr &kernel_node) {
|
void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
|
||||||
|
if (AnfAlgo::IsGraphKernel(kernel_node)) {
|
||||||
|
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
SetGraphKernelInfo(kernel_node, func_graph);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::string> inputs_format;
|
std::vector<std::string> inputs_format;
|
||||||
std::vector<TypeId> inputs_type;
|
std::vector<TypeId> inputs_type;
|
||||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||||
|
@ -291,13 +359,19 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||||
builder->SetOutputsFormat(outputs_format);
|
builder->SetOutputsFormat(outputs_format);
|
||||||
builder->SetOutputsDeviceType(outputs_type);
|
builder->SetOutputsDeviceType(outputs_type);
|
||||||
|
|
||||||
bool result =
|
bool result = false;
|
||||||
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build());
|
KernelType res_kernel_type = UNKNOWN_KERNEL_TYPE;
|
||||||
KernelType kernel_type = UNKNOWN_KERNEL_TYPE;
|
if (kernel_type == UNKNOWN_KERNEL_TYPE) {
|
||||||
|
result =
|
||||||
|
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build());
|
||||||
|
|
||||||
if (!result) {
|
if (!result) {
|
||||||
|
result = SelectAkgKernel(kernel_node, builder->Build());
|
||||||
|
res_kernel_type = AKG_KERNEL;
|
||||||
|
}
|
||||||
|
} else if (kernel_type == AKG_KERNEL) {
|
||||||
result = SelectAkgKernel(kernel_node, builder->Build());
|
result = SelectAkgKernel(kernel_node, builder->Build());
|
||||||
kernel_type = AKG_KERNEL;
|
res_kernel_type = AKG_KERNEL;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!result) {
|
if (!result) {
|
||||||
|
@ -314,7 +388,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||||
<< "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists
|
<< "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists
|
||||||
<< ", but get " << build_type;
|
<< ", but get " << build_type;
|
||||||
}
|
}
|
||||||
builder->SetKernelType(kernel_type);
|
builder->SetKernelType(res_kernel_type);
|
||||||
builder->SetProcessor(kernel::Processor::CUDA);
|
builder->SetProcessor(kernel::Processor::CUDA);
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
||||||
SetTensorDeviceInfo(*(builder->Build()), kernel_node);
|
SetTensorDeviceInfo(*(builder->Build()), kernel_node);
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "ir/dtype.h"
|
#include "ir/dtype.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
#include "backend/session/kernel_graph.h"
|
#include "backend/session/kernel_graph.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -59,7 +60,7 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
|
||||||
{prim::kPrimAddN->name(), {{}, {0}}},
|
{prim::kPrimAddN->name(), {{}, {0}}},
|
||||||
};
|
};
|
||||||
|
|
||||||
void SetKernelInfo(const CNodePtr &kernel_node);
|
void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||||
|
|
||||||
class FormatTransformChecker {
|
class FormatTransformChecker {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -194,6 +194,26 @@ constexpr auto kPaddingOpName = "Padding";
|
||||||
constexpr auto kAvgPoolOpName = "AvgPool";
|
constexpr auto kAvgPoolOpName = "AvgPool";
|
||||||
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
|
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
|
||||||
constexpr auto kTensorAddOpName = "TensorAdd";
|
constexpr auto kTensorAddOpName = "TensorAdd";
|
||||||
|
constexpr auto kCastOpName = "Cast";
|
||||||
|
constexpr auto kGreaterEqualOpName = "GreaterEqual";
|
||||||
|
constexpr auto kAbsOpName = "Abs";
|
||||||
|
constexpr auto kExpOpName = "Exp";
|
||||||
|
constexpr auto kNegOpName = "Neg";
|
||||||
|
constexpr auto kMinimumOpName = "Minimum";
|
||||||
|
constexpr auto kMaximumOpName = "Maximum";
|
||||||
|
constexpr auto kMulOpName = "Mul";
|
||||||
|
constexpr auto kSubOpName = "Sub";
|
||||||
|
constexpr auto kLogOpName = "Log";
|
||||||
|
constexpr auto kPowOpName = "Pow";
|
||||||
|
constexpr auto kReciprocalOpName = "Reciprocal";
|
||||||
|
constexpr auto kEqualOpName = "Equal";
|
||||||
|
constexpr auto kLessOpName = "Less";
|
||||||
|
constexpr auto kLessEqualOpName = "LessEqual";
|
||||||
|
constexpr auto kSquareOpName = "Square";
|
||||||
|
constexpr auto kSelectOpName = "Select";
|
||||||
|
constexpr auto kReduceSumOpName = "ReduceSum";
|
||||||
|
constexpr auto kReduceMinOpName = "ReduceMin";
|
||||||
|
constexpr auto kReduceMaxOpName = "ReduceMax";
|
||||||
constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum";
|
constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum";
|
||||||
constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";
|
constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";
|
||||||
|
|
||||||
|
|
|
@ -206,6 +206,11 @@ inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
|
||||||
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
|
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
|
||||||
inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
|
inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
|
||||||
inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
|
inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
|
||||||
|
inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs");
|
||||||
|
inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round");
|
||||||
|
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp");
|
||||||
|
inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");
|
||||||
|
inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt");
|
||||||
|
|
||||||
// Statements
|
// Statements
|
||||||
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
|
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
|
||||||
|
|
|
@ -290,7 +290,7 @@ class Parameter : public ANode {
|
||||||
std::string DebugString(int recursive_level = 1) const override;
|
std::string DebugString(int recursive_level = 1) const override;
|
||||||
std::string name() const { return name_; }
|
std::string name() const { return name_; }
|
||||||
void set_name(const std::string &name) { name_ = name; }
|
void set_name(const std::string &name) { name_ = name; }
|
||||||
std::string fullname_with_scope() override { return name(); };
|
std::string fullname_with_scope() override { return name(); }
|
||||||
|
|
||||||
bool has_default() const { return has_default_; }
|
bool has_default() const { return has_default_; }
|
||||||
void set_default_param(ValuePtr param) {
|
void set_default_param(ValuePtr param) {
|
||||||
|
|
|
@ -273,7 +273,8 @@ class Lamb(Optimizer):
|
||||||
self.global_step = Parameter(initializer(0, [1]), name='global_step')
|
self.global_step = Parameter(initializer(0, [1]), name='global_step')
|
||||||
self.assignadd = P.AssignAdd()
|
self.assignadd = P.AssignAdd()
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
self.enable_graph_kernel = context.get_context("enable_graph_kernel")
|
self.enable_graph_kernel = context.get_context("enable_graph_kernel") and \
|
||||||
|
context.get_context("device_target") == "Ascend"
|
||||||
|
|
||||||
def construct(self, gradients):
|
def construct(self, gradients):
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
|
|
|
@ -13,24 +13,44 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""__init__"""
|
"""__init__"""
|
||||||
|
from .abs import _abs_akg
|
||||||
|
from .add import _add_akg
|
||||||
|
from .add_n import _addn_akg
|
||||||
from .cast import _cast_akg
|
from .cast import _cast_akg
|
||||||
from .equal import _equal_akg
|
from .equal import _equal_akg
|
||||||
from .mean import _simple_mean_akg
|
from .exp import _exp_akg
|
||||||
from .mean_grad import _simple_mean_grad_akg
|
from .expand_dims import _expand_dims_akg
|
||||||
from .mul import _mul_akg
|
from .greater_equal import _greater_equal_akg
|
||||||
from .relu6 import _relu6_akg
|
|
||||||
from .relu6_grad import _relu6_grad_akg
|
|
||||||
from .squeeze import _squeeze_akg
|
|
||||||
from .squeeze_grad import _squeeze_grad_akg
|
|
||||||
from .tile import _tile_akg
|
|
||||||
from .hsigmoid import _hsigmoid_akg
|
from .hsigmoid import _hsigmoid_akg
|
||||||
from .hsigmoid_grad import _hsigmoid_grad_akg
|
from .hsigmoid_grad import _hsigmoid_grad_akg
|
||||||
from .hswish import _hswish_akg
|
from .hswish import _hswish_akg
|
||||||
from .hswish_grad import _hswish_grad_akg
|
from .hswish_grad import _hswish_grad_akg
|
||||||
from .sub import _sub_akg
|
from .lessequal import _lessequal_akg
|
||||||
|
from .log import _log_akg
|
||||||
from .logical_and import _logical_and_akg
|
from .logical_and import _logical_and_akg
|
||||||
from .logical_not import _logical_not_akg
|
from .logical_not import _logical_not_akg
|
||||||
from .logical_or import _logical_or_akg
|
from .logical_or import _logical_or_akg
|
||||||
from .lessequal import _lessequal_akg
|
from .maximum import _maximum_akg
|
||||||
|
from .mean import _simple_mean_akg
|
||||||
|
from .mean_grad import _simple_mean_grad_akg
|
||||||
|
from .minimum import _minimum_akg
|
||||||
|
from .mul import _mul_akg
|
||||||
|
from .neg import _neg_akg
|
||||||
from .notequal import _notequal_akg
|
from .notequal import _notequal_akg
|
||||||
from .greater_equal import _greater_equal_akg
|
from .pow import _pow_akg
|
||||||
|
from .real_div import _real_div_akg
|
||||||
|
from .reciprocal import _reciprocal_akg
|
||||||
|
from .reduce_max import _reduce_max_akg
|
||||||
|
from .reduce_sum import _reduce_sum_akg
|
||||||
|
from .relu6 import _relu6_akg
|
||||||
|
from .relu6_grad import _relu6_grad_akg
|
||||||
|
from .reshape import _reshape_akg
|
||||||
|
from .round import _round_akg
|
||||||
|
from .rsqrt import _rsqrt_akg
|
||||||
|
from .sqrt import _sqrt_akg
|
||||||
|
from .squeeze import _squeeze_akg
|
||||||
|
from .squeeze_grad import _squeeze_grad_akg
|
||||||
|
from .sub import _sub_akg
|
||||||
|
from .tile import _tile_akg
|
||||||
|
|
||||||
|
# Please insert op register in lexicographical order of the filename.
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Abs op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Abs") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.dtype_format(DT.I32_Default, DT.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _abs_akg():
|
||||||
|
"""Abs Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""TensorAdd op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("TensorAdd") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.input(1, "y") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
|
||||||
|
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _add_akg():
|
||||||
|
"""TensorAdd Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""AddN op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("AddN") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.input(0, "inputs", "dynamic") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _addn_akg():
|
||||||
|
"""AddN Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Exp op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Exp") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _exp_akg():
|
||||||
|
"""Exp Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ExpandDims op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("ExpandDims") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.attr("axis", "required", "int") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.dtype_format(DT.I32_Default, DT.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _expand_dims_akg():
|
||||||
|
"""ExpandDims Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Log op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Log") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.dtype_format(DT.I32_Default, DT.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _log_akg():
|
||||||
|
"""Log Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Maximum op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Maximum") \
|
||||||
|
.fusion_type("COMMREDUCE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.input(1, "y") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
|
||||||
|
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _maximum_akg():
|
||||||
|
"""Maximum Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Minimum op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Minimum") \
|
||||||
|
.fusion_type("COMMREDUCE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.input(1, "y") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
|
||||||
|
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _minimum_akg():
|
||||||
|
"""Minimum Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Neg op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Neg") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _neg_akg():
|
||||||
|
"""Neg Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Pow op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Pow") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.input(1, "y") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
|
||||||
|
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _pow_akg():
|
||||||
|
"""Pow Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""RealDiv op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("RealDiv") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.input(1, "y") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _real_div_akg():
|
||||||
|
"""RealDiv Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Reciprocal op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Reciprocal") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _reciprocal_akg():
|
||||||
|
"""Reciprocal Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReduceMax op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("ReduceMax") \
|
||||||
|
.fusion_type("COMMREDUCE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.attr("axis", "required", "listInt") \
|
||||||
|
.attr("keep_dims", "required", "bool") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _reduce_max_akg():
|
||||||
|
"""ReduceMax Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReduceMin op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("ReduceMin") \
|
||||||
|
.fusion_type("COMMREDUCE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.attr("axis", "required", "listInt") \
|
||||||
|
.attr("keep_dims", "required", "bool") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _reduce_min_akg():
|
||||||
|
"""ReduceMin Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""ReduceSum op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("ReduceSum") \
|
||||||
|
.fusion_type("COMMREDUCE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.attr("axis", "required", "listInt") \
|
||||||
|
.attr("keep_dims", "required", "bool") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _reduce_sum_akg():
|
||||||
|
"""ReduceSum Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Reshape op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Reshape") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "y") \
|
||||||
|
.attr("shape", "required", "listInt") \
|
||||||
|
.dtype_format(DT.BOOL_Default, DT.BOOL_Default) \
|
||||||
|
.dtype_format(DT.I8_Default, DT.I8_Default) \
|
||||||
|
.dtype_format(DT.I16_Default, DT.I16_Default) \
|
||||||
|
.dtype_format(DT.I32_Default, DT.I32_Default) \
|
||||||
|
.dtype_format(DT.I64_Default, DT.I64_Default) \
|
||||||
|
.dtype_format(DT.U8_Default, DT.U8_Default) \
|
||||||
|
.dtype_format(DT.U16_Default, DT.U16_Default) \
|
||||||
|
.dtype_format(DT.U32_Default, DT.U32_Default) \
|
||||||
|
.dtype_format(DT.U64_Default, DT.U64_Default) \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.dtype_format(DT.F64_Default, DT.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _reshape_akg():
|
||||||
|
"""Reshape Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Round op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Round") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.dtype_format(DT.I32_Default, DT.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _round_akg():
|
||||||
|
"""Round Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Rsqrt op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Rsqrt") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _rsqrt_akg():
|
||||||
|
"""Rsqrt Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Sqrt op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
|
||||||
|
|
||||||
|
op_info = AkgGpuRegOp("Sqrt") \
|
||||||
|
.fusion_type("ELEMWISE") \
|
||||||
|
.input(0, "x") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DT.F16_Default, DT.F16_Default) \
|
||||||
|
.dtype_format(DT.F32_Default, DT.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(op_info)
|
||||||
|
def _sqrt_akg():
|
||||||
|
"""Sqrt Akg register"""
|
||||||
|
return
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
from mindspore.nn.graph_kernels import ReLU
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.add = P.TensorAdd()
|
||||||
|
self.sub = P.Sub()
|
||||||
|
self.mul = P.Mul()
|
||||||
|
self.relu = ReLU()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
sub_res = self.sub(x, y)
|
||||||
|
mul_res = self.mul(sub_res, x)
|
||||||
|
relu_res = self.relu(mul_res)
|
||||||
|
square_res = P.Square()(relu_res)
|
||||||
|
add_res = self.add(relu_res, square_res)
|
||||||
|
add1_res = self.add(add_res, add_res)
|
||||||
|
return self.add(add1_res, add1_res)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_basic():
|
||||||
|
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||||
|
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||||
|
sub_res = input_x - input_y
|
||||||
|
mul_res = sub_res * input_x
|
||||||
|
relu_res = np.maximum(mul_res, 0)
|
||||||
|
square_res = np.square(relu_res)
|
||||||
|
add_res = relu_res + square_res
|
||||||
|
add1_res = add_res + add_res
|
||||||
|
expect = add1_res + add1_res
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
result = net(Tensor(input_x), Tensor(input_y))
|
||||||
|
|
||||||
|
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
|
||||||
|
assert res
|
|
@ -0,0 +1,77 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.layernorm = P.LayerNorm(1, 1)
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
return self.layernorm(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_basic():
|
||||||
|
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
|
||||||
|
gamma = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
|
||||||
|
beta = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
|
||||||
|
shape_x = [2, 3, 4, 3]
|
||||||
|
begin_norm_axis = 1
|
||||||
|
|
||||||
|
in_rank = len(shape_x)
|
||||||
|
if begin_norm_axis < 0:
|
||||||
|
norm_axis = begin_norm_axis + in_rank
|
||||||
|
else:
|
||||||
|
norm_axis = begin_norm_axis
|
||||||
|
norm_axes = tuple(range(norm_axis, in_rank))
|
||||||
|
mean = np.mean(input_x, axis=norm_axes, keepdims=True)
|
||||||
|
mean_b = np.broadcast_to(mean, shape_x)
|
||||||
|
diff = input_x - mean_b
|
||||||
|
square = np.square(diff)
|
||||||
|
smean = np.mean(square, axis=norm_axes, keepdims=True)
|
||||||
|
smean_b = np.broadcast_to(smean, shape_x)
|
||||||
|
meps = smean_b + 1e-5
|
||||||
|
logs = np.log(meps)
|
||||||
|
mul = logs * (-0.5)
|
||||||
|
rsqrt = np.exp(mul)
|
||||||
|
out = diff * rsqrt
|
||||||
|
bn = out * gamma + beta
|
||||||
|
expect = (bn, mean, smean)
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
|
||||||
|
net_result = net(Tensor(input_x), Tensor(gamma), Tensor(beta))
|
||||||
|
if isinstance(net_result, tuple) and len(net_result) == 3:
|
||||||
|
result = (net_result[0].asnumpy(), net_result[1].asnumpy(), net_result[2].asnumpy())
|
||||||
|
res0 = np.allclose(expect[0], result[0], rtol=1.e-4, atol=1.e-4, equal_nan=True)
|
||||||
|
assert res0
|
||||||
|
res1 = np.allclose(expect[1], result[1], rtol=1.e-4, atol=1.e-7, equal_nan=True)
|
||||||
|
assert res1
|
||||||
|
res2 = np.allclose(expect[2], result[2], rtol=1.e-4, atol=1.e-7, equal_nan=True)
|
||||||
|
assert res2
|
||||||
|
else:
|
||||||
|
assert False
|
|
@ -115,6 +115,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"../../../mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc"
|
"../../../mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc"
|
||||||
"../../../mindspore/ccsrc/backend/kernel_compiler/tbe/*.cc"
|
"../../../mindspore/ccsrc/backend/kernel_compiler/tbe/*.cc"
|
||||||
"../../../mindspore/ccsrc/backend/optimizer/ascend/*.cc"
|
"../../../mindspore/ccsrc/backend/optimizer/ascend/*.cc"
|
||||||
|
"../../../mindspore/ccsrc/backend/optimizer/graph_kernel/*.cc"
|
||||||
"../../../mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc"
|
"../../../mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc"
|
||||||
"../../../mindspore/ccsrc/backend/session/ascend_session.cc"
|
"../../../mindspore/ccsrc/backend/session/ascend_session.cc"
|
||||||
"../../../mindspore/ccsrc/backend/session/ascend_control_parser.cc"
|
"../../../mindspore/ccsrc/backend/session/ascend_control_parser.cc"
|
||||||
|
|
Loading…
Reference in New Issue