!5783 GraphKernel supports GPU

Merge pull request !5783 from DeshiChen/graph_kernel_1.0
This commit is contained in:
mindspore-ci-bot 2020-09-10 09:15:21 +08:00 committed by Gitee
commit 7152fe04be
85 changed files with 6213 additions and 1642 deletions

2
akg

@ -1 +1 @@
Subproject commit 3bb6264188d0b1d6ff776a35a571bc7190df0800 Subproject commit d237aa7d8e9d3fb709bda9f30205b02129bc2b59

View File

@ -0,0 +1,17 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init"""
from .splitter import split_with_json
from .expander import get_op_expander

View File

@ -0,0 +1,58 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate json desc for graph kernel ops"""
import json
import json.decoder as jd
import traceback
from mindspore import log as logger
import mindspore._extends.graph_kernel.expanders as expanders
def get_op_expander(json_str: str):
"""get op expander by json info"""
try:
kernel_info = json.loads(json_str)
expand_info = kernel_info['expand_info']
if 'name' not in expand_info:
logger.error("expand info have no op name")
return None
if 'process' not in expand_info:
logger.error("expand info have no processor info")
return None
processor = expand_info['process']
op_name = str(expand_info['name']).lower()
expand_op_func_name = 'expand_' + op_name
if not hasattr(expanders, expand_op_func_name):
logger.error("Generator do not support op: {}".format(op_name))
return None
expand_op_func = getattr(expanders, expand_op_func_name)
# generate graph desc.
graph = expand_op_func(expand_info)
if graph is None:
logger.error("Failed to generate graph of: {}".format(op_name))
return None
graph.set_processor(processor)
# dump graph to json desc.
desc = graph.dump()
return json.dumps(desc)
except jd.JSONDecodeError:
logger.error("Failed to generate graph kernel op")
logger.error(traceback.format_exc())
return None

View File

@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""expanders init"""
from .gelu import expand_gelu
from .layernorm import expand_layernorm
from .softmax import expand_softmax
from .square import expand_square

View File

@ -0,0 +1,68 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for gelu"""
from mindspore._extends.graph_kernel.model import model_builder as builder
CSVALUE = 0.044715
CSVALUE_A = 1.5957691 # 2*np.sqrt(2/np.pi)
def expand_gelu(expand_info):
"""Gelu expander"""
# get op info.
input_desc = expand_info['input_desc'][0]
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
dtype = input_x.dtype
if dtype == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
# cal tanh.
mul_0 = graph_builder.emit('Mul', [input_x, input_x])
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format'])
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1])
const_csvalue_a = graph_builder.value(tanh_res.dtype, CSVALUE_A, input_desc['format'])
mul_0 = graph_builder.emit('Mul', [tanh_res, const_csvalue_a])
const_zero = graph_builder.value(mul_0.dtype, 0.0, input_desc['format'])
mul_0_min = graph_builder.emit('Minimum', [mul_0, const_zero])
right_mul = graph_builder.emit('Exp', [mul_0_min])
mul_0_abs = graph_builder.emit('Abs', [mul_0])
const_neg_one = graph_builder.value(mul_0_abs.dtype, -1.0, input_desc['format'])
mul_0_abs_neg = graph_builder.emit('Mul', [mul_0_abs, const_neg_one])
mul_0_abs_neg_exp = graph_builder.emit('Exp', [mul_0_abs_neg])
const_one = graph_builder.value(mul_0_abs_neg_exp.dtype, 1.0, input_desc['format'])
mul_0_abs_neg_exp_add = graph_builder.emit('TensorAdd', [mul_0_abs_neg_exp, const_one])
left_mul = graph_builder.emit('RealDiv', [input_x, mul_0_abs_neg_exp_add])
result = graph_builder.emit('Mul', [left_mul, right_mul])
if dtype == 'float16':
result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'})
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for LayerNorm"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_layernorm(expand_info):
"""LayerNorm expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
attrs = expand_info['attr']
begin_norm_axis = None
epsilon = None
for item in attrs:
if 'begin_norm_axis' in item:
begin_norm_axis = item['begin_norm_axis']
if 'epsilon' in item:
epsilon = item['epsilon']
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_gamma = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
input_beta = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
# Calculate the scaling ratio of the average
shape_x = input_desc_0['shape']
if begin_norm_axis < 0:
begin_norm_axis += len(shape_x)
reduce_axis = ()
for i, _ in enumerate(shape_x):
if i > begin_norm_axis or i == begin_norm_axis:
reduce_axis = reduce_axis + (i,)
reduce_elts = 1.0
for i in reduce_axis:
reduce_elts *= shape_x[i]
mean_cof = 1.0 / reduce_elts
mean_cof_v = graph_builder.value(input_x.dtype, mean_cof, input_x.data_format)
# Calculate mean
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
# Calculate variance
variance_sub = graph_builder.emit('Sub', [input_x, mean])
variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub])
variance_red = graph_builder.emit('ReduceSum', [variance_mul],
attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
# Calculate normalize
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
epsilon_v = graph_builder.value(input_x.dtype, epsilon, input_x.data_format)
normalize_add = graph_builder.emit('TensorAdd', [variance, epsilon_v])
normalize_log = graph_builder.emit('Log', [normalize_add])
input_y = graph_builder.value(input_x.dtype, -0.5, input_x.data_format)
normalize_log_mul = graph_builder.emit('Mul', [normalize_log, input_y])
normalize_exp = graph_builder.emit('Exp', [normalize_log_mul])
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normalize_exp])
# Calculate scale and translate
scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul])
res = graph_builder.emit('TensorAdd', [scale_mul, input_beta])
# set graph output.
graph_scope.set_output(res, mean, variance)
graph = graph_builder.get()[0]
return graph

View File

@ -0,0 +1,51 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for softmax"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_softmax(expand_info):
"""Softmax expander"""
# get op info.
input_desc = expand_info['input_desc'][0]
attrs = expand_info['attr']
axis = None
for item in attrs:
if 'axis' in item:
axis = item['axis']
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
# cal softmax.
if input_x.dtype == 'float32':
input_x_cast = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
max_x = graph_builder.emit('ReduceMax', [input_x_cast], attrs={'reduce_axis': axis, 'keep_dims': True})
max_x = graph_builder.emit('Cast', [max_x], attrs={'dst_type': 'float32'})
else:
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
data_sub = graph_builder.emit('Sub', [input_x, max_x])
data_exp = graph_builder.emit('Exp', [data_sub])
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
result = graph_builder.emit('RealDiv', [data_exp, data_expsum])
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for square"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_square(expand_info):
"""Square expander"""
# get op info.
input_desc = expand_info['input_desc'][0]
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
# create op.
result = graph_builder.emit('Mul', [input_x, input_x])
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

View File

@ -0,0 +1,18 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""GraphKernel cost model init"""
from .graph_split import split
from .model_builder import GraphBuilder, load_composite

View File

@ -0,0 +1,153 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Cost model splitter"""
from .model import PrimLib, Graph
class GraphSplitByPattern:
"""Graph split by pattern"""
def __init__(self, graph):
self.graph = graph
self.groups = []
self.op_group = {}
for op in self.graph.ops:
g = [op]
self.groups.append(g)
self.op_group[op] = g
self.ids = {}
for i, op in enumerate(graph.ops):
self.ids[op] = i
self.doms = self.post_dom(graph.ops)
_, outputs = graph.deduce_parameters()
self.outputs = set(outputs)
def post_dom(self, ops):
"""Post dom"""
doms, i_doms = {}, {}
for i in range(len(ops) - 1, -1, -1):
op = ops[i]
doms[op] = {op}
i_dom = None
if op.output.to_ops:
suc_dom = set(doms[op.output.to_ops[0]])
for to in op.output.to_ops[1:]:
suc_dom.intersection_update(doms[to])
doms[op].update(suc_dom)
for dom in suc_dom:
if i_dom is None or self.ids[dom] < self.ids[i_dom]:
i_dom = dom
i_doms[op] = i_dom
return i_doms
def get_pattern(self, op, i):
"""Get pattern"""
pattern = PrimLib.UNKNOWN
_, elem_relation = PrimLib.input_relation(op, i)
for pat in elem_relation:
if pat and pat > pattern:
pattern = pat
return pattern
def fuse(self, check_fun):
"""Fuse ops"""
def _get_path(op, dom):
path_ops, visited = [], set()
def _get_path_depth(p):
visited.add(p)
if self.op_group[p][0] == p:
path_ops.append(p)
for to in p.output.to_ops:
if to != dom and to not in visited:
_get_path_depth(to)
_get_path_depth(op)
return path_ops
changed = True
while changed:
for group in self.groups:
op = group[0]
dom = self.doms[op]
if dom is None or op.output in self.outputs:
continue
ops = _get_path(op, dom)
if check_fun(op, dom, ops):
dom_group = self.op_group[dom]
fused = []
for fop in ops:
f_group = self.op_group[fop]
for p in f_group:
self.op_group[p] = dom_group
fused.append(f_group)
dom_group += f_group
for g in fused:
self.groups.remove(g)
break
else:
changed = False
def to_subgraphs(self):
"""Transform op groups to subgraphs"""
subgraphs = []
for i, group in enumerate(self.groups):
group.sort(key=lambda op: self.ids[op])
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), group))
return subgraphs
def split(self):
"""Split graph"""
def _buddy(op, dom, path_ops):
"""Fuse buddy together"""
# pylint: disable=unused-argument
group = self.op_group[op]
for p in group:
# p is buddy
if p.output.buddy is not None and p.output.buddy.members[0].op not in group:
return True
# p's output is buddy
for to in p.output.to_ops:
if to.output.buddy is not None and to not in group:
return True
return False
def _injective(pattern, limit):
def _checker(op, dom, path_ops):
# pylint: disable=unused-argument
for p in op.output.to_ops:
if p not in self.op_group[dom]:
return False
if PrimLib.iter_type(op) in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
for i, t in enumerate(dom.inputs):
if t == op.output:
return self.get_pattern(dom, i) == pattern and len(self.op_group[op]) < limit
return False
return _checker
def _diamond(op, dom, path_ops):
if PrimLib.iter_type(op) not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
PrimLib.iter_type(dom) in (PrimLib.UNKNOWN, PrimLib.TRANSFORM):
return False
return len(path_ops) == 1 and op.output not in dom.inputs
self.fuse(_buddy)
self.fuse(_injective(PrimLib.ELEMWISE, 100))
self.fuse(_injective(PrimLib.BROADCAST, 6))
self.fuse(_injective(PrimLib.REDUCE, 6))
self.fuse(_diamond)
return self.to_subgraphs()
def split(graph):
return GraphSplitByPattern(graph).split()

View File

@ -0,0 +1,473 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""GraphKernel cost model"""
class Utils:
"""Model utils"""
@staticmethod
def get_attr_type(attr):
"""Get attr type"""
if isinstance(attr, bool):
return 'bool'
if isinstance(attr, str):
return 'str'
if isinstance(attr, int):
return 'int'
if isinstance(attr, float):
return 'bool'
if isinstance(attr, (list, tuple)):
if not attr:
raise ValueError("Length of attr is 0")
if isinstance(attr[0], int):
return 'listInt'
if isinstance(attr[0], str):
return 'listStr'
raise ValueError("Unknown type of attr: {}".format(attr))
class DataFormat:
"""DataFormat"""
DEFAULT = "DefaultFormat"
NC1KHKWHWC0 = "NC1KHKWHWC0"
ND = "ND"
NCHW = "NCHW"
NHWC = "NHWC"
HWCN = "HWCN"
NC1HWC0 = "NC1HWC0"
FRAC_Z = "FracZ"
FRAC_NZ = "FRACTAL_NZ"
C1HWNCOC0 = "C1HWNCoC0"
NC1HWC0_C04 = "NC1HWC0_C04"
FRACTAL_Z_C04 = "FRACTAL_Z_C04"
NDHWC = "NDHWC"
class Config:
R0 = 8.0
UB_SIZE = 256 * 1024
MAX_BLOCK = 32
class PrimLib:
"""Prim lib"""
UNKNOWN = 0
ELEMWISE = 1
BROADCAST = 2
REDUCE = 3
TRANSFORM = 4
CONTROL = 5
class Prim:
"""Prim"""
def __init__(self, iter_type, calibrate=1, relation_func=None):
self.iter_type = iter_type
self.calibrate = calibrate
self.relation_func = relation_func
if relation_func is None:
self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
def default_elemwise_broadcast_relation(self, op, input_idx):
"""Process elemwise and broadcast relation"""
out_shape = op.output.shape
in_shape = op.inputs[input_idx].shape
assert len(out_shape) >= len(in_shape)
axis_relation, elem_relation = [], []
delta = len(out_shape) - len(in_shape)
if delta > 0:
for i in range(0, delta):
axis_relation.append(None)
elem_relation.append(None)
for i, _ in enumerate(in_shape):
axis_relation.append(i)
elem_relation.append(
PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
return axis_relation, elem_relation
def default_reduce_relation(self, op, input_idx):
"""Process reduce relation"""
axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
for i in op.attrs['reduce_axis']:
elem_relation[i] = PrimLib.REDUCE
return axis_relation, elem_relation
def unknown_relation(self, op, input_idx):
"""Process unknown relation"""
out_shape = op.output.shape
in_shape = op.inputs[input_idx].shape
all_relation = list(range(len(in_shape)))
axis_relation = [all_relation for i in range(0, len(out_shape))]
elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
return axis_relation, elem_relation
default_relation_func = [
unknown_relation,
default_elemwise_broadcast_relation,
default_elemwise_broadcast_relation,
default_reduce_relation,
unknown_relation,
unknown_relation,
]
primtives = {
'TensorAdd': Prim(ELEMWISE),
'Abs': Prim(ELEMWISE),
'Neg': Prim(ELEMWISE),
'Mul': Prim(ELEMWISE),
'Sub': Prim(ELEMWISE),
'Log': Prim(ELEMWISE),
'Exp': Prim(ELEMWISE),
'Rsqrt': Prim(ELEMWISE),
'Sqrt': Prim(ELEMWISE),
'RealDiv': Prim(ELEMWISE),
'Cast': Prim(ELEMWISE),
'Pow': Prim(ELEMWISE),
'Minimum': Prim(ELEMWISE),
'Maximum': Prim(ELEMWISE),
'Reciprocal': Prim(ELEMWISE),
'Equal': Prim(ELEMWISE),
'Greater': Prim(ELEMWISE),
'GreaterEqual': Prim(ELEMWISE),
'Less': Prim(ELEMWISE),
'LessEqual': Prim(ELEMWISE),
'Square': Prim(ELEMWISE),
'AddN': Prim(ELEMWISE),
'Select': Prim(ELEMWISE, 8),
'ReduceSum': Prim(REDUCE),
'ReduceMax': Prim(REDUCE),
'ReduceMin': Prim(REDUCE),
'make_tuple': Prim(CONTROL),
'ControlDepend': Prim(CONTROL),
'@ReduceInit': Prim(ELEMWISE),
}
default_primtive = Prim(UNKNOWN)
@classmethod
def get_prim(cls, op):
prim = cls.primtives.get(op.prim, None)
if prim is None:
print('[WARN] primtive is not registered: ' + op.prim)
prim = cls.default_primtive
return prim
@classmethod
def input_relation(cls, op, input_idx):
return cls.get_prim(op).relation_func(op, input_idx)
@classmethod
def iter_type(cls, op):
return cls.get_prim(op).iter_type
@classmethod
def is_reduce(cls, op):
return cls.get_prim(op).iter_type == cls.REDUCE
@classmethod
def calibrate_iter_size(cls, op, iter_size):
return cls.get_prim(op).calibrate * iter_size
@classmethod
def dtype_bytes(cls, dtype):
bits, unit = 1, 1
for i in range(len(dtype) - 1, 0, -1):
if dtype[i].isdecimal():
bits += int(dtype[i]) * unit
unit *= 10
else:
break
return bits // 8
@classmethod
def inplace_reuse(cls, op, input_idx, start_axis=0):
if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
return False
_, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
for i in range(start_axis, len(elem_relation)):
if elem_relation[i] != cls.ELEMWISE:
return False
return True
class Tensor:
"""Tensor"""
PARA_NONE = 0
PARA_INPUT = 1
PARA_OUTPUT = 2
class Buddy:
def __init__(self, leader):
self.members = [leader]
def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
self.name = name
self.shape = shape
self.dtype = dtype
self.data_format = data_format
self.para_type = para_type
self.op = None
self.to_ops = []
self.buddy = None
def __str__(self):
return self.name + str(list(self.shape))
def __repr__(self):
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
def get_size(self):
"""Get size"""
size = PrimLib.dtype_bytes(self.dtype)
for i in self.shape:
size *= i
return size
def add_buddy(self, tensor):
"""Add buddy"""
if self.buddy is None:
self.buddy = self.Buddy(self)
self.buddy.members.append(tensor)
tensor.buddy = self.buddy
class Value:
"""Value"""
def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
self.name = name
self.shape = [1]
self.dtype = dtype
self.value = value
self.data_format = data_format
def __str__(self):
return self.name + str(list(self.shape)) + str(self.value)
def __repr__(self):
return "%s.%s%s%s" % (self.name, self.dtype, str(list(self.shape)), str(self.value))
def get_size(self):
return 1
class Operator:
"""Operator"""
def __init__(self, primtive, inputs, output, attrs):
self.prim = primtive
self.inputs = inputs
self.output = output
self.attrs = attrs
for t in inputs:
t.to_ops.append(self)
if output.op is None:
output.op = self
self.all_inputs = [] # include Tensor inputs and Value inputs.
def __str__(self):
args = ', '.join([str(t) for t in self.all_inputs])
expr = "%s = %s.%s(%s)" % (
str(self.output), self.prim, self.output.dtype, args)
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
def __repr__(self):
return str(self)
class Graph:
"""Graph"""
def __init__(self, name, ops):
self.name = name
self.ops = ops # in topo order, can not use set
self.outputs = []
def set_processor(self, processor):
"""Set processor"""
self.processor = processor
def add(self, ops):
"""Add ops"""
if isinstance(ops, Operator):
self.ops.append(ops)
else:
self.ops.extend(ops)
def extract_subgraph(self, graph_name, tensor_names, difference=False):
"""Extract subgraph from this graph"""
graph = Graph(graph_name, [])
outputs = set(tensor_names)
if difference:
for op in self.ops:
if op.output.name not in outputs:
graph.add(op)
else:
for op in self.ops:
if op.output.name in outputs:
graph.add(op)
outputs.remove(op.output.name)
for name in outputs:
raise ValueError("invalid input tensor : " + name)
return graph
def deduce_parameters(self):
"""Deduce parameters"""
inputs, outputs = [], []
for op in self.ops:
for t in op.inputs:
if t not in inputs and t.op not in self.ops:
inputs.append(t)
if op.output not in outputs:
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
outputs.append(op.output)
else:
for d in op.output.to_ops:
if d not in self.ops:
outputs.append(op.output)
break
if self.outputs:
outputs = self.outputs
return inputs, outputs
def __str__(self):
inputs, outputs = self.deduce_parameters()
para_str = ', '.join([repr(t) for t in inputs])
out_str = ', '.join([repr(t) for t in outputs])
lines = []
lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
for op in self.ops:
lines.append(' ' + str(op))
lines.append('}')
return '\n'.join(lines)
def __repr__(self):
return str(self)
def dump(self):
"""Dump Graph to json"""
attr_name = {'reduce_axis': 'axis'}
inputs, outputs = self.deduce_parameters()
input_desc, output_desc, op_desc = [], [], []
for t in inputs:
input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}])
for t in outputs:
output_desc.append({'data_type': t.dtype, 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format})
for op in self.ops:
attrs, in_desc = [], []
for a in op.attrs:
name = attr_name.get(a, a)
attrs.append(
{'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
for t in op.all_inputs:
if isinstance(t, Tensor):
in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}])
else:
in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}])
out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
'tensor_name': op.output.name, 'format': t.data_format}]
op_desc.append({'attr': attrs, 'impl_path': '',
'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
'platform': 'AKG', 'process': self.processor}
return graph_desc
class GraphVisitor:
"""Graph visitor"""
def __init__(self, forward=True, once_mode=True):
self.forward = forward
self.once_mode = once_mode
if self.once_mode:
self.visited = set()
def visit_graph(self, graph):
"""Visit graph"""
inputs, outputs = graph.deduce_parameters()
if self.forward:
for tensor in inputs:
for op in tensor.to_ops:
self.visit(op)
else:
for tensor in outputs:
if not tensor.to_ops:
self.visit(tensor.op)
def visit(self, op):
"""Visit op"""
next_ops = op.output.to_ops if self.forward else [
t.op for t in op.inputs if t.op is not None]
if self.once_mode:
self.visited.add(op)
for n in next_ops:
if n not in self.visited:
self.visit(n)
else:
for n in next_ops:
self.visit(n)
class AlignShape(GraphVisitor):
"""Align shape"""
def __init__(self):
super().__init__(once_mode=False)
def visit(self, op):
prim = PrimLib.get_prim(op)
if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
out_dim = len(op.output.shape)
align_dim = out_dim
for t in op.inputs:
if len(t.shape) > align_dim:
align_dim = len(t.shape)
if align_dim > out_dim:
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
super().visit(op)
class AddControlBuddy(GraphVisitor):
"""Add control buddy"""
def __init__(self):
super().__init__()
self.buddies = {} # {op : [ctrl_op]}
def visit(self, op):
if PrimLib.iter_type(op) == PrimLib.CONTROL:
assert len(op.output.to_ops) == 1
owner = op.output.to_ops[0]
if owner in self.buddies:
self.buddies[owner].append(op)
else:
self.buddies[owner] = [op]
if op in self.buddies:
ops = self.buddies.pop(op)
self.buddies[owner].extend(ops)
super().visit(op)
def visit_graph(self, graph):
super().visit_graph(graph)
for owner in self.buddies:
for op in self.buddies[owner]:
owner.add_buddy(op.output)

View File

@ -0,0 +1,292 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""GraphKernel model builder"""
import copy
from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy
class OpInfer:
"""Op infer"""
@staticmethod
def default_reduce_infer(inputs, attrs):
shape = copy.deepcopy(inputs[0].shape)
for i in attrs['reduce_axis']:
shape[i] = 1
return shape
default_infer_shape_func = [
None,
lambda inputs, attrs: max([t.shape for t in inputs]),
lambda inputs, attrs: max([t.shape for t in inputs]),
default_reduce_infer.__func__,
None,
lambda inputs, attrs: [1], # control op
]
@staticmethod
def default_infer_dtype_func(inputs, attrs):
"""Infer dtype"""
# pylint: disable=unused-argument
return inputs[0].dtype
@staticmethod
def default_infer_format_func(inputs, attrs):
"""Infer format"""
# pylint: disable=unused-argument
return inputs[0].data_format
infer_shape_func = {
# add special infer func here
}
infer_dtype_func = {
# add special infer func here
'Cast': lambda inputs, attrs: attrs['dst_type'],
}
infer_format_func = {
# add special infer func here
}
@classmethod
def infer(cls, prim_name, inputs, attrs):
prim = PrimLib.primtives[prim_name]
infer_shape = cls.infer_shape_func.get(
prim_name, cls.default_infer_shape_func[prim.iter_type])
infer_dtype = cls.infer_dtype_func.get(
prim_name, cls.default_infer_dtype_func)
infer_format = cls.infer_format_func.get(
prim_name, cls.default_infer_format_func)
return infer_shape(inputs, attrs), infer_dtype(inputs, attrs), infer_format(inputs, attrs)
class GraphBuilder:
"""Graph builder"""
class GraphWrapper:
def __init__(self, name):
self.graph = Graph(name, [])
def set_output(self, *para):
for t in para:
t.para_type = Tensor.PARA_OUTPUT
self.graph.outputs.append(t)
def __init__(self):
self.graphs = []
self.current = None
self.name_id = 0
def _alloc_tensor_name(self):
tid = self.name_id
self.name_id += 1
return "t%d" % (tid)
def graph_scope(self, name):
"""The graph scope to be processed"""
class GraphScope:
def __init__(self, gb):
self.gb = gb
def __enter__(self):
return self.gb.current
def __exit__(self, ptype, value, trace):
self.gb.graphs.append(self.gb.current.graph)
self.gb.current = None
assert self.current is None
self.current = self.GraphWrapper(name)
return GraphScope(self)
def tensor(self, shape, dtype, data_format="DefaultFormat", name=None, para_type=Tensor.PARA_NONE):
"""Create a new Tensor"""
if name in (None, ''):
name = self._alloc_tensor_name()
if not shape:
shape = [1]
return Tensor(name, shape, dtype, data_format, para_type=para_type)
def value(self, dtype, value, data_format, name=None):
"""Create a new Value"""
if name in (None, ''):
name = self._alloc_tensor_name()
return Value(name, dtype, value, data_format)
def op(self, prim, output, inputs, attrs=None):
"""Insert an operator into graph"""
if attrs is None:
attrs = {}
if isinstance(inputs, Tensor):
inputs = [inputs]
tensor_inputs = [t for t in inputs if isinstance(t, Tensor)]
node = Operator(prim, tensor_inputs, output, attrs)
node.all_inputs = inputs
self.current.graph.add(node)
def emit(self, prim, inputs, name=None, attrs=None):
"""Emit a new operation"""
if attrs is None:
attrs = {}
if isinstance(inputs, Tensor):
inputs = [inputs]
tensor_inputs = [t for t in inputs if isinstance(t, Tensor)]
out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs)
output = self.tensor(out_shape, out_dtype, out_format, name)
self.op(prim, output, inputs, attrs)
return output
def get(self):
return self.graphs
class CompositeGraph:
"""Composite Graph"""
def __init__(self):
self.graph = None
self.desc = None
self.tensors = {} # name : Tensor
def refine(self):
"""Refine Graph"""
AlignShape().visit_graph(self.graph)
AddControlBuddy().visit_graph(self.graph)
def load(self, desc):
"""Load Graph from json"""
def _attr_of(op, inputs, output):
attr = {}
if op['name'] not in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
return attr
for a in op['attr']:
if a['name'] == 'axis':
red_axis, dim_size = [], len(inputs[0].shape)
if not a['value']:
assert len(output.shape) == len(inputs[0].shape)
for i in range(len(output.shape)):
if output.shape[i] == 1 and inputs[0].shape[i] > 1:
red_axis.append(i)
else:
for i in a['value']:
red_axis.append(i if i >= 0 else dim_size + i)
attr['reduce_axis'] = red_axis
break
return attr
builder = GraphBuilder()
with builder.graph_scope(desc['op']):
for in_desc in desc['input_desc']:
name, shape, dtype, data_format = in_desc[0]['tensor_name'], in_desc[
0]['shape'], in_desc[0]['data_type'], in_desc[0]['format']
self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_INPUT)
for out_desc in desc['output_desc']:
name, shape, dtype, data_format = out_desc['tensor_name'], out_desc[
'shape'], out_desc['data_type'], out_desc['format']
self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
cur_fusion = None
for op in desc['op_desc']:
inputs = [self.tensors[d[0]['tensor_name']]
for d in op['input_desc'] if 'value' not in d[0]]
out_desc = op['output_desc']
name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[
0]['shape'], out_desc[0]['data_type'], out_desc[0]['format']
if op['name'] == 'InplaceAssign':
inputs[0].add_buddy(inputs[1])
inputs[1].para_type = Tensor.PARA_OUTPUT
output = inputs[2]
self.tensors[name] = output
else:
output = self.tensors.get(name, None)
if not output:
output = builder.tensor(
shape, dtype, data_format, name=name)
self.tensors[name] = output
builder.op(op['name'], output, inputs,
attrs=_attr_of(op, inputs, output))
if 'fusion' in op:
if cur_fusion is None:
cur_fusion = output
else:
cur_fusion.add_buddy(output)
if op['fusion'].endswith('_end'):
cur_fusion = None
self.graph = builder.get()[0]
self.desc = desc
def dump(self, subgraph):
"""Dump Graph to json"""
desc = {}
inputs, outputs = subgraph.deduce_parameters()
graph_ops = set(subgraph.ops)
inplace_assign = {} # y_name, output_name
inplace_assign_z = None
for op in self.desc['op_desc']:
if op['name'] == 'InplaceAssign':
inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name']
if inplace_assign:
for t in outputs:
if t.name not in inplace_assign:
inplace_assign_z = t
for key in self.desc:
if key == 'input_desc':
desc[key] = [
[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
elif key == 'output_desc':
out_desc = []
for t in outputs:
if t.name in inplace_assign:
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
out_desc.append(
{'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]})
else:
out_desc.append(
{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name})
desc[key] = out_desc
elif key == 'op_desc':
op_desc = []
for d in self.desc[key]:
if d['name'] == 'InplaceAssign':
y = d['input_desc'][1][0]['tensor_name']
if self.tensors[y].op in graph_ops:
z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (
self.tensors[y], True)
inplace_desc = copy.deepcopy(d)
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
z_desc, out_desc = inplace_desc['input_desc'][2][0].inplace_desc['output_desc'][0]
z_desc['shape'] = z.shape
z_desc['data_type'] = z.dtype
z_desc['tensor_name'] = z.name
out_desc['shape'] = z.shape
out_desc['data_type'] = z.dtype
op_desc.append(inplace_desc)
else:
op = self.tensors[d['output_desc'][0]['tensor_name']].op
if op in graph_ops:
op_desc.append(d)
desc[key] = op_desc
elif key == 'op':
desc[key] = subgraph.name
else:
desc[key] = self.desc[key]
return desc
def load_composite(desc):
"""Load composite kernel"""
composite = CompositeGraph()
composite.load(desc)
composite.refine()
return composite

View File

@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GraphKernel splitter"""
import json
import json.decoder as jd
import traceback
from mindspore import log as logger
from . import model
def split_with_json(json_str: str):
"""Call costmodel to split GraphKernel"""
try:
graph_desc = json.loads(json_str)
comp = model.load_composite(graph_desc)
graph_split = model.split(comp.graph)
is_multi_graph = len(graph_split) > 1
graph_list = list(map(comp.dump, graph_split))
result = {"multi_graph": is_multi_graph, "graph_desc": graph_list}
return json.dumps(result)
except jd.JSONDecodeError:
logger.error(traceback.format_exc())
return None

View File

@ -0,0 +1,17 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
PYTHONPATH="$(pwd)/..:${PYTHONPATH}"
export PYTHONPATH

View File

@ -0,0 +1,142 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""graph kernel split"""
import json
import getopt
import sys
import model
def print_usage():
print('Usage: graph_kernel_split.py [OPTION] <JSON_FILE>')
print('Options:')
print(' -s <config/auto>\tsplit graph with config')
print(' -e \t\testimate graph')
print(' -i \t\tnaive estimate')
print(' -o <prefix>\toutput split graphs')
print(' -v \t\tverbose mode')
print(' -h \t\tprint this help')
print('Report bugs to xiong.gao@huawei.com')
class Option:
"""Options"""
def __init__(self):
self.split = None
self.estimate = False
self.estimate_naive = False
self.output = None
self.verbose = False
self.help = False
def parse(self, options):
"""parse options"""
for name, val in options:
if name == '-h':
self.help = True
elif name == '-v':
self.verbose = True
elif name == '-o':
self.output = val
elif name == '-e':
self.estimate = True
elif name == '-s':
self.split = val
elif name == '-i':
self.estimate_naive = True
opt = Option()
def estimate(graph_in, parts_in, naive):
"""estimate graphs costs"""
def _print_cost(name, c):
print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" %
(name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type()))
main_cost, _ = model.estimate(graph_in, naive)
split_cost, sub_costs = model.estimate(parts_in, naive) if parts_in else (None, None)
_print_cost("MainGraph:", main_cost)
if parts_in:
_print_cost("Subgraphs:", split_cost)
if opt.verbose:
for i, sub_cost in enumerate(sub_costs):
_print_cost(" |_%d:\t" % (i), sub_cost)
def split_graph(graph_in, config):
"""split graph"""
if config == 'auto':
return model.split(graph_in)
subgraphs = []
all_tensors = []
subgraph_idx = 0
config_parts = config.split('|')
for part in config_parts:
tensor_names = part.split(',')
graph_name = "%s_%d" % (graph_in.name, subgraph_idx)
g = graph_in.extract_subgraph(graph_name, tensor_names)
assert len(g.ops) == len(tensor_names)
subgraphs.append(g)
all_tensors += tensor_names
subgraph_idx += 1
if len(all_tensors) < len(graph_in.ops):
graph_name = "%s_%d" % (graph_in.name, subgraph_idx)
g = graph_in.extract_subgraph(graph_name, all_tensors, True)
subgraphs.append(g)
return subgraphs
def main():
opts, args = getopt.getopt(sys.argv[1:], 'heivo:s:')
opt.parse(opts)
if len(args) != 1 or opt.help:
print_usage()
sys.exit(0)
in_file = args[0]
with open(in_file, 'r') as f:
desc = json.loads(f.read())
comp = model.load_composite(desc)
graph = comp.graph
parts = []
# 1. split sub-graphs
if opt.split is not None:
parts = split_graph(graph, opt.split)
if opt.verbose:
print('----------- main graph --------------')
print(graph)
for i, _ in enumerate(parts):
print('---------------- sub graph %d ---------------' % (i))
print(parts[i])
# 2. estimate cost
if opt.estimate:
print('------------- cost --------------')
estimate(graph, parts, False)
if opt.estimate_naive:
print('------------- naive cost --------------')
estimate(graph, parts, True)
# 3. output parts
if opt.output is not None:
for graph_part in parts:
desc = comp.dump(graph_part)
s_desc = json.dumps(desc)
fname = "%s_%s.json" % (opt.output, graph_part.name)
with open(fname, 'w', encoding='utf-8') as of:
of.write(s_desc)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,53 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""test split"""
import model
def graph_1():
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a = gb.tensor([1024, 16], "float32", name="a")
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", b, 'c')
d = gb.emit("Abs", c, 'd')
gb.emit("TensorAdd", [b, d], "e")
return gb.get()[0]
def graph_2():
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a = gb.tensor([1024, 16], "float32", name="a")
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", b, 'c')
d = gb.emit("ReduceSum", c, 'd', attrs={'reduce_axis': (1,)})
gb.emit("Sqrt", d, 'e')
return gb.get()[0]
def test_split_by_pattern():
def _test(graph):
print("***************** main graph ***************")
print(graph)
subgraphs = model.split(graph)
for i, g in enumerate(subgraphs):
print('------------- subgraph {} --------------'.format(i))
print(g)
_test(graph_2())
if __name__ == '__main__':
test_split_by_pattern()

View File

@ -71,7 +71,8 @@ if(ENABLE_GPU)
"runtime/device/gpu/*.cu" "runtime/device/gpu/*.cu"
"backend/kernel_compiler/gpu/*.cu" "backend/kernel_compiler/gpu/*.cu"
"backend/kernel_compiler/akg/gpu/*.cc" "backend/kernel_compiler/akg/gpu/*.cc"
"backend/kernel_compiler/akg/akg_kernel_build.cc" "backend/kernel_compiler/akg/akg_kernel_json_generator.cc"
"backend/kernel_compiler/akg/akg_kernel_json_decoder.cc"
"backend/kernel_compiler/akg/akg_kernel_attrs_process.cc" "backend/kernel_compiler/akg/akg_kernel_attrs_process.cc"
) )

View File

@ -10,7 +10,8 @@ if (ENABLE_D)
"kernel_query.cc" "kernel_query.cc"
"kernel_fusion.cc" "kernel_fusion.cc"
"akg/ascend/*.cc" "akg/ascend/*.cc"
"akg/akg_kernel_build.cc" "akg/akg_kernel_json_generator.cc"
"akg/akg_kernel_json_decoder.cc"
"akg/akg_kernel_attrs_process.cc" "akg/akg_kernel_attrs_process.cc"
"akg/akg_kernel_metadata.cc" "akg/akg_kernel_metadata.cc"
"tbe/*.cc" "tbe/*.cc"
@ -49,7 +50,8 @@ if (ENABLE_GPU)
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"gpu/*.cu" "gpu/*.cu"
"akg/gpu/*.cc" "akg/gpu/*.cc"
"akg/akg_kernel_build.cc" "akg/akg_kernel_json_generator.cc"
"akg/akg_kernel_json_decoder.cc"
"akg/akg_kernel_attrs_process.cc" "akg/akg_kernel_attrs_process.cc"
) )

View File

@ -24,7 +24,6 @@
#include <climits> #include <climits>
#include "runtime/device/kernel_runtime.h" #include "runtime/device/kernel_runtime.h"
#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" #include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h"
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
#include "proto/tensor.pb.h" #include "proto/tensor.pb.h"
#include "proto/tensor_shape.pb.h" #include "proto/tensor_shape.pb.h"
#include "proto/attr.pb.h" #include "proto/attr.pb.h"
@ -33,6 +32,7 @@
#include "backend/kernel_compiler/aicpu/aicpu_util.h" #include "backend/kernel_compiler/aicpu/aicpu_util.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/oplib/oplib.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {

View File

@ -15,13 +15,20 @@
*/ */
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" #include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
#include <algorithm> #include <algorithm>
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "base/core_ops.h"
#include "utils/utils.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
namespace {
void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node) { void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
// The x and output are akg op input and output param. // The x and output are akg op input and output param.
@ -169,5 +176,29 @@ void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) {
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_input_names), anf_node); AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_input_names), anf_node);
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node); AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node);
} }
const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_node)>> kAkgKernelAttrsProcessMap = {
{kFour2FiveOpName, SetAkgAttrsForFour2Five},
{kFive2FourOpName, SetAkgAttrsForFive2Four},
{kCastOpName, SetAkgAttrsForCast},
{kBNGrad1OpName, SetAkgAttrsForBNGrad1},
{kBNGrad2OpName, SetAkgAttrsForBNGrad2},
{kBNGrad3OpName, SetAkgAttrsForBNGrad3},
{kFusedBN1OpName, SetAkgAttrsForFusedBN1},
{kFusedBN2OpName, SetAkgAttrsForFusedBN2},
{kFusedBN3OpName, SetAkgAttrsForFusedBN3},
{kConvBN1OpName, SetAkgAttrsForConvBN1},
{kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu},
{kBN2ReLUOpName, SetAkgAttrsForBN2Relu},
};
} // namespace
void SetAkgKernelAttrs(const AnfNodePtr &anf_node) {
auto it = kAkgKernelAttrsProcessMap.find(AnfAlgo::GetCNodeName(anf_node));
if (it != kAkgKernelAttrsProcessMap.end()) {
it->second(anf_node);
}
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -16,43 +16,13 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H
#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
#include "ir/anf.h" #include "ir/anf.h"
#include "utils/utils.h"
#include "base/core_ops.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node);
void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node);
void SetAkgAttrsForCast(const AnfNodePtr &anf_node);
void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node);
void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node);
void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node);
void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node);
void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node);
void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node);
void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node);
void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node);
void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node);
const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_node)>> kAkgKernelAttrsProcessMap = { void SetAkgKernelAttrs(const AnfNodePtr &anf_node);
{kFour2FiveOpName, SetAkgAttrsForFour2Five},
{kFive2FourOpName, SetAkgAttrsForFive2Four},
{"Cast", SetAkgAttrsForCast},
{kBNGrad1OpName, SetAkgAttrsForBNGrad1},
{kBNGrad2OpName, SetAkgAttrsForBNGrad2},
{kBNGrad3OpName, SetAkgAttrsForBNGrad3},
{kFusedBN1OpName, SetAkgAttrsForFusedBN1},
{kFusedBN2OpName, SetAkgAttrsForFusedBN2},
{kFusedBN3OpName, SetAkgAttrsForFusedBN3},
{kConvBN1OpName, SetAkgAttrsForConvBN1},
{kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu},
{kBN2ReLUOpName, SetAkgAttrsForBN2Relu},
};
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_ATTRS_PROCESS_H

View File

@ -1,573 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
#include <unistd.h>
#include <dirent.h>
#include <memory>
#include <map>
#include <utility>
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <unordered_set>
#include "utils/convert_utils.h"
#include "utils/any.h"
#include "utils/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
#include "backend/session/kernel_build_client.h"
namespace mindspore {
namespace kernel {
// json key
constexpr auto kOpDesc = "op_desc";
constexpr auto kInputDesc = "input_desc";
constexpr auto kShape = "shape";
constexpr auto kDataType = "data_type";
constexpr auto kOutputDesc = "output_desc";
constexpr auto kName = "name";
constexpr auto kTensorName = "tensor_name";
constexpr auto kValue = "value";
constexpr auto KDynInputSizes = "dyn_input_sizes";
constexpr auto KInputNames = "input_names";
constexpr auto KInput = "input";
constexpr auto KDtype = "dtype";
namespace {
template <typename T>
std::string Vector2Str(const std::vector<T> &inputs) {
if (!inputs.empty()) {
std::ostringstream oss;
(void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator<T>(oss, ", "));
oss << inputs.back();
return oss.str();
}
return "";
}
} // namespace
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position) {
if (node_json.count(tag) == 0) {
MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "].";
return "";
}
auto const &tag_desc = node_json[tag];
nlohmann::json first_index;
if (tag == kOutputDesc) {
first_index = tag_desc;
} else if (!tag_desc.is_array() || tag_desc.size() <= position.first) {
MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "].";
return "";
} else {
first_index = tag_desc[position.first];
}
if (!first_index.is_array() || first_index.size() <= position.second) {
MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "].";
return "";
}
auto const &second_index = first_index[position.second];
if (second_index.count(kTensorName) == 0) {
MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "].";
return "";
}
return second_index[kTensorName];
}
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
nlohmann::json *const node_json) {
MS_EXCEPTION_IF_NULL(node_json);
if (node_json->count(tag) == 0) {
MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "].";
return;
}
nlohmann::json *tag_desc = &((*node_json)[tag]);
nlohmann::json *first_index;
if (tag == kOutputDesc) {
first_index = tag_desc;
} else if (!tag_desc->is_array() || tag_desc->size() <= position.first) {
MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "].";
return;
} else {
first_index = &((*tag_desc)[position.first]);
}
if (!first_index->is_array() || first_index->size() <= position.second) {
MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "].";
return;
}
nlohmann::json *second_index = &((*first_index)[position.second]);
if (second_index->count(kTensorName) == 0) {
MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "].";
return;
}
(*second_index)[kTensorName] = new_name;
return;
}
int AkgKernelBuild::op_cnt_ = 0;
std::mutex AkgKernelBuild::op_cnt_mtx_;
std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string device;
switch (AnfAlgo::GetProcessor(anf_node)) {
case Processor::AICORE:
device = kProcessorAiCore;
break;
case Processor::AICPU:
device = kProcessorAiCpu;
break;
case Processor::CUDA:
device = kProcessorCuda;
break;
default:
MS_LOG(ERROR) << "Unknown processor type.";
break;
}
return device;
}
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size) {
if (input_size == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "input size or output size is nullptr";
return false;
}
input_size->clear();
output_size->clear();
for (size_t i = 0; i < node_json[kInputDesc].size(); i++) {
for (size_t m = 0; m < node_json[kInputDesc][i].size(); m++) {
std::string dtype = node_json[kInputDesc][i][m][kDataType];
size_t nbyte = GetDtypeNbyte(dtype);
size_t size_i = std::accumulate(node_json[kInputDesc][i][m][kShape].begin(),
node_json[kInputDesc][i][m][kShape].end(), nbyte, std::multiplies<size_t>());
input_size->push_back(size_i);
}
}
for (size_t i = 0; i < node_json[kOutputDesc].size(); i++) {
std::string dtype = node_json[kOutputDesc][i][kDataType];
size_t nbyte = GetDtypeNbyte(dtype);
size_t size_i = std::accumulate(node_json[kOutputDesc][i][kShape].begin(), node_json[kOutputDesc][i][kShape].end(),
nbyte, std::multiplies<size_t>());
output_size->push_back(size_i);
}
return true;
}
int AkgKernelBuild::GetOpCntInc() {
op_cnt_mtx_.lock();
int cnt = op_cnt_++;
op_cnt_mtx_.unlock();
return cnt;
}
bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(inputs_json);
// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
if (op_info == nullptr) {
MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op_info is nullptr";
return false;
}
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr = op_info->inputs_ptr();
if (inputs_ptr.empty()) {
MS_LOG(INFO) << "Apply kernel [" << op_name << "] regist info has no input info";
return true;
}
auto op_info_input_num = inputs_ptr.size();
// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
std::vector<int> dyn_input_sizes;
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
}
size_t real_input_index = 0;
std::vector<nlohmann::json> input_list;
for (size_t i = 0; i < op_info_input_num; i++) {
size_t input_tensor_num;
std::shared_ptr<OpIOInfo> input_ptr = inputs_ptr[i];
std::string op_input_name;
if (input_ptr == nullptr) {
MS_LOG(ERROR) << "Apply kernel [" << op_name << "] regist input[" << i << "] is nullptr";
return false;
}
op_input_name = input_ptr->name();
if (dyn_input_sizes.empty()) {
input_tensor_num = 1;
} else {
input_tensor_num = IntToSize(dyn_input_sizes[i]);
}
input_list.clear();
for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
// dtype : float16
auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index);
std::string dtype = TypeId2String(type_id);
if (dtype.empty()) {
MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. ";
return false;
}
nlohmann::json input_desc_json;
input_desc_json[kDataType] = dtype;
input_desc_json[kName] = op_input_name;
input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) &&
GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2)
<< "] as const tensor, shape: [" << Vector2Str(input_shape)
<< "], value: " << input_desc_json[kValue];
input_shape.clear();
}
if (input_shape.empty()) {
input_shape.push_back(1);
}
input_desc_json[kShape] = input_shape;
input_list.emplace_back(input_desc_json);
real_input_index++;
}
inputs_json->emplace_back(input_list);
}
return true;
}
bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(outputs_json);
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(anf_node);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
auto outputs = op_info_ptr->outputs_ptr();
for (size_t i = 0; i < output_tensor_num; i++) {
nlohmann::json output_json;
auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i);
std::string dtype = TypeId2String(type_id);
if (dtype.empty()) {
MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. ";
return false;
}
std::string output_name = outputs[i]->name();
output_json[kDataType] = dtype;
output_json[kName] = output_name;
output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc());
output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i);
outputs_json->push_back(output_json);
}
return true;
}
void GetJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes,
const std::shared_ptr<OpAttr> &op_attr, nlohmann::json *const attr_json, const ValuePtr &attr_value) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(op_attr);
MS_EXCEPTION_IF_NULL(attr_json);
std::string type = op_attr->type();
if (type == "int") {
(*attr_json)[kValue] = GetValue<int>(attr_value);
} else if (type == "str") {
(*attr_json)[kValue] = GetValue<std::string>(attr_value);
} else if (type == "bool") {
(*attr_json)[kValue] = GetValue<bool>(attr_value);
} else if (type == "float") {
(*attr_json)[kValue] = GetValue<float>(attr_value);
} else if (type == "listInt") {
(*attr_json)[kValue] = GetValue<std::vector<int>>(attr_value);
} else if (type == "listStr") {
std::vector<std::string> data_format;
if (op_attr->name() == kArgDataformat) {
size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
for (size_t format_i = 0; format_i < tensor_args_num; format_i++) {
auto input_format = AnfAlgo::GetInputFormat(anf_node, format_i);
data_format.push_back(input_format);
}
} else {
data_format = GetValue<std::vector<std::string>>(attr_value);
}
(*attr_json)[kValue] = data_format;
} else {
MS_LOG(WARNING) << "attr type:" << type;
}
}
bool AkgKernelBuild::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name,
const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(attrs_json);
MS_EXCEPTION_IF_NULL(op_info);
std::vector<std::shared_ptr<OpAttr>> attrs = op_info->attrs_ptr();
if (attrs.empty()) {
MS_LOG(INFO) << "Apply kernel [" << op_name << "] op info attrs is empty";
return true;
}
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info->inputs_ptr();
std::vector<int> dyn_input_sizes;
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
}
if (inputs.empty()) {
MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op info inputs is empty";
return false;
}
// create input name list for atch "x_shape" in att with "x" in primitive.
std::map<size_t, std::string> op_info_shape_name;
for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) {
std::string input_name = inputs[op_info_input_i]->name();
std::string x_shape_name = input_name + "_shape";
(void)op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name));
}
for (const auto &op_attr : attrs) {
nlohmann::json attr_json;
ValuePtr attr_value = primitive->GetAttr(op_attr->name());
if (attr_value == nullptr && op_attr->name() != kArgDataformat) {
if (op_attr->param_type() == "required") {
// match "x_shape" in att with "x" in primitive.
std::string attr_name = op_attr->name();
auto find_item = std::find_if(
op_info_shape_name.begin(), op_info_shape_name.end(),
[attr_name](const std::map<size_t, std::string>::value_type item) { return item.second == attr_name; });
if (find_item != op_info_shape_name.end()) {
if (!dyn_input_sizes.empty()) {
if (find_item->first >= dyn_input_sizes.size() - 1) {
MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first
<< " is out of range:" << dyn_input_sizes.size() - 1 << ".";
return false;
}
size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0));
for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) {
attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx);
attr_json[kName] = op_attr->name();
attrs_json->push_back(attr_json);
tensor_idx++;
}
} else {
attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first);
attr_json[kName] = op_attr->name();
attrs_json->push_back(attr_json);
}
} else {
MS_LOG(ERROR) << "op [" << op_name << "] should have attr :" << op_attr->name();
return false;
}
}
continue;
}
GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value);
attr_json[kName] = op_attr->name();
attrs_json->push_back(attr_json);
}
return true;
}
bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name,
nlohmann::json *const node_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(node_json);
int op_cnt = GetOpCntInc();
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
MS_EXCEPTION_IF_NULL(op_info_ptr);
// get basic params from currentNodeOpDesc
(*node_json)[kName] = op_name;
(*node_json)["impl_path"] = op_info_ptr->impl_path();
(*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node);
(*node_json)["composite"] = false;
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
ValuePtr input_names_v = primitive->GetAttr(KInputNames);
if (input_names_v == nullptr) {
MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "].";
return false;
}
std::vector<std::string> prim_input_names = GetValue<const std::vector<std::string>>(input_names_v);
std::string inputs_name;
for (const auto &prim_input_name : prim_input_names) {
(void)inputs_name.append("_input_").append(prim_input_name).append("_");
}
// input desc
nlohmann::json inputs_json;
if (!CreateInputDescJson(anf_node, &inputs_json)) {
MS_LOG(ERROR) << "Create input desc json failed, op[" << op_name << "].";
return false;
}
(*node_json)[kInputDesc] = inputs_json;
MS_LOG(INFO) << "Akg create input desc json success.";
std::string inputs_shape = "inputs_shape_";
for (auto &i : inputs_json) {
for (auto &m : i) {
std::string data_type = m[kDataType];
(void)inputs_shape.append("_").append(data_type).append("_");
for (auto &j : m[kShape]) {
size_t n = j;
(void)inputs_shape.append(std::to_string(n)).append("_");
}
}
}
// output desc
nlohmann::json outputs_json;
if (!CreateOutputDescJson(anf_node, &outputs_json)) {
MS_LOG(ERROR) << "Create output desc json failed, op[" << op_name << "].";
return false;
}
(*node_json)[kOutputDesc] = outputs_json;
MS_LOG(INFO) << "Akg create output desc json success.";
std::string outputs_shape = "outputs_shape_";
for (auto &i : outputs_json) {
std::string data_type = i[kDataType];
(void)outputs_shape.append("_").append(data_type).append("_");
for (auto &j : i[kShape]) {
size_t m = j;
(void)outputs_shape.append(std::to_string(m)).append("_");
}
}
// attribute desc
nlohmann::json attrs_json;
if (!CreateAttrDescJson(anf_node, op_name, op_info_ptr, &attrs_json)) {
MS_LOG(ERROR) << "Create attr desc json failed, op[" << op_name << "].";
return false;
}
(*node_json)["attr"] = attrs_json;
std::string json_str = node_json->dump();
size_t hash_id = std::hash<std::string>()(json_str);
json_name_ = op_name + "_";
(void)json_name_.append(std::to_string(hash_id));
MS_LOG(INFO) << "full scope name is : " << anf_node->fullname_with_scope() << ", json info name is : " << json_name_;
json_info_ = json_str;
(*node_json)["id"] = op_cnt;
(*node_json)["op"] = json_name_;
MS_LOG(INFO) << "Akg create node desc json success.";
return true;
}
KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
auto processor = AkgKernelBuild::GetProcessor(anf_node);
auto cached_kernel_pack = SearchCache(json_name_, processor);
if (cached_kernel_pack != nullptr) {
MS_LOG(INFO) << "Use cached kernel, json_name_[" << json_name_ << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return cached_kernel_pack;
}
(void)alarm(AUTODIFF_COMPILE_OVERTIME);
auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(node_json);
(void)alarm(0);
if (!res) {
MS_LOG(ERROR) << "Akg compile failed, json: " << node_json;
return nullptr;
}
auto new_kernel_pack = InsertCache(json_name_, processor);
kernel::SaveJsonInfo(json_name_, json_info_);
if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name_ << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return nullptr;
}
return new_kernel_pack;
}
KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
auto it = kAkgKernelAttrsProcessMap.find(op_name);
if (it != kAkgKernelAttrsProcessMap.end()) {
it->second(anf_node);
}
MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
nlohmann::json node_json;
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed.";
}
std::string json_str = node_json.dump();
auto kernel_pack = OpBuild(json_str, anf_node);
if (kernel_pack == nullptr) {
MS_LOG(ERROR) << "Akg build failed op[" << op_name << "], json:" << json_str;
return nullptr;
}
if (!GetIOSize(node_json, input_size, output_size)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return nullptr;
}
MS_LOG(INFO) << "Akg compile success, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node)
<< "]";
return kernel_pack;
}
size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
<< cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]";
}
auto input_node = cnode->input(input_idx + 1);
if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) {
size_t index = input_tensor_idx_.size();
input_tensor_idx_[input_node] = index;
}
return input_tensor_idx_[input_node];
}
size_t AkgKernelBuild::GetOutputTensorIdxInc() {
size_t idx = output_tensor_idx_++;
return idx;
}
} // namespace kernel
} // namespace mindspore

View File

@ -1,76 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_
#include <unordered_map>
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <utility>
#include "backend/kernel_compiler/kernel.h"
#include "ir/dtype.h"
#include "ir/primitive.h"
#include <nlohmann/json.hpp>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/oplib/oplib.h"
namespace mindspore {
namespace kernel {
class AkgKernelBuild {
public:
AkgKernelBuild() {
input_tensor_idx_ = {};
output_tensor_idx_ = 0;
}
~AkgKernelBuild() = default;
KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size);
static std::string GetProcessor(const AnfNodePtr &anf_node);
protected:
bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json);
bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json);
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name,
const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json);
KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node);
int GetOpCntInc();
size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
size_t GetOutputTensorIdxInc();
bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name,
nlohmann::json *const node_json);
static int op_cnt_;
// lock for variable fusionOpCnt in singleton mode
static std::mutex op_cnt_mtx_;
std::string json_name_;
std::string json_info_;
std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_;
size_t output_tensor_idx_;
};
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size);
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
nlohmann::json *const node_json);
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKGKERNELBUILD_H_

View File

@ -0,0 +1,415 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
#include <string>
#include <memory>
#include <vector>
#include <sstream>
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/meta_tensor.h"
#include "ir/manager.h"
#include "ir/dtype.h"
#include "frontend/operator/ops.h"
#include "utils/convert_utils.h"
#include "utils/convert_utils_py.h"
#include "utils/utils.h"
#include "ir/graph_utils.h"
#include "runtime/device/kernel_info.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "debug/anf_ir_dump.h"
namespace mindspore {
namespace kernel {
namespace {
ValuePtr ParseValue(const nlohmann::json &attr_json, const std::string &type) {
if (type == "str") {
std::string value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "int") {
int value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "bool") {
bool value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "float") {
float value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "listInt") {
std::vector<int> value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else if (type == "listStr") {
std::vector<std::string> value = attr_json[kJsonKeyValue];
return MakeValue(value);
} else {
MS_LOG(ERROR) << "Unknown type of attr: " << type << ", json: \n" << attr_json;
return nullptr;
}
}
bool DecodeAttrs(const nlohmann::json &attrs_json, std::map<std::string, ValuePtr> *attrs) {
MS_EXCEPTION_IF_NULL(attrs);
MS_LOG(DEBUG) << "start decode attrs, " << attrs_json;
// decode attrs.
if (attrs_json.find(kJsonKeyAttr) == attrs_json.end() || attrs_json[kJsonKeyAttr].is_null()) {
// attrs maybe empty
return true;
}
std::vector<nlohmann::json> attr_descs = attrs_json[kJsonKeyAttr];
for (const auto &attr_desc : attr_descs) {
std::string name = attr_desc[kJsonKeyName];
std::string type = attr_desc[kJsonKeyDataType];
auto value = ParseValue(attr_desc, type);
if (value == nullptr) {
return false;
}
(*attrs)[name] = value;
}
return true;
}
// python utils.
constexpr auto kGetPythonOpFunc = "_get_python_op";
constexpr auto kParallelUtilsModule = "mindspore.parallel._utils";
// almost all ops are defined in this path.
constexpr auto kOperationsModule = "mindspore.ops.operations";
const std::map<std::string, std::vector<std::string>> op_attrs_map = {
{kReduceSumOpName, std::vector<std::string>{kAttrKeepDims}},
{kReduceMaxOpName, std::vector<std::string>{kAttrKeepDims}},
{kReduceMinOpName, std::vector<std::string>{kAttrKeepDims}},
};
ValuePtr CreatOpInstance(const std::string &op_name, const std::vector<ValuePtr> &attrs) {
py::module mod = py::module::import(kOperationsModule);
if (!py::hasattr(mod, op_name.c_str())) {
MS_LOG(ERROR) << kOperationsModule << " don't have attr: " << op_name;
return nullptr;
}
std::vector<py::object> arg_list;
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
[](const ValuePtr &attr) { return ValuePtrToPyData(attr); });
py::object obj = parse::python_adapter::CallPyFn(kParallelUtilsModule, kGetPythonOpFunc, op_name, kOperationsModule,
op_name, arg_list);
ValuePtr op_instance = nullptr;
bool succ = parse::ConvertData(obj, &op_instance);
if (!succ) {
MS_LOG(ERROR) << "Get python op " << op_name << " from " << kOperationsModule << " failed.";
return nullptr;
}
return op_instance;
}
PrimitivePtr GetPrimitive(const std::string &op_name, const std::map<std::string, ValuePtr> &attrs_val) {
PrimitivePtr primitive{nullptr};
if (op_attrs_map.count(op_name) == 0) {
// no attrs for op instance.
primitive = CreatOpInstance(op_name, std::vector<ValuePtr>{})->cast<PrimitivePtr>();
} else {
// make attrs for op instance.
std::vector<ValuePtr> op_attrs;
const auto &attr_names = op_attrs_map.at(op_name);
for (const auto &attr_name : attr_names) {
if (attrs_val.count(attr_name) == 0) {
MS_LOG(ERROR) << "Attr: " << attr_name << " for: " << op_name << " not found.";
return nullptr;
}
op_attrs.push_back(attrs_val.at(attr_name));
}
primitive = CreatOpInstance(op_name, op_attrs)->cast<PrimitivePtr>();
}
if (primitive != nullptr) {
for (const auto &attr : attrs_val) {
primitive->AddAttr(attr.first, attr.second);
}
}
return primitive;
}
} // namespace
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
ScalarPtr AkgKernelJsonDecoder::DecodeScalar(const nlohmann::json &scalar_json) {
auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]);
switch (type_id) {
case kNumberTypeFloat16:
case kNumberTypeFloat32:
return std::make_shared<FP32Imm>(scalar_json[kJsonKeyValue]);
case kNumberTypeInt32:
return std::make_shared<Int32Imm>(scalar_json[kJsonKeyValue]);
default:
MS_LOG(ERROR) << "Unknown type: " << scalar_json[kJsonKeyDataType];
break;
}
return nullptr;
}
ValueNodePtr AkgKernelJsonDecoder::DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph) {
MS_LOG(DEBUG) << "start decode value node, " << value_json;
auto scalar = DecodeScalar(value_json);
auto tensor = ScalarToTensor(scalar);
auto value_node = std::make_shared<ValueNode>(tensor);
value_node->set_abstract(tensor->ToAbstract());
// create kernel_info fo new value node.
auto kernel_info = std::make_shared<device::KernelInfo>();
value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node.
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// layout info.
builder->SetOutputsFormat(std::vector<std::string>{value_json[kJsonKeyFormat]});
builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(value_json[kJsonKeyDataType])});
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get());
func_graph->AddValueNode(value_node);
MS_LOG(DEBUG) << "decode value node success, " << value_node->DebugString(2);
return value_node;
}
ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json &parameter_json,
const FuncGraphPtr &func_graph) {
MS_LOG(DEBUG) << "start decode parameter, " << parameter_json;
ParameterPtr new_parameter = func_graph->add_parameter();
std::string name = parameter_json[kJsonKeyTensorName];
new_parameter->set_name(name);
auto kernel_info = std::make_shared<device::KernelInfo>();
new_parameter->set_kernel_info(kernel_info);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
builder->SetOutputsFormat(std::vector<std::string>{parameter_json[kJsonKeyFormat]});
builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(parameter_json[kJsonKeyDataType])});
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), new_parameter.get());
nodes_map_[name] = new_parameter;
return new_parameter;
}
CNodePtr AkgKernelJsonDecoder::DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph,
const std::string &processor) {
Processor p = kernel::GetProcessor(processor);
MS_LOG(DEBUG) << "start decode cnode, " << cnode_json;
// decode attrs.
std::map<std::string, ValuePtr> cnode_attrs;
if (!DecodeAttrs(cnode_json, &cnode_attrs)) {
MS_LOG(ERROR) << "Error decode attrs.";
return nullptr;
}
std::string op_name = cnode_json[kJsonKeyName];
// new primitive.
auto primitive = GetPrimitive(op_name, cnode_attrs);
if (primitive == nullptr) {
MS_LOG(ERROR) << "Create primitive failed.";
return nullptr;
}
// data layout info.
std::vector<std::string> input_formats;
std::vector<TypeId> input_types;
std::vector<std::string> output_formats;
std::vector<TypeId> output_types;
// collect inputs.
auto primitive_v = NewValueNode(primitive);
func_graph->AddValueNode(primitive_v);
std::vector<AnfNodePtr> inputs{primitive_v};
std::vector<nlohmann::json> input_descs = cnode_json[kJsonKeyInputDesc];
for (size_t i = 0; i < input_descs.size(); ++i) {
nlohmann::json input_desc = input_descs[i][0];
std::string name = input_desc[kJsonKeyTensorName];
if (input_desc.find(kJsonKeyValue) != input_desc.end()) {
inputs.push_back(DecodeValueNode(input_desc, func_graph));
} else if (nodes_map_.count(name) == 0) {
MS_LOG(ERROR) << "Input: " << name << " of: " << op_name << " not found.";
return nullptr;
} else {
inputs.push_back(nodes_map_[name]);
}
input_formats.push_back(input_desc[kJsonKeyFormat]);
input_types.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType]));
}
MS_LOG(DEBUG) << "decode inputs success.";
// new cnode.
auto cnode = func_graph->NewCNode(inputs);
func_graph->AddNode(cnode);
// decode outputs.
std::vector<nlohmann::json> output_descs = cnode_json[kJsonKeyOutputDesc];
AbstractBasePtr abstract(nullptr);
if (output_descs.empty()) {
MS_LOG(ERROR) << "No outputs found.";
return nullptr;
} else if (output_descs.size() == 1) {
// single output.
nlohmann::json output_desc = output_descs[0];
output_formats.push_back(output_desc[kJsonKeyFormat]);
output_types.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType]));
nodes_map_[output_desc[kJsonKeyTensorName]] = cnode;
} else {
// multi outputs.
for (size_t j = 0; j < output_descs.size(); ++j) {
nlohmann::json output_desc = output_descs[j];
output_formats.push_back(output_desc[kJsonKeyFormat]);
output_types.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType]));
auto get_item = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, NewValueNode(SizeToInt(j))});
func_graph->AddNode(get_item);
nodes_map_[output_desc[kJsonKeyTensorName]] = get_item;
}
}
MS_LOG(DEBUG) << "decode outputs success.";
// create kernel_info.
auto kernel_info = std::make_shared<device::KernelInfo>();
std::vector<size_t> feature_map_input_indexs;
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
for (size_t index = 1; index < inputs.size(); ++index) {
auto node = AnfAlgo::VisitKernel(inputs[index], 0);
if (AnfAlgo::IsFeatureMapOutput(node.first)) {
feature_map_input_indexs.push_back(index);
}
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
}
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true);
}
if (AnfAlgo::IsRealCNodeKernel(cnode)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
}
cnode->set_kernel_info(kernel_info);
// create kernel_build_info.
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
builder->SetInputsFormat(input_formats);
builder->SetInputsDeviceType(input_types);
builder->SetOutputsFormat(output_formats);
builder->SetOutputsDeviceType(output_types);
builder->SetProcessor(p);
builder->SetKernelType(KernelType::AKG_KERNEL);
builder->SetFusionType(kernel::FusionType::OPAQUE);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode.get());
return cnode;
}
FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const nlohmann::json &kernel_json) {
MS_LOG(DEBUG) << "start decode, " << kernel_json;
// clear cache.
nodes_map_.clear();
// create a graph.
auto graph = std::make_shared<FuncGraph>();
// decode parameters.
std::vector<nlohmann::json> input_descs = kernel_json[kJsonKeyInputDesc];
if (input_descs.empty()) {
MS_LOG(ERROR) << "Error decode parameter, no inputs for graph.";
return nullptr;
}
for (size_t i = 0; i < input_descs.size(); ++i) {
std::vector<nlohmann::json> input_desc = input_descs[i];
auto parameter = DecodeParameter(input_desc[0], graph);
if (parameter == nullptr) {
MS_LOG(ERROR) << "Error decode parameter.";
return nullptr;
}
}
MS_LOG(DEBUG) << "decode parameters success.";
// decode cnodes in graph.
std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
if (op_node_descs.empty()) {
MS_LOG(ERROR) << "Error decode cnodes, no cnodes for graph.";
return nullptr;
}
for (const auto &op_desc : op_node_descs) {
auto op_node = DecodeCNode(op_desc, graph, kernel_json[kJsonKeyProcess]);
if (op_node == nullptr) {
MS_LOG(ERROR) << "Error decode cnode.";
return nullptr;
}
}
MS_LOG(DEBUG) << "decode cnodes success.";
// decode outputs of graph.
std::vector<nlohmann::json> output_descs = kernel_json[kJsonKeyOutputDesc];
if (output_descs.empty()) {
MS_LOG(ERROR) << "Error decode outputs, no outputs for graph.";
return nullptr;
}
std::vector<AnfNodePtr> outputs{NewValueNode(prim::kPrimMakeTuple)};
for (const auto &output_desc : output_descs) {
std::string name = output_desc[kJsonKeyTensorName];
if (nodes_map_.count(name) == 0) {
MS_LOG(ERROR) << "Output: " << name << " of graph not found.";
return nullptr;
}
outputs.push_back(nodes_map_[name]);
}
if (outputs.size() == 2) {
graph->set_output(outputs[1]);
} else {
auto output = graph->NewCNode(outputs);
graph->AddNode(output);
graph->set_output(output);
}
MS_LOG(DEBUG) << "decode success, " << kernel_json;
return graph;
}
FuncGraphPtr AkgKernelJsonDecoder::DecodeFusedNodes(const std::string &kernel_json_str) {
auto kernel_json = nlohmann::json::parse(kernel_json_str);
return DecodeFusedNodes(kernel_json);
}
bool AkgKernelJsonDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
const std::map<std::string, AnfNodePtr> &address_node_map,
AnfNodePtrList *res_graphs) {
MS_EXCEPTION_IF_NULL(res_graphs);
MS_LOG(DEBUG) << "start decode, " << kernel_json;
// decode cnodes in graph.
std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
if (op_node_descs.empty()) {
MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
return false;
}
for (const auto &op_desc : op_node_descs) {
if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
return false;
}
std::string ptr_address = op_desc[kJsonKeyPtrAddress];
if (address_node_map.count(ptr_address) == 0) {
MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
return false;
}
res_graphs->push_back(address_node_map.at(ptr_address));
}
MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_DECODER_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_DECODER_H_
#include <string>
#include <vector>
#include <map>
#include <nlohmann/json.hpp>
#include "ir/scalar.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace kernel {
class AkgKernelJsonDecoder {
public:
AkgKernelJsonDecoder() { nodes_map_.clear(); }
~AkgKernelJsonDecoder() = default;
FuncGraphPtr DecodeFusedNodes(const nlohmann::json &kernel_json);
FuncGraphPtr DecodeFusedNodes(const std::string &kernel_json_str);
bool DecodeSplitNodes(const nlohmann::json &kernel_json, const std::map<std::string, AnfNodePtr> &address_node_map,
AnfNodePtrList *res_graphs);
private:
ScalarPtr DecodeScalar(const nlohmann::json &scalar_json);
ValueNodePtr DecodeValueNode(const nlohmann::json &value_json, const FuncGraphPtr &func_graph);
ParameterPtr DecodeParameter(const nlohmann::json &parameter_json, const FuncGraphPtr &func_graph);
CNodePtr DecodeCNode(const nlohmann::json &cnode_json, const FuncGraphPtr &func_graph, const std::string &processor);
std::map<std::string, AnfNodePtr> nodes_map_{};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_DECODER_H_

View File

@ -0,0 +1,630 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include <algorithm>
#include <functional>
#include <map>
#include <sstream>
#include <tuple>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
namespace {
std::vector<int> GetDynInputSize(const AnfNodePtr &anf_node) {
std::vector<int> dyn_input_sizes;
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->HasAttr(kAttrDynInputSizes)) {
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
}
return dyn_input_sizes;
}
} // namespace
int AkgKernelJsonGenerator::op_cnt_ = 0;
std::mutex AkgKernelJsonGenerator::op_cnt_mtx_;
int AkgKernelJsonGenerator::GetOpCntInc() {
op_cnt_mtx_.lock();
int cnt = op_cnt_++;
op_cnt_mtx_.unlock();
return cnt;
}
inline TypeId AkgKernelJsonGenerator::GetInputDataType(const AnfNodePtr &anf_node, size_t real_index) {
return dump_option_.is_before_select_kernel ? AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index)
: AnfAlgo::GetInputDeviceDataType(anf_node, real_index);
}
inline std::vector<size_t> AkgKernelJsonGenerator::GetInputShape(const AnfNodePtr &anf_node, size_t real_index) {
return dump_option_.is_before_select_kernel ? AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index)
: AnfAlgo::GetInputDeviceShape(anf_node, real_index);
}
inline std::string AkgKernelJsonGenerator::GetInputFormat(const AnfNodePtr &anf_node, size_t real_index) {
return dump_option_.is_before_select_kernel ? kOpFormat_DEFAULT : AnfAlgo::GetInputFormat(anf_node, real_index);
}
inline TypeId AkgKernelJsonGenerator::GetOutputDataType(const AnfNodePtr &anf_node, size_t index) {
return dump_option_.is_before_select_kernel ? AnfAlgo::GetOutputInferDataType(anf_node, index)
: AnfAlgo::GetOutputDeviceDataType(anf_node, index);
}
inline std::vector<size_t> AkgKernelJsonGenerator::GetOutputShape(const AnfNodePtr &anf_node, size_t index) {
return dump_option_.is_before_select_kernel ? AnfAlgo::GetOutputInferShape(anf_node, index)
: AnfAlgo::GetOutputDeviceShape(anf_node, index);
}
inline std::string AkgKernelJsonGenerator::GetOutputFormat(const AnfNodePtr &anf_node, size_t index) {
return dump_option_.is_before_select_kernel ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(anf_node, index);
}
bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *const inputs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(op_info);
MS_EXCEPTION_IF_NULL(inputs_json);
// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr = op_info->inputs_ptr();
if (inputs_ptr.empty()) {
MS_LOG(DEBUG) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info";
return true;
}
// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
auto dyn_input_sizes = GetDynInputSize(anf_node);
size_t real_input_index = 0;
std::vector<nlohmann::json> input_list;
for (size_t i = 0; i < inputs_ptr.size(); i++) {
std::shared_ptr<OpIOInfo> input_ptr = inputs_ptr[i];
if (input_ptr == nullptr) {
MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist input[" << i << "] is nullptr";
return false;
}
auto op_input_name = input_ptr->name();
size_t input_tensor_num = dyn_input_sizes.empty() ? 1 : IntToSize(dyn_input_sizes[i]);
input_list.clear();
for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
auto type_id = this->GetInputDataType(anf_node, real_input_index);
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
if (dtype.empty()) {
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] input [" << real_input_index
<< "] data type is null. ";
return false;
}
nlohmann::json input_desc_json;
input_desc_json[kJsonKeyDataType] = dtype;
input_desc_json[kJsonKeyFormat] = this->GetInputFormat(anf_node, real_input_index);
input_desc_json[kJsonKeyName] = op_input_name;
input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
auto input_shape = this->GetInputShape(anf_node, real_input_index);
if (GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
MS_LOG(DEBUG) << "Take input[" << real_input_index << "] of [" << anf_node->DebugString(2)
<< "] as const tensor, shape: [" << Vector2Str(input_shape)
<< "], value: " << input_desc_json[kJsonKeyValue];
input_shape.clear();
}
if (input_shape.empty()) {
input_shape.push_back(1);
}
input_desc_json[kJsonKeyShape] = input_shape;
input_list.emplace_back(input_desc_json);
real_input_index++;
}
inputs_json->emplace_back(input_list);
}
return true;
}
bool AkgKernelJsonGenerator::CreateOutputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *const outputs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(op_info);
MS_EXCEPTION_IF_NULL(outputs_json);
size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(anf_node);
auto outputs = op_info->outputs_ptr();
for (size_t i = 0; i < output_tensor_num; i++) {
nlohmann::json output_json;
auto type_id = this->GetOutputDataType(anf_node, i);
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
if (dtype.empty()) {
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] output [" << i << "] data type is null. ";
return false;
}
std::string output_name = outputs[i]->name();
output_json[kJsonKeyDataType] = dtype;
output_json[kJsonKeyFormat] = this->GetOutputFormat(anf_node, i);
output_json[kJsonKeyName] = output_name;
output_json[kJsonKeyTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc());
output_json[kJsonKeyShape] = this->GetOutputShape(anf_node, i);
outputs_json->push_back(output_json);
}
return true;
}
void AkgKernelJsonGenerator::GetJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes,
const std::shared_ptr<OpAttr> &op_attr, nlohmann::json *const attr_json,
const ValuePtr &attr_value) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(op_attr);
MS_EXCEPTION_IF_NULL(attr_json);
std::string type = op_attr->type();
(*attr_json)[kJsonKeyDataType] = type;
if (type == "int") {
(*attr_json)[kJsonKeyValue] = GetValue<int>(attr_value);
} else if (type == "str") {
(*attr_json)[kJsonKeyValue] = GetValue<std::string>(attr_value);
} else if (type == "bool") {
(*attr_json)[kJsonKeyValue] = GetValue<bool>(attr_value);
} else if (type == "float") {
(*attr_json)[kJsonKeyValue] = GetValue<float>(attr_value);
} else if (type == "listInt") {
(*attr_json)[kJsonKeyValue] = GetValue<std::vector<int>>(attr_value);
} else if (type == "listStr") {
std::vector<std::string> data_format;
if (op_attr->name() == kArgDataformat) {
size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
for (size_t format_i = 0; format_i < tensor_args_num; format_i++) {
auto input_format = this->GetInputFormat(anf_node, format_i);
data_format.push_back(input_format);
}
} else {
data_format = GetValue<std::vector<std::string>>(attr_value);
}
(*attr_json)[kJsonKeyValue] = data_format;
} else {
MS_LOG(WARNING) << "No valid json value for attr type: " << type;
}
}
bool AkgKernelJsonGenerator::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *const attrs_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(op_info);
MS_EXCEPTION_IF_NULL(attrs_json);
std::vector<std::shared_ptr<OpAttr>> attrs = op_info->attrs_ptr();
if (attrs.empty()) {
MS_LOG(INFO) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info attrs is empty";
return true;
}
std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info->inputs_ptr();
std::vector<int> dyn_input_sizes;
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
}
if (inputs.empty()) {
MS_LOG(ERROR) << "Apply kernel [" << anf_node->fullname_with_scope() << "] op info inputs is empty";
return false;
}
// create input name list for "x_shape" in attr with "x" in primitive.
std::map<size_t, std::string> op_info_shape_name;
for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) {
std::string input_name = inputs[op_info_input_i]->name();
std::string x_shape_name = input_name + "_shape";
static_cast<void>(op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name)));
}
for (const auto &op_attr : attrs) {
nlohmann::json attr_json;
ValuePtr attr_value = primitive->GetAttr(op_attr->name());
if (attr_value == nullptr && op_attr->name() != kArgDataformat) {
if (op_attr->param_type() == "required") {
// match "x_shape" in att with "x" in primitive.
std::string attr_name = op_attr->name();
auto find_item = std::find_if(
op_info_shape_name.begin(), op_info_shape_name.end(),
[attr_name](const std::map<size_t, std::string>::value_type item) { return item.second == attr_name; });
if (find_item != op_info_shape_name.end()) {
if (!dyn_input_sizes.empty()) {
if (find_item->first >= dyn_input_sizes.size() - 1) {
MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first
<< " is out of range:" << dyn_input_sizes.size() - 1 << ".";
return false;
}
size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0));
for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) {
attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx);
attr_json[kJsonKeyName] = op_attr->name();
attrs_json->push_back(attr_json);
tensor_idx++;
}
} else {
attr_json[kJsonKeyValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first);
attr_json[kJsonKeyName] = op_attr->name();
attrs_json->push_back(attr_json);
}
} else {
MS_LOG(ERROR) << "op [" << anf_node->fullname_with_scope() << "] should have attr :" << op_attr->name();
return false;
}
}
continue;
}
GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value);
attr_json[kJsonKeyName] = op_attr->name();
attrs_json->push_back(attr_json);
}
return true;
}
size_t AkgKernelJsonGenerator::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
<< cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]";
}
auto input_node = cnode->input(input_idx + 1);
if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) {
size_t index = input_tensor_idx_.size();
input_tensor_idx_[input_node] = index;
}
return input_tensor_idx_[input_node];
}
size_t AkgKernelJsonGenerator::GetOutputTensorIdxInc() {
size_t idx = output_tensor_idx_++;
return idx;
}
std::string AkgKernelJsonGenerator::GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position) {
if (node_json.count(tag) == 0) {
MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "].";
return "";
}
auto const &tag_desc = node_json[tag];
nlohmann::json first_index;
if (tag == kJsonKeyOutputDesc) {
first_index = tag_desc;
} else if (!tag_desc.is_array() || tag_desc.size() <= position.first) {
MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "].";
return "";
} else {
first_index = tag_desc[position.first];
}
if (!first_index.is_array() || first_index.size() <= position.second) {
MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "].";
return "";
}
auto const &second_index = first_index[position.second];
if (second_index.count(kJsonKeyTensorName) == 0) {
MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kJsonKeyTensorName << "].";
return "";
}
return second_index[kJsonKeyTensorName];
}
void AkgKernelJsonGenerator::SetTensorName(const std::string &tag, const std::string &new_name,
const std::pair<size_t, size_t> &position, nlohmann::json *const node_json) {
MS_EXCEPTION_IF_NULL(node_json);
if (node_json->count(tag) == 0) {
MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "].";
return;
}
nlohmann::json *tag_desc = &((*node_json)[tag]);
nlohmann::json *first_index;
if (tag == kJsonKeyOutputDesc) {
first_index = tag_desc;
} else if (!tag_desc->is_array() || tag_desc->size() <= position.first) {
MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "].";
return;
} else {
first_index = &((*tag_desc)[position.first]);
}
if (!first_index->is_array() || first_index->size() <= position.second) {
MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "].";
return;
}
nlohmann::json *second_index = &((*first_index)[position.second]);
if (second_index->count(kJsonKeyTensorName) == 0) {
MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kJsonKeyTensorName << "].";
return;
}
(*second_index)[kJsonKeyTensorName] = new_name;
return;
}
bool AkgKernelJsonGenerator::GenerateSingleKernelJson(const AnfNodePtr &anf_node, nlohmann::json *const node_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(node_json);
auto op_name = AnfAlgo::GetCNodeName(anf_node);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
MS_EXCEPTION_IF_NULL(op_info);
// get basic params from currentNodeOpDesc
(*node_json)[kJsonKeyName] = op_name;
(*node_json)[kJsonKeyImplPath] = op_info->impl_path();
if (dump_option_.save_ptr_address) {
std::ostringstream get_the_address;
get_the_address << anf_node.get();
auto address = get_the_address.str();
(*node_json)[kJsonKeyPtrAddress] = address;
address_node_map_[address] = anf_node;
}
// input desc
nlohmann::json inputs_json;
if (!CreateInputDescJson(anf_node, op_info, &inputs_json)) {
MS_LOG(ERROR) << "Create input desc json failed, op[" << anf_node->fullname_with_scope() << "].";
return false;
}
(*node_json)[kJsonKeyInputDesc] = inputs_json;
MS_LOG(DEBUG) << "Akg create input desc json success.";
// output desc
nlohmann::json outputs_json;
if (!CreateOutputDescJson(anf_node, op_info, &outputs_json)) {
MS_LOG(ERROR) << "Create output desc json failed, op[" << anf_node->fullname_with_scope() << "].";
return false;
}
(*node_json)[kJsonKeyOutputDesc] = outputs_json;
MS_LOG(DEBUG) << "Akg create output desc json success.";
// attribute desc
nlohmann::json attrs_json;
if (!CreateAttrDescJson(anf_node, op_info, &attrs_json)) {
MS_LOG(ERROR) << "Create attr desc json failed, op[" << anf_node->fullname_with_scope() << "].";
return false;
}
(*node_json)[kJsonKeyAttr] = attrs_json;
return true;
}
bool AkgKernelJsonGenerator::GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size) {
if (input_size == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "input size or output size is nullptr";
return false;
}
input_size->clear();
output_size->clear();
for (size_t i = 0; i < node_json[kJsonKeyInputDesc].size(); i++) {
for (size_t m = 0; m < node_json[kJsonKeyInputDesc][i].size(); m++) {
std::string dtype = node_json[kJsonKeyInputDesc][i][m][kJsonKeyDataType];
size_t nbyte = GetDtypeNbyte(dtype);
size_t size_i =
std::accumulate(node_json[kJsonKeyInputDesc][i][m][kJsonKeyShape].begin(),
node_json[kJsonKeyInputDesc][i][m][kJsonKeyShape].end(), nbyte, std::multiplies<size_t>());
input_size->push_back(size_i);
}
}
for (size_t i = 0; i < node_json[kJsonKeyOutputDesc].size(); i++) {
std::string dtype = node_json[kJsonKeyOutputDesc][i][kJsonKeyDataType];
size_t nbyte = GetDtypeNbyte(dtype);
size_t size_i =
std::accumulate(node_json[kJsonKeyOutputDesc][i][kJsonKeyShape].begin(),
node_json[kJsonKeyOutputDesc][i][kJsonKeyShape].end(), nbyte, std::multiplies<size_t>());
output_size->push_back(size_i);
}
return true;
}
bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::json *const kernel_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(kernel_json);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
MS_LOG(INFO) << "Akg start generate kernel json desc, full scope name is : " << anf_node->fullname_with_scope();
SetAkgKernelAttrs(anf_node);
if (!GenerateSingleKernelJson(anf_node, kernel_json)) {
MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
return false;
}
size_t hash_id = std::hash<std::string>()(kernel_json->dump());
kernel_name_ = op_name + "_";
(void)kernel_name_.append(std::to_string(hash_id));
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
(*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_node);
(*kernel_json)[kJsonKeyComposite] = false;
if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;
}
MS_LOG(INFO) << "Akg create kernel json desc success, full scope name is : " << anf_node->fullname_with_scope()
<< ", json info name is : " << kernel_name_;
return true;
}
bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list,
nlohmann::json *const kernel_json) {
if (anf_nodes.empty() || input_list.empty()) {
MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size()
<< "].";
return false;
}
MS_LOG(INFO) << "Fusion nodes: [" << output_list.size() << "], input_list: [" << anf_nodes.size()
<< "], output_list: [" << input_list.size() << "].";
std::map<AnfNodePtr, nlohmann::json> node_json_map;
for (auto const &anf_node : anf_nodes) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!AnfAlgo::IsRealKernel(anf_node)) {
MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "].";
return false;
}
SetAkgKernelAttrs(anf_node);
nlohmann::json node_json;
if (!GenerateSingleKernelJson(anf_node, &node_json)) {
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
return false;
}
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("fusion") != nullptr) {
node_json["fusion"] = primitive->GetAttr("fusion")->ToString();
}
node_json_map[anf_node] = node_json;
}
for (auto const &anf_node : anf_nodes) {
auto dyn_input_sizes = GetDynInputSize(anf_node);
bool is_dynamic_input = !dyn_input_sizes.empty();
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
size_t real_input_index = 0;
for (size_t i = 0; i < input_num; ++i) {
size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1;
for (size_t j = 0; j < input_tensor_num; ++j) {
auto tmp_input = GetKernelInput(anf_node, real_input_index);
std::string tensor_name = GetTensorName(node_json_map[anf_node], kJsonKeyInputDesc, std::make_pair(i, j));
if (node_json_map.find(tmp_input.first) != node_json_map.end()) {
std::string new_tensor_name =
GetTensorName(node_json_map[tmp_input.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_input.second));
SetTensorName(kJsonKeyInputDesc, new_tensor_name, std::make_pair(i, j), &(node_json_map[anf_node]));
MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of ["
<< anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output ["
<< new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "].";
} else {
MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of ["
<< anf_node->fullname_with_scope() << "] is out input.";
}
real_input_index++;
}
}
}
std::vector<nlohmann::json> node_json_desc;
std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc),
[&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; });
(*kernel_json)[kJsonKeyOpDesc] = node_json_desc;
nlohmann::json inputs_json;
auto input_index = GetInputIndex(anf_nodes, input_list);
for (size_t i = 0; i < input_index.size(); ++i) {
auto tmp_input = input_index[i];
auto type_id = this->GetInputDataType(tmp_input.first, tmp_input.second.first);
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
nlohmann::json input_desc_json;
input_desc_json[kJsonKeyTensorName] =
GetTensorName(node_json_map[tmp_input.first], kJsonKeyInputDesc, tmp_input.second);
input_desc_json[kJsonKeyDataType] = dtype;
input_desc_json[kJsonKeyFormat] = this->GetInputFormat(tmp_input.first, tmp_input.second.first);
input_desc_json[kJsonKeyShape] = this->GetInputShape(tmp_input.first, tmp_input.second.first);
inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json});
}
(*kernel_json)[kJsonKeyInputDesc] = inputs_json;
nlohmann::json outputs_json;
auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
std::map<size_t, std::vector<std::string>> sub_graphs;
std::map<size_t, size_t> dim_infos;
for (size_t i = 0; i < output_index.size(); ++i) {
auto tmp_output = output_index[i];
bool found = false;
nlohmann::json output_desc_json;
for (size_t input_i = 0; input_i < input_list.size(); ++input_i) {
if (tmp_output.first == input_list[input_i]) {
output_desc_json = inputs_json[input_i][0];
found = true;
break;
}
}
if (!found) {
auto type_id = this->GetOutputDataType(tmp_output.first, tmp_output.second);
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
output_desc_json[kJsonKeyTensorName] =
GetTensorName(node_json_map[tmp_output.first], kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second));
output_desc_json[kJsonKeyDataType] = dtype;
output_desc_json[kJsonKeyFormat] = this->GetOutputFormat(tmp_output.first, tmp_output.second);
auto output_shape = this->GetOutputShape(tmp_output.first, tmp_output.second);
if (output_shape.empty()) {
output_shape.push_back(1);
}
output_desc_json[kJsonKeyShape] = output_shape;
}
outputs_json.emplace_back(output_desc_json);
}
(*kernel_json)[kJsonKeyOutputDesc] = outputs_json;
auto processor = GetProcessorStr(anf_nodes[0]);
size_t hash_id = std::hash<std::string>()(kernel_json->dump());
kernel_name_ = "Fused_";
auto fg = anf_nodes[0]->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (attr_val != nullptr) {
auto fg_attr = GetValue<std::string>(attr_val);
(void)kernel_name_.append(fg_attr).append("_");
}
(void)kernel_name_.append(std::to_string(hash_id));
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
(*kernel_json)[kJsonKeyOp] = kernel_name_;
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = processor;
(*kernel_json)[kJsonKeyComposite] = true;
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();
if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;
}
return true;
}
bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node) {
kernel_json_ = nlohmann::json();
return CollectJson(anf_node, &kernel_json_);
}
bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list) {
kernel_json_ = nlohmann::json();
return CollectFusedJson(anf_nodes, input_list, output_list, &kernel_json_);
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,125 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_GENERATOR_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_GENERATOR_H_
#include <unordered_map>
#include <string>
#include <memory>
#include <map>
#include <utility>
#include <vector>
#include <nlohmann/json.hpp>
#include "backend/kernel_compiler/oplib/oplib.h"
namespace mindspore {
namespace kernel {
// json key
constexpr auto kJsonKeyOpDesc = "op_desc";
constexpr auto kJsonKeyAttr = "attr";
constexpr auto kJsonKeyInputDesc = "input_desc";
constexpr auto kJsonKeyFormat = "format";
constexpr auto kJsonKeyInferDataType = "infer_data_type";
constexpr auto kJsonKeyInferShape = "infer_shape";
constexpr auto kJsonKeyShape = "shape";
constexpr auto kJsonKeyDataType = "data_type";
constexpr auto kJsonKeyOutputDesc = "output_desc";
constexpr auto kJsonKeyName = "name";
constexpr auto kJsonKeyTensorName = "tensor_name";
constexpr auto kJsonKeyValue = "value";
constexpr auto kJsonKeyImplPath = "impl_path";
constexpr auto kJsonKeyProcess = "process";
constexpr auto kJsonKeyComposite = "composite";
constexpr auto kJsonKeyId = "id";
constexpr auto kJsonKeyOp = "op";
constexpr auto kJsonKeyPtrAddress = "ptr_address";
constexpr auto kJsonKeyCompositeGraph = "composite_graph";
constexpr auto kJsonKeyPlatform = "platform";
constexpr auto kAttrInputNames = "input_names";
// dump option
struct DumpOption {
bool is_before_select_kernel = false;
bool save_ptr_address = false;
};
class AkgKernelJsonGenerator {
public:
AkgKernelJsonGenerator() { Clear(); }
explicit AkgKernelJsonGenerator(DumpOption dump_option) : dump_option_(dump_option) { Clear(); }
~AkgKernelJsonGenerator() = default;
bool CollectJson(const AnfNodePtr &anf_node, nlohmann::json *const kernel_json);
bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list, nlohmann::json *const kernel_json);
bool CollectJson(const AnfNodePtr &anf_node);
bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list);
bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, nlohmann::json *const node_json);
std::string kernel_name() const { return kernel_name_; }
nlohmann::json kernel_json() const { return kernel_json_; }
std::string kernel_json_str() const { return kernel_json_.dump(); }
const std::vector<size_t> &input_size_list() const { return input_size_list_; }
const std::vector<size_t> &output_size_list() const { return output_size_list_; }
void Clear() {
input_tensor_idx_.clear();
address_node_map_.clear();
output_tensor_idx_ = 0;
}
void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; }
std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; }
private:
bool CreateInputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *const inputs_json);
bool CreateOutputDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *const outputs_json);
void GetJson(const AnfNodePtr &anf_node, const std::vector<int> &dyn_input_sizes,
const std::shared_ptr<OpAttr> &op_attr, nlohmann::json *const attr_json, const ValuePtr &attr_value);
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::shared_ptr<OpInfo> &op_info,
nlohmann::json *const attrs_json);
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size);
int GetOpCntInc();
size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
size_t GetOutputTensorIdxInc();
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
nlohmann::json *const node_json);
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position);
TypeId GetInputDataType(const AnfNodePtr &anf_node, size_t real_index);
std::vector<size_t> GetInputShape(const AnfNodePtr &anf_node, size_t real_index);
std::string GetInputFormat(const AnfNodePtr &anf_node, size_t real_index);
TypeId GetOutputDataType(const AnfNodePtr &anf_node, size_t index);
std::vector<size_t> GetOutputShape(const AnfNodePtr &anf_node, size_t index);
std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index);
DumpOption dump_option_;
static int op_cnt_;
// lock for variable fusionOpCnt in singleton mode
static std::mutex op_cnt_mtx_;
std::string kernel_name_;
std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_;
size_t output_tensor_idx_;
nlohmann::json kernel_json_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::map<std::string, AnfNodePtr> address_node_map_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_JSON_GENERATOR_H_

View File

@ -29,6 +29,7 @@
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/tbe/tbe_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h"
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h" #include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" #include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_build_client.h" #include "backend/session/kernel_build_client.h"
@ -38,287 +39,37 @@ namespace kernel {
constexpr int32_t PROCESS_NUM = 16; constexpr int32_t PROCESS_NUM = 16;
constexpr int32_t TIME_OUT = 300; constexpr int32_t TIME_OUT = 300;
constexpr auto kOpDesc = "op_desc"; bool AkgAscendKernelBuilder::AkgOpParallelBuild(
constexpr auto kShape = "shape"; const std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> &build_args) {
constexpr auto kDataType = "data_type";
constexpr auto kInputDesc = "input_desc";
constexpr auto kOutputDesc = "output_desc";
constexpr auto kTensorName = "tensor_name";
namespace {
void UpdateTensorNameInJson(const std::vector<AnfNodePtr> &anf_nodes,
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
for (auto const &anf_node : anf_nodes) {
std::vector<int> dyn_input_sizes;
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
}
bool is_dynamic_input = !dyn_input_sizes.empty();
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
size_t real_input_index = 0;
for (size_t i = 0; i < input_num; ++i) {
size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1;
for (size_t j = 0; j < input_tensor_num; ++j) {
auto tmp_input = GetKernelInput(anf_node, real_input_index);
std::string tensor_name = GetTensorName((*node_json_map)[anf_node], kInputDesc, std::make_pair(i, j));
if (node_json_map->find(tmp_input.first) != node_json_map->end()) {
std::string new_tensor_name =
GetTensorName((*node_json_map)[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second));
SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &((*node_json_map)[anf_node]));
MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of ["
<< anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output ["
<< new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "].";
} else {
MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of ["
<< anf_node->fullname_with_scope() << "] is out input.";
}
real_input_index++;
}
}
}
}
nlohmann::json GetInputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
nlohmann::json inputs_json;
auto input_index = GetInputIndex(anf_nodes, input_list);
for (size_t i = 0; i < input_index.size(); ++i) {
auto tmp_input = input_index[i];
auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first);
std::string dtype = TypeId2String(type_id);
nlohmann::json input_desc_json;
input_desc_json[kTensorName] = GetTensorName((*node_json_map)[tmp_input.first], kInputDesc, tmp_input.second);
input_desc_json[kDataType] = dtype;
input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first);
inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json});
}
return inputs_json;
}
nlohmann::json GetOutputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list, const nlohmann::json &inputs_json,
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
nlohmann::json outputs_json;
auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
for (size_t i = 0; i < output_index.size(); ++i) {
auto tmp_output = output_index[i];
bool found = false;
nlohmann::json output_desc_json;
for (size_t input_i = 0; input_i < input_list.size(); ++input_i) {
if (tmp_output.first == input_list[input_i]) {
output_desc_json = inputs_json[input_i][0];
found = true;
break;
}
}
if (!found) {
auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second);
std::string dtype = TypeId2String(type_id);
output_desc_json[kTensorName] =
GetTensorName((*node_json_map)[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second));
output_desc_json[kDataType] = dtype;
auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second);
if (output_shape.empty()) {
output_shape.push_back(1);
}
output_desc_json[kShape] = output_shape;
}
outputs_json.emplace_back(output_desc_json);
}
return outputs_json;
}
std::pair<std::vector<std::string>, std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>>> PreProcessJsonForBuild(
const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) {
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
std::vector<std::string> jsons; std::vector<std::string> jsons;
std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> repeat_nodes; std::unordered_set<std::string> kernel_name_set;
std::unordered_set<std::string> json_name_set; std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> repeat_nodes;
for (const auto &[builder, anf_node] : build_args) { for (const auto &[json_generator, anf_node] : build_args) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
auto json_name = builder.json_name(); auto kernel_name = json_generator.kernel_name();
MS_LOG(DEBUG) << "Akg start compile op: " << json_name; MS_LOG(DEBUG) << "Akg start compile op: " << kernel_name;
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node));
if (cached_kernel_pack != nullptr) { if (cached_kernel_pack != nullptr) {
MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope[" MS_LOG(DEBUG) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "]."; << anf_node->fullname_with_scope() << "].";
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack); auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
continue; continue;
} }
if (json_name_set.count(json_name) != 0) { if (kernel_name_set.count(kernel_name) != 0) {
repeat_nodes.push_back({builder, anf_node}); repeat_nodes.push_back({json_generator, anf_node});
continue; continue;
} }
json_name_set.insert(json_name); kernel_name_set.insert(kernel_name);
auto node_json = builder.kernel_json(); auto kernel_json = json_generator.kernel_json_str();
kernel::SaveJsonInfo(json_name, node_json); kernel::SaveJsonInfo(kernel_name, kernel_json);
jsons.push_back(node_json); jsons.push_back(kernel_json);
} }
return std::make_pair(jsons, repeat_nodes);
}
bool PostProcessAfterCompile(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args,
const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &repeat_nodes) {
for (const auto &[builder, anf_node] : build_args) {
auto json_name = builder.json_name();
auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return false;
}
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack);
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!";
}
for (const auto &[builder, anf_node] : repeat_nodes) {
auto node_json = builder.kernel_json();
auto json_name = builder.json_name();
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
if (cached_kernel_pack == nullptr) {
return false;
}
MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
}
return true;
}
} // namespace
bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
auto it = kAkgKernelAttrsProcessMap.find(op_name);
if (it != kAkgKernelAttrsProcessMap.end()) {
it->second(anf_node);
}
MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
nlohmann::json node_json;
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed.";
}
kernel_json_ = node_json.dump();
if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;
}
return true;
}
bool AkgAscendKernelBuilder::GenJsonAndPreprocess4Fused(const std::vector<AnfNodePtr> &anf_nodes,
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
for (auto const &anf_node : anf_nodes) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (!AnfAlgo::IsRealKernel(anf_node)) {
MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "].";
return false;
}
auto it = kAkgKernelAttrsProcessMap.find(op_name);
if (it != kAkgKernelAttrsProcessMap.end()) {
it->second(anf_node);
}
nlohmann::json node_json;
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed.";
return false;
}
// No need for composite op.
node_json.erase("id");
node_json.erase("op");
node_json.erase("composite");
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("fusion") != nullptr) {
node_json["fusion"] = primitive->GetAttr("fusion")->ToString();
}
(*node_json_map)[anf_node] = node_json;
}
return true;
}
bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list) {
if (anf_nodes.empty() || input_list.empty()) {
MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size()
<< "].";
return false;
}
MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list ["
<< input_list.size() << "].";
std::map<AnfNodePtr, nlohmann::json> node_json_map;
if (!GenJsonAndPreprocess4Fused(anf_nodes, &node_json_map)) {
return false;
}
UpdateTensorNameInJson(anf_nodes, &node_json_map);
nlohmann::json fused_node_json;
std::vector<nlohmann::json> node_json_desc;
std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc),
[&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; });
fused_node_json[kOpDesc] = node_json_desc;
fused_node_json[kInputDesc] = GetInputsJson(anf_nodes, input_list, &node_json_map);
fused_node_json[kOutputDesc] =
GetOutputsJson(anf_nodes, input_list, output_list, fused_node_json[kInputDesc], &node_json_map);
size_t hash_id = std::hash<std::string>()(fused_node_json.dump());
json_name_ = "Fused_";
auto fg = anf_nodes[0]->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (attr_val != nullptr) {
auto fg_attr = GetValue<std::string>(attr_val);
(void)json_name_.append(fg_attr).append("_");
}
(void)json_name_.append(std::to_string(hash_id));
fused_node_json["composite_graph"] = fg->ToString();
fused_node_json["op"] = json_name_;
fused_node_json["platform"] = "AKG";
fused_node_json["process"] = "aicore";
fused_node_json["composite"] = true;
kernel_json_ = fused_node_json.dump();
if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;
}
return true;
}
bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) {
auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args);
if (jsons.empty()) { if (jsons.empty()) {
return true; return true;
} }
@ -337,18 +88,43 @@ bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfN
return false; return false;
} }
if (!PostProcessAfterCompile(build_args, repeat_nodes)) { // All unique done here, cache them and set kernel.
return false; for (const auto &[json_generator, anf_node] : build_args) {
auto kernel_name = json_generator.kernel_name();
auto new_kernel_pack = tbe::TbeUtils::InsertCache(kernel_name, GetProcessorStr(anf_node));
if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return false;
}
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack);
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
MS_LOG(DEBUG) << "Akg compile " << kernel_name << " kernel and insert cache successfully!";
}
// Handle repeated nodes.
for (const auto &[json_generator, anf_node] : repeat_nodes) {
auto kernel_name = json_generator.kernel_name();
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(kernel_name, GetProcessorStr(anf_node));
if (cached_kernel_pack == nullptr) return false;
MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
} }
return true; return true;
} }
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) { bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> json_and_node; std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> json_and_node;
for (const auto &anf_node : anf_nodes) { for (const auto &anf_node : anf_nodes) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
AkgAscendKernelBuilder akg_cce_kernel_builder; AkgKernelJsonGenerator akg_kernel_json_generator;
KernelPackPtr kernel_pack = nullptr; KernelPackPtr kernel_pack = nullptr;
auto cnode = anf_node->cast<CNodePtr>(); auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
@ -363,18 +139,17 @@ bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
std::vector<AnfNodePtr> node_list; std::vector<AnfNodePtr> node_list;
std::vector<AnfNodePtr> input_list; std::vector<AnfNodePtr> input_list;
std::vector<AnfNodePtr> output_list; std::vector<AnfNodePtr> output_list;
std::string op_name = AnfAlgo::GetCNodeName(anf_node); MS_LOG(INFO) << "Akg start compile composite op[" << anf_node->fullname_with_scope() << "]";
MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]";
GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) { if (!akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list)) {
MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "]."; MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << anf_node->fullname_with_scope() << "].";
} }
} else { } else {
if (!akg_cce_kernel_builder.CollectJson(anf_node)) { if (!akg_kernel_json_generator.CollectJson(anf_node)) {
MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "]."; MS_EXCEPTION(UnknownError) << "Akg build failed op[" << anf_node->fullname_with_scope() << "].";
} }
} }
json_and_node.push_back({akg_cce_kernel_builder, anf_node}); json_and_node.push_back({akg_kernel_json_generator, anf_node});
} }
if (json_and_node.empty()) { if (json_and_node.empty()) {
@ -382,7 +157,8 @@ bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
return true; return true;
} }
return AkgOpParallelBuild(json_and_node); AkgAscendKernelBuilder akg_ascend_kernel_builder;
return akg_ascend_kernel_builder.AkgOpParallelBuild(json_and_node);
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -18,35 +18,21 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
#include <string> #include <string>
#include <memory> #include <utility>
#include <vector> #include <vector>
#include <map> #include <map>
#include "ir/anf.h" #include "ir/anf.h"
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/akg/akg_kernel_build.h" #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class AkgAscendKernelBuilder : public AkgKernelBuild { class AkgAscendKernelBuilder {
public: public:
AkgAscendKernelBuilder() = default; AkgAscendKernelBuilder() = default;
~AkgAscendKernelBuilder() = default; ~AkgAscendKernelBuilder() = default;
bool CollectJson(const AnfNodePtr &anf_node); bool AkgOpParallelBuild(const std::vector<std::pair<AkgKernelJsonGenerator, AnfNodePtr>> &build_args);
bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list);
std::string json_name() const { return json_name_; }
std::string kernel_json() const { return kernel_json_; }
const std::vector<size_t> &input_size_list() const { return input_size_list_; }
const std::vector<size_t> &output_size_list() const { return output_size_list_; }
private:
bool GenJsonAndPreprocess4Fused(const std::vector<AnfNodePtr> &anf_nodes,
std::map<AnfNodePtr, nlohmann::json> *node_json_map);
std::string kernel_json_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
}; };
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes); bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);

View File

@ -15,29 +15,116 @@
*/ */
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h" #include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h"
#include <Python.h>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string>
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/akg/akg_kernel_build.h" #include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h" #include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_build_client.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node) { constexpr int32_t ARGS_SIZE = 1;
MS_EXCEPTION_IF_NULL(anf_node); constexpr auto kCompileWithJsonFunc = "compilewithjson";
AkgKernelBuild akg_kernel_build;
std::vector<size_t> input_size_list; KernelPackPtr AkgGpuKernelBuilder::OpBuild(const AkgKernelJsonGenerator &json_generator, const AnfNodePtr &anf_node) {
std::vector<size_t> output_size_list; MS_EXCEPTION_IF_NULL(anf_node);
KernelPackPtr kernel_pack = akg_kernel_build.BuildByJson(anf_node, &input_size_list, &output_size_list); auto processor = GetProcessorStr(anf_node);
MS_EXCEPTION_IF_NULL(kernel_pack); auto kernel_name = json_generator.kernel_name();
auto cached_kernel_pack = SearchCache(kernel_name, processor);
if (cached_kernel_pack != nullptr) {
MS_LOG(INFO) << "Use cached kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return cached_kernel_pack;
}
(void)alarm(AUTODIFF_COMPILE_OVERTIME);
auto kernel_json = json_generator.kernel_json_str();
auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(kernel_json);
(void)alarm(0);
if (!res) {
MS_LOG(ERROR) << "Akg compile failed, json: " << kernel_json;
return nullptr;
}
auto new_kernel_pack = InsertCache(kernel_name, processor);
kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path());
if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, kernel_name[" << kernel_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return nullptr;
}
return new_kernel_pack;
}
KernelModPtr AkgGpuKernelBuilder::BuildByJson(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_LOG(INFO) << "Akg start compile, op[" << anf_node->fullname_with_scope() << "]";
AkgKernelJsonGenerator json_generator;
if (!json_generator.CollectJson(anf_node)) {
MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
}
auto kernel_pack = OpBuild(json_generator, anf_node);
if (kernel_pack == nullptr) {
MS_LOG(ERROR) << "Akg build failed op[" << anf_node->fullname_with_scope() << "].";
return nullptr;
}
auto kernel_mod_ptr = std::make_shared<GpuKernelMod>(kernel_pack); auto kernel_mod_ptr = std::make_shared<GpuKernelMod>(kernel_pack);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr); MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
kernel_mod_ptr->SetInputSizeList(input_size_list); kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(output_size_list); kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
MS_LOG(INFO) << "Akg compile success, op[" << anf_node->fullname_with_scope() << "]";
return kernel_mod_ptr; return kernel_mod_ptr;
} }
KernelModPtr AkgGpuKernelBuilder::FuseByJson(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_LOG(INFO) << "Akg start compile, graph_kernel[" << anf_node->fullname_with_scope() << "]";
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
if (mng == nullptr) {
mng = Manage(fg, true);
fg->set_manager(mng);
}
AnfNodePtrList node_list;
AnfNodePtrList input_list;
AnfNodePtrList output_list;
GetValidKernelNodes(fg, &node_list, &input_list, &output_list);
AkgKernelJsonGenerator json_generator;
if (!json_generator.CollectFusedJson(node_list, input_list, output_list)) {
MS_LOG(ERROR) << "Op[" << anf_node->fullname_with_scope() << "] create single kernel json failed.";
}
auto kernel_pack = OpBuild(json_generator, anf_node);
if (kernel_pack == nullptr) {
MS_LOG(ERROR) << "Akg build failed, graph_kernel[" << anf_node->fullname_with_scope() << "].";
return nullptr;
}
auto kernel_mod_ptr = std::make_shared<GpuKernelMod>(kernel_pack);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
MS_LOG(INFO) << "Akg compile success, graph_kernel[" << anf_node->fullname_with_scope() << "]";
return kernel_mod_ptr;
}
KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
AkgGpuKernelBuilder akg_gpu_kernel_builder;
if (AnfAlgo::IsGraphKernel(anf_node)) {
return akg_gpu_kernel_builder.FuseByJson(anf_node);
}
return akg_gpu_kernel_builder.BuildByJson(anf_node);
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -16,11 +16,25 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_
#include <string>
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "base/base.h" #include "base/base.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class AkgGpuKernelBuilder {
public:
AkgGpuKernelBuilder() = default;
~AkgGpuKernelBuilder() = default;
KernelModPtr BuildByJson(const AnfNodePtr &anf_node);
KernelModPtr FuseByJson(const AnfNodePtr &anf_node);
private:
KernelPackPtr OpBuild(const AkgKernelJsonGenerator &json_generator, const AnfNodePtr &anf_node);
};
KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node); KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -205,10 +205,13 @@ TypeId DtypeToTypeId(const std::string &dtypes) {
} }
} }
std::string TypeId2String(TypeId type_id) { std::string TypeId2String(TypeId type_id, bool unknown_as_default) {
auto iter = type_id_str_map.find(type_id); auto iter = type_id_str_map.find(type_id);
if (iter == type_id_str_map.end()) { if (iter == type_id_str_map.end()) {
return std::string(TypeIdLabel(type_id)); if (!unknown_as_default) {
MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
}
return "float32";
} }
return iter->second; return iter->second;
} }
@ -427,9 +430,9 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
return true; return true;
} }
void SaveJsonInfo(const std::string &json_name, const std::string &info) { void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
char real_path[PATH_MAX] = {0}; char real_path[PATH_MAX] = {0};
std::string path = kCceKernelMeta + json_name + kInfoSuffix; std::string path = base_path + json_name + kInfoSuffix;
if (path.size() > PATH_MAX) { if (path.size() > PATH_MAX) {
MS_LOG(DEBUG) << "file path " << path << " is too long."; MS_LOG(DEBUG) << "file path " << path << " is too long.";
return; return;
@ -458,6 +461,14 @@ void SaveJsonInfo(const std::string &json_name, const std::string &info) {
} }
} }
Processor GetProcessor(const string &processor) {
if (processor == kProcessorAiCore) return Processor::AICORE;
if (processor == kProcessorAiCpu) return Processor::AICPU;
if (processor == kProcessorCuda) return Processor::CUDA;
MS_LOG(DEBUG) << "Unknown processor type.";
return Processor::UNKNOWN;
}
std::string GetProcessor(const AnfNodePtr &anf_node) { std::string GetProcessor(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
std::string device; std::string device;
@ -628,16 +639,21 @@ void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr>
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list, void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) { std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node_list); MS_EXCEPTION_IF_NULL(node_list);
MS_EXCEPTION_IF_NULL(input_list); MS_EXCEPTION_IF_NULL(input_list);
MS_EXCEPTION_IF_NULL(output_list);
MS_EXCEPTION_IF_NULL(func_graph);
GetValidKernelNodes(func_graph, node_list); GetValidKernelNodes(func_graph, node_list);
auto parameters = func_graph->parameters(); auto parameters = func_graph->parameters();
input_list->insert(input_list->begin(), parameters.begin(), parameters.end()); input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
GetFuncGraphOutputNodes(func_graph, output_list);
}
void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(output_list);
auto func_output = func_graph->output(); auto func_output = func_graph->output();
MS_EXCEPTION_IF_NULL(func_output); MS_EXCEPTION_IF_NULL(func_output);
if (func_output->isa<CNode>()) { if (func_output->isa<CNode>()) {
@ -780,5 +796,36 @@ std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode) {
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode); AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
return axis; return axis;
} }
std::string GetProcessorStr(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string processor = kProcessorUnknown;
auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
// we may call this before kernel select.
if (build_info == nullptr) {
return processor;
}
switch (build_info->processor()) {
case Processor::AICORE:
processor = kProcessorAiCore;
break;
case Processor::AICPU:
processor = kProcessorAiCpu;
break;
case Processor::CUDA:
processor = kProcessorCuda;
break;
default:
MS_LOG(ERROR) << "Unknown processor type.";
break;
}
return processor;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -23,6 +23,7 @@
#include <unordered_set> #include <unordered_set>
#include <map> #include <map>
#include <string> #include <string>
#include <algorithm>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
@ -37,6 +38,7 @@ constexpr auto kGpuKernelMeta = "./cuda_meta";
constexpr auto kProcessorAiCore = "aicore"; constexpr auto kProcessorAiCore = "aicore";
constexpr auto kProcessorAiCpu = "aicpu"; constexpr auto kProcessorAiCpu = "aicpu";
constexpr auto kProcessorCuda = "cuda"; constexpr auto kProcessorCuda = "cuda";
constexpr auto kProcessorUnknown = "unknown";
constexpr auto kJsonSuffix = ".json"; constexpr auto kJsonSuffix = ".json";
constexpr auto kInfoSuffix = ".info"; constexpr auto kInfoSuffix = ".info";
constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600; constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600;
@ -76,12 +78,13 @@ KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &pro
KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor);
TypeId DtypeToTypeId(const std::string &dtypes); TypeId DtypeToTypeId(const std::string &dtypes);
std::string Dtype2ShortType(const std::string &dtypes); std::string Dtype2ShortType(const std::string &dtypes);
std::string TypeId2String(TypeId type_id); std::string TypeId2String(TypeId type_id, bool unknown_as_default = false);
size_t GetDtypeNbyte(const std::string &dtypes); size_t GetDtypeNbyte(const std::string &dtypes);
bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor, bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list); std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list);
void SaveJsonInfo(const std::string &json_name, const std::string &info); void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path = kCceKernelMeta);
std::string GetProcessor(const AnfNodePtr &anf_node); std::string GetProcessor(const AnfNodePtr &anf_node);
Processor GetProcessor(const string &processor);
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b); bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
int Sign(float x); int Sign(float x);
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index); std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index);
@ -90,13 +93,26 @@ std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(cons
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list, std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list); const std::vector<AnfNodePtr> &output_list);
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list);
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list, void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list); std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list);
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list); void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list);
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json); bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json);
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list); void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list);
bool IsWeightBoundary(const AnfNodePtr &node); bool IsWeightBoundary(const AnfNodePtr &node);
std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode); std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode);
std::string GetProcessorStr(const AnfNodePtr &anf_node);
template <typename T>
inline std::string Vector2Str(const std::vector<T> &inputs) {
if (!inputs.empty()) {
std::ostringstream oss;
(void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator<T>(oss, ", "));
oss << inputs.back();
return oss.str();
}
return "";
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -16,14 +16,12 @@
#include <unistd.h> #include <unistd.h>
#include <fstream> #include <fstream>
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "securec/include/securec.h" #include "securec/include/securec.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/convert_utils.h" #include "utils/convert_utils.h"
#include "utils/system/sha256.h" #include "utils/system/sha256.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
namespace { namespace {

View File

@ -49,6 +49,7 @@ enum OpPattern {
// Backend processor // Backend processor
enum Processor { enum Processor {
UNKNOWN = -1,
AICORE = 0, AICORE = 0,
AICPU, AICPU,
CUDA, CUDA,

View File

@ -5,13 +5,19 @@ file(GLOB_RECURSE _PREACTIVATE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
) )
if (ENABLE_D) if (ENABLE_D)
file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc") file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST}) "ascend/*.cc"
"graph_kernel/*.cc"
)
list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST})
endif () endif ()
if (ENABLE_GPU) if (ENABLE_GPU)
file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc") file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
list(APPEND _PREACTIVATE_SRC_LIST ${_GPU_SRC_LIST}) "gpu/*.cc"
"graph_kernel/*.cc"
)
list(APPEND _PREACTIVATE_SRC_LIST ${_GPU_SRC_LIST})
endif () endif ()
set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT) set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT)

View File

@ -14,6 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include "backend/optimizer/ascend/ascend_backend_optimization.h" #include "backend/optimizer/ascend/ascend_backend_optimization.h"
#include <algorithm>
#include <list>
#include <memory> #include <memory>
#include <string> #include <string>
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
@ -68,8 +70,6 @@
#include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" #include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
#include "backend/optimizer/pass/eliminate_redundant_op.h" #include "backend/optimizer/pass/eliminate_redundant_op.h"
#include "backend/optimizer/pass/common_subexpression_elimination.h" #include "backend/optimizer/pass/common_subexpression_elimination.h"
#include "backend/optimizer/pass/fuse_graph_kernel.h"
#include "backend/optimizer/pass/fuse_basic.h"
#include "backend/optimizer/pass/add_atomic_clean.h" #include "backend/optimizer/pass/add_atomic_clean.h"
#include "backend/optimizer/ascend/format_type/merge_cast_to_op.h" #include "backend/optimizer/ascend/format_type/merge_cast_to_op.h"
#include "backend/optimizer/ascend/format_type/check_consistency.h" #include "backend/optimizer/ascend/format_type/check_consistency.h"
@ -106,6 +106,8 @@
#include "backend/optimizer/ascend/ir_fission/pack_fission.h" #include "backend/optimizer/ascend/ir_fission/pack_fission.h"
#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h" #include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "utils/config_manager.h" #include "utils/config_manager.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
#include "debug/dump_proto.h" #include "debug/dump_proto.h"
@ -406,7 +408,7 @@ void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &ke
} }
// Fuse graph kernels with basic ops // Fuse graph kernels with basic ops
FuseGraphKernel(kernel_graph, is_before_kernel_select); static_cast<void>(FuseCompositeOps(kernel_graph, is_before_kernel_select));
if (save_graphs) { if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" + std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" +
@ -429,17 +431,17 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
save_graphs_path = "."; save_graphs_path = ".";
} }
if (save_graphs) { if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" + std::string file_path = save_graphs_path + "/" + "hwopt_fuse_basic_opt_before_graph_" +
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
".ir"; ".ir";
DumpIR(file_path, kernel_graph, true); DumpIR(file_path, kernel_graph, true);
} }
// Fuse basic ops with basic ops // Fuse basic ops with basic ops
FuseBasic(kernel_graph, is_before_kernel_select); static_cast<void>(FuseBasicOps(kernel_graph, is_before_kernel_select));
if (save_graphs) { if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" + std::string file_path = save_graphs_path + "/" + "hwopt_fuse_basic_opt_end_graph_" +
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
".ir"; ".ir";
DumpIR(file_path, kernel_graph, true); DumpIR(file_path, kernel_graph, true);

View File

@ -601,6 +601,7 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i
std::vector<std::string> new_input_names; std::vector<std::string> new_input_names;
auto primitive = AnfAlgo::GetCNodePrimitive(cnode); auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
primitive = primitive->Clone();
auto input_names = primitive->GetAttr(kAttrInputNames); auto input_names = primitive->GetAttr(kAttrInputNames);
if (input_names == nullptr) { if (input_names == nullptr) {
MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
@ -631,6 +632,7 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i
} }
if (need_update) { if (need_update) {
// Update cnode's inputs // Update cnode's inputs
new_inputs[0] = NewValueNode(primitive);
cnode->set_inputs(new_inputs); cnode->set_inputs(new_inputs);
// Update cnode's input_names attr // Update cnode's input_names attr
primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));

View File

@ -73,7 +73,7 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
if (save_graphs) { if (save_graphs) {
auto dump_file_path = auto dump_file_path =
save_graphs_path + "/" + "hwopt_" + name() + "_" + std::to_string(num) + "_" + pass->name() + ".ir"; save_graphs_path + "/" + "hwopt_" + name() + "_" + std::to_string(num) + "_" + pass->name() + ".ir";
DumpIR(dump_file_path, func_graph); DumpIR(dump_file_path, func_graph, true);
} }
num++; num++;
} }

View File

@ -1,4 +1,3 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
@ -14,8 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/optimizer/pass/fuse_basic.h" #include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "backend/optimizer/pass/fuse_graph_kernel.h"
#include <memory> #include <memory>
#include <algorithm> #include <algorithm>
@ -31,17 +29,30 @@
#include "vm/segment_runner.h" #include "vm/segment_runner.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
std::vector<PrimitivePtr> get_fusable_basic_ops(bool is_before_kernel_select) { bool IsBasicOp(const AnfNodePtr &node, bool is_before_kernel_select) {
#if ENABLE_D
std::vector<PrimitivePtr> fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, std::vector<PrimitivePtr> fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub,
prim::kPrimExpandDims}; prim::kPrimExpandDims};
if (!is_before_kernel_select) { if (!is_before_kernel_select) {
fusable_basic_ops.push_back(prim::kPrimCast); fusable_basic_ops.push_back(prim::kPrimCast);
} }
return fusable_basic_ops; #elif ENABLE_GPU
std::vector<PrimitivePtr> fusable_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData};
#else
std::vector<PrimitivePtr> fusable_basic_ops;
#endif
return std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
} }
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info,
@ -53,16 +64,14 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKe
return EXCLUDE; return EXCLUDE;
} }
auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select); bool is_fusable = IsBasicOp(node, info.is_before_kernel_select);
bool is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
return is_fusable ? FOLLOW : EXCLUDE; return is_fusable ? FOLLOW : EXCLUDE;
} }
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
GraphKernelInfo info; GraphKernelInfo info;
info.is_before_kernel_select = is_before_kernel_select; info.is_before_kernel_select = is_before_kernel_select;
// Search fusable nodes according input direction. // Search fusable nodes according input direction.
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1); auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1);
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
@ -170,8 +179,9 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con
fg->set_output(fg_new_output, true); fg->set_output(fg_new_output, true);
} }
void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::vector<AnfNodePtr> &todos, bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos,
std::unordered_set<AnfNodePtr> *fused_ops, bool is_before_kernel_select) { std::unordered_set<AnfNodePtr> *fused_ops, bool is_before_kernel_select) {
bool changed = false;
auto mng = kernel_graph->manager(); auto mng = kernel_graph->manager();
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = (*iter)->cast<CNodePtr>(); auto node = (*iter)->cast<CNodePtr>();
@ -181,9 +191,7 @@ void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, const
if (fused_ops->count(node)) { if (fused_ops->count(node)) {
continue; continue;
} }
auto fusable_basic_ops = get_fusable_basic_ops(is_before_kernel_select); bool is_basic_op = IsBasicOp(node, is_before_kernel_select);
bool is_basic_op = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
if (!is_basic_op || !kernel_graph->nodes().contains(node)) { if (!is_basic_op || !kernel_graph->nodes().contains(node)) {
continue; continue;
} }
@ -193,12 +201,16 @@ void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, const
continue; continue;
} }
changed = true;
FuncGraphPtr fg; FuncGraphPtr fg;
AnfNodePtrList inputs; AnfNodePtrList inputs;
AnfNodePtrList outputs; AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
RemoveControlDependOut(fg, &outputs, mng); RemoveControlDependOut(fg, &outputs, mng);
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select); auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select);
if (!is_before_kernel_select) {
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
}
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
@ -210,10 +222,12 @@ void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, const
fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end()); fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
} }
std::dynamic_pointer_cast<session::KernelGraph>(kernel_graph)->SetExecOrderByDefault();
return changed;
} }
} // namespace } // namespace
void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) { bool FuseBasicOps(const FuncGraphPtr &kernel_graph, bool is_before_kernel_select) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager(); auto mng = kernel_graph->manager();
if (mng == nullptr) { if (mng == nullptr) {
@ -223,7 +237,9 @@ void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool i
std::unordered_set<AnfNodePtr> fused_ops; std::unordered_set<AnfNodePtr> fused_ops;
auto todos = TopoSort(kernel_graph->get_return()); auto todos = TopoSort(kernel_graph->get_return());
std::reverse(todos.begin(), todos.end()); std::reverse(todos.begin(), todos.end());
FuseBasic(kernel_graph, todos, &fused_ops, is_before_kernel_select); return FuseBasicOps(kernel_graph, todos, &fused_ops, is_before_kernel_select);
} }
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { return FuseBasicOps(func_graph, false); }
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -14,8 +14,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_BASIC_H_ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_BASIC_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_
#include <memory> #include <memory>
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
@ -23,7 +23,16 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
void FuseBasic(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select); bool FuseBasicOps(const FuncGraphPtr &kernel_graph, bool is_before_kernel_select);
class BasicOpsFusion : public Pass {
public:
BasicOpsFusion() : Pass("basic_ops_fusion") {}
~BasicOpsFusion() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
using FuseBasicPtr = std::shared_ptr<BasicOpsFusion>;
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_BASIC_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_

View File

@ -0,0 +1,385 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include <memory>
#include <string>
#include <algorithm>
#include <unordered_set>
#include <map>
#include <set>
#include <queue>
#include <vector>
#include "frontend/operator/ops.h"
#include "utils/utils.h"
#include "utils/ordered_set.h"
#include "utils/ordered_map.h"
#include "ir/graph_utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "vm/segment_runner.h"
#include "debug/draw.h"
#include "debug/anf_ir_dump.h"
#include "ir/func_graph_cloner.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
namespace mindspore {
namespace opt {
bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) {
#if ENABLE_D
std::vector<PrimitivePtr> basic_ops = {
prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum,
prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt,
prim::kPrimExpandDims, prim::kPrimReciprocal, prim::kPrimLessEqual};
if (!is_before_kernel_select) {
basic_ops.push_back(prim::kPrimCast);
}
#elif ENABLE_GPU
std::vector<PrimitivePtr> basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData};
#else
std::vector<PrimitivePtr> basic_ops;
#endif
return std::any_of(basic_ops.begin(), basic_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
bool IsReduceOp(const AnfNodePtr &node) {
std::vector<PrimitivePtr> reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin,
prim::kPrimReduceMax, prim::kPrimReduceAll};
return std::any_of(reduce_ops.begin(), reduce_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
void GetGraphKernelInfo(const FuncGraphPtr &fg, GraphKernelInfo *info) {
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
if (mng == nullptr) {
mng = Manage(fg, false);
fg->set_manager(mng);
}
const auto &nodes = fg->nodes();
info->op_type = ELEWISE;
info->cal_step = -1;
info->reduce_op_num = 0;
for (auto node : nodes) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
info->cal_step++;
if (IsReduceOp(node)) {
info->op_type = REDUCE;
info->reduce_op_num++;
}
}
auto fg_flag = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_flag != nullptr) {
auto fg_name = GetValue<std::string>(fg_flag);
info->origin_composite_name = fg_name;
}
}
bool IsCompositeFuseBasic(const GraphKernelInfo &info, const AnfNodePtr &node) {
#if ENABLE_D
std::vector<PrimitivePtr> fusable_with_reduce;
if (!info.is_before_kernel_select) {
fusable_with_reduce.push_back(prim::kPrimCast);
}
if (info.op_type == REDUCE &&
(info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) {
return std::any_of(fusable_with_reduce.begin(), fusable_with_reduce.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
#endif
return IsBasicFuseOp(node, info.is_before_kernel_select);
}
bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) {
// composite fuse composite op
if (AnfAlgo::IsGraphKernel(node)) {
#if ENABLE_D
return false;
#else
return true;
#endif
}
return IsCompositeFuseBasic(info, node);
}
void UpdateGraphKernelInfo(GraphKernelInfo *info, const AnfNodePtr &node) {
if (IsPrimitiveCNode(node)) {
info->cal_step++;
if (IsReduceOp(node)) {
info->op_type = REDUCE;
}
info->origin_composite_name += AnfAlgo::GetCNodePrimitive(node)->name() + "_";
} else if (AnfAlgo::IsGraphKernel(node)) {
auto cnode = node->cast<CNodePtr>();
auto composite_g = GetValueNode<FuncGraphPtr>(cnode->input(0));
GraphKernelInfo fuse_info;
GetGraphKernelInfo(composite_g, &fuse_info);
info->cal_step += fuse_info.cal_step;
info->origin_composite_name += fuse_info.origin_composite_name;
}
}
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, GraphKernelInfo *info, const AnfNodePtr &node) {
if (cur_node == node) {
return FOLLOW;
}
#if ENABLE_D
if (!IsPrimitiveCNode(node)) {
return EXCLUDE;
}
#else
bool is_fuse_composite = AnfAlgo::IsGraphKernel(node);
if (!IsPrimitiveCNode(node) && !is_fuse_composite) {
return EXCLUDE;
}
#endif
bool is_fusable = IsFuse(*info, node);
if (is_fusable) {
UpdateGraphKernelInfo(info, node);
}
return is_fusable ? FOLLOW : EXCLUDE;
}
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelInfo *info, const AnfNodePtr &node) {
if (cur_node == node) {
return FOLLOW;
}
if (AnfAlgo::IsGraphKernel(node)) {
auto cnode = node->cast<CNodePtr>();
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kAnfPrimitiveIndex));
auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
MS_EXCEPTION_IF_NULL(fg_attr_val);
auto fg_attr = GetValue<std::string>(fg_attr_val);
if (fg_attr == kApplyMomentumOpName) {
return FOLLOW;
}
return EXCLUDE;
}
if (!IsPrimitiveCNode(node)) {
return EXCLUDE;
}
bool is_fusable = IsFuse(*info, node);
if (is_fusable) {
UpdateGraphKernelInfo(info, node);
}
return is_fusable ? FOLLOW : EXCLUDE;
}
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
std::set<AnfNodePtr> *cached_unconnected_set) {
if (!check_node->isa<CNode>()) {
return false;
}
auto cnode = check_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
// there is a input not in fused_op_set, but the input depends on the fused_op_set
bool has_circle = false;
for (auto input : inputs) {
if (input->isa<CNode>() && !fused_op_set.count(input)) {
std::set<AnfNodePtr> done;
std::vector<AnfNodePtr> todos = {input};
while (!todos.empty()) {
auto node = todos.back();
todos.pop_back();
if (done.count(node) || cached_unconnected_set->count(node)) {
continue;
}
done.insert(node);
if (fused_op_set.count(node)) {
has_circle = true;
break;
}
if (node->isa<CNode>()) {
auto cnode_ptr = node->cast<CNodePtr>();
for (auto it : cnode_ptr->inputs()) {
if (it->isa<CNode>()) {
todos.push_back(it);
}
}
}
}
if (has_circle) {
return true;
}
cached_unconnected_set->insert(done.begin(), done.end());
}
}
return false;
}
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) {
std::set<AnfNodePtr> cached_unconnected_set;
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto include = [&fused_op_set](const AnfNodePtr &node) {
if (fused_op_set.count(node)) {
return FOLLOW;
}
return EXCLUDE;
};
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set);
// delete the circle node and the node which depend on the circle node in fused op
if (has_circle) {
auto mng = (*iter)->func_graph()->manager();
std::vector<AnfNodePtr> erase_nodes;
if (is_backward) {
erase_nodes = DeepUsersSearch(*iter, include, mng);
} else {
erase_nodes = DeepLinkedGraphSearch(*iter, include);
}
for (auto erase_node : erase_nodes) {
fused_op_set.erase(erase_node);
}
}
}
std::vector<AnfNodePtr> res;
for (auto node : fused_op) {
if (fused_op_set.count(node)) {
res.push_back(node);
}
}
return res;
}
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
if (lst->size() < 2) {
return;
}
std::vector<AnfNodePtr> res;
std::set<AnfNodePtr> node_sets(lst->begin(), lst->end());
OrderedMap<AnfNodePtr, std::set<AnfNodePtr>> ins;
OrderedMap<AnfNodePtr, OrderedSet<AnfNodePtr>> outs;
std::queue<AnfNodePtr> q;
for (auto node : *lst) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (auto input : cnode->inputs()) {
if (!node_sets.count(input)) {
continue;
}
// out_degree
outs[input].insert(node);
// in_degree
ins[node].insert(input);
}
if (!ins.count(node)) {
ins[node] = {};
}
}
for (auto p : ins) {
if (p.second.size() == 0) {
q.push(p.first);
}
}
while (!q.empty()) {
auto node = q.front();
q.pop();
res.push_back(node);
if (!outs.count(node)) {
continue;
}
for (auto out : outs[node]) {
if (!ins.count(out)) {
continue;
}
ins[out].erase(node);
if (ins[out].size() == 0) {
q.push(out);
}
}
}
lst->assign(res.begin(), res.end());
}
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
auto func_graph = cnode->func_graph();
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(cnode->input(0));
GraphKernelInfo info;
info.is_before_kernel_select = is_before_kernel_select;
GetGraphKernelInfo(graph_kernel_g, &info);
auto mng = func_graph->manager();
// Search fusable nodes according input direction.
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, &info, std::placeholders::_1);
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
std::reverse(used_nodes.begin(), used_nodes.end());
// Search fusable nodes according output direction.
auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, &info, std::placeholders::_1);
auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng);
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
if (used_nodes.size() > 1) {
used_nodes = RemoveCircle(used_nodes);
}
TopoSortForNodeList(&used_nodes);
return used_nodes;
}
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) {
MS_EXCEPTION_IF_NULL(kernel_graph);
bool changed = false;
auto &todos = kernel_graph->execution_order();
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = *iter;
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
continue;
}
auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_attr != nullptr) {
auto fg_name = GetValue<std::string>(fg_attr);
if (graph_kernel_black_list.count(fg_name) != 0) {
continue;
}
}
auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select);
if (fuse_nodes.size() <= 1) {
continue;
}
changed = true;
FuseNodesToSubGraph(fuse_nodes, kernel_graph, "", is_before_kernel_select);
}
return changed;
}
bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) {
return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph), false);
}
} // namespace opt
} // namespace mindspore

View File

@ -1,4 +1,3 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
@ -14,13 +13,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_GRAPH_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_GRAPH_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
#include <set> #include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <limits>
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
@ -31,18 +31,20 @@ enum GraphKernelType {
REDUCE, // contain reduce ops REDUCE, // contain reduce ops
CUBE, // contain cube ops CUBE, // contain cube ops
}; };
struct GraphKernelInfo { struct GraphKernelInfo {
GraphKernelType op_type = ELEWISE; GraphKernelType op_type = ELEWISE;
bool is_before_kernel_select = false; bool is_before_kernel_select = false;
int reduce_op_num = 0; int reduce_op_num = 0;
int cal_step = 0; int cal_step = 0;
std::string origin_composite_name = "";
}; };
// when reduce graph kernel's cal step is greater than this number, not fuse // when composite fuse composite the cal step is greate than this number, not fuse
#if ENABLE_D
const int MAX_REDUCE_OP_FUSION_CAL_STEP = 5; const int MAX_REDUCE_OP_FUSION_CAL_STEP = 5;
// when reduce graph kernel contain reduce op num is greater than this number, not fuse
const int MAX_REDUCE_OP_FUSION_REDUCE_NUM = 2; const int MAX_REDUCE_OP_FUSION_REDUCE_NUM = 2;
#endif
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
"LambNextMV", "LambUpdateWithLR"}; "LambNextMV", "LambUpdateWithLR"};
@ -50,14 +52,15 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst); void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
AnfNodePtr CreateNewFuseCNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const FuncGraphPtr &fg, bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select = false);
const AnfNodePtrList &inputs, const AnfNodePtrList &outputs,
bool is_before_kernel_select);
void ReplaceNewFuseCNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const AnfNodePtr &new_fuse_cnode, class CompositeOpsFusion : public Pass {
const AnfNodePtrList &outputs); public:
CompositeOpsFusion() : Pass("composite_ops_fusion") {}
void FuseGraphKernel(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select = false); ~CompositeOpsFusion() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
using FuseGraphKernelPassPtr = std::shared_ptr<CompositeOpsFusion>;
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FUSE_GRAPH_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_

View File

@ -0,0 +1,206 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
#include <vector>
#include <string>
#include <unordered_set>
#include "backend/session/anf_runtime_algorithm.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "mindspore/core/ir/graph_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "vm/segment_runner.h"
#include "runtime/device/kernel_info.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/kernel_build_info.h"
namespace mindspore {
namespace opt {
namespace {
constexpr auto kJsonKeyExpandInfo = "expand_info";
#define GET_VALUE_FOR_JSON(JSON, VALUE, VALUE_ELEM, TYPE_NAME, TYPE) \
if (VALUE_ELEM->isa<TYPE_NAME>()) { \
JSON = GetValue<TYPE>(VALUE); \
}
nlohmann::json ExpandAttrJsonInfo(const CNodePtr &cnode) {
nlohmann::json attrs_json;
if (auto prim = GetCNodePrimitive(cnode); prim != nullptr) {
auto attrs = prim->attrs();
for (const auto &[k, v] : attrs) {
nlohmann::json attr_json;
MS_LOG(DEBUG) << "attr key is : " << k << " and value type is : " << v->type_name();
GET_VALUE_FOR_JSON(attr_json[k], v, v, Int32Imm, int);
GET_VALUE_FOR_JSON(attr_json[k], v, v, Int64Imm, int64_t);
GET_VALUE_FOR_JSON(attr_json[k], v, v, UInt32Imm, uint32_t);
GET_VALUE_FOR_JSON(attr_json[k], v, v, UInt64Imm, uint64_t);
GET_VALUE_FOR_JSON(attr_json[k], v, v, FP32Imm, float);
GET_VALUE_FOR_JSON(attr_json[k], v, v, FP64Imm, double);
GET_VALUE_FOR_JSON(attr_json[k], v, v, BoolImm, bool);
GET_VALUE_FOR_JSON(attr_json[k], v, v, StringImm, std::string);
if (v->isa<ValueList>() || v->isa<ValueTuple>()) {
auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value();
if (!vec.empty()) {
MS_LOG(DEBUG) << "value type is : " << vec[0]->type_name();
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], Int32Imm, std::vector<int>);
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], Int64Imm, std::vector<int64_t>);
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], UInt32Imm, std::vector<uint32_t>);
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], UInt64Imm, std::vector<uint64_t>);
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], FP32Imm, std::vector<float>);
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], FP64Imm, std::vector<double>);
GET_VALUE_FOR_JSON(attr_json[k], v, vec[0], StringImm, std::vector<std::string>);
}
}
if (!attr_json.empty()) {
attrs_json.push_back(attr_json);
}
}
}
return attrs_json;
}
bool ExpandJsonInfo(const CNodePtr &cnode, nlohmann::json *kernel_json) {
MS_EXCEPTION_IF_NULL(kernel_json);
if (kernel_json->find(kJsonKeyExpandInfo) != kernel_json->end()) {
return false;
}
nlohmann::json expand_info;
expand_info[kernel::kJsonKeyAttr] = ExpandAttrJsonInfo(cnode);
expand_info[kernel::kJsonKeyName] = AnfAlgo::GetCNodeName(cnode);
expand_info[kernel::kJsonKeyProcess] = kernel::GetProcessorStr(cnode);
std::vector<nlohmann::json> inputs_info;
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) {
nlohmann::json input_info;
input_info[kernel::kJsonKeyFormat] = AnfAlgo::GetInputFormat(cnode, i);
input_info[kernel::kJsonKeyInferShape] = AnfAlgo::GetPrevNodeOutputInferShape(cnode, i);
input_info[kernel::kJsonKeyShape] = AnfAlgo::GetInputDeviceShape(cnode, i);
input_info[kernel::kJsonKeyInferDataType] =
kernel::TypeId2String(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, i));
input_info[kernel::kJsonKeyDataType] = kernel::TypeId2String(AnfAlgo::GetInputDeviceDataType(cnode, i));
inputs_info.push_back(input_info);
}
expand_info[kernel::kJsonKeyInputDesc] = inputs_info;
std::vector<nlohmann::json> outputs_info;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(cnode); ++i) {
nlohmann::json output_info;
output_info[kernel::kJsonKeyFormat] = AnfAlgo::GetOutputFormat(cnode, i);
output_info[kernel::kJsonKeyInferShape] = AnfAlgo::GetOutputInferShape(cnode, i);
output_info[kernel::kJsonKeyShape] = AnfAlgo::GetOutputDeviceShape(cnode, i);
output_info[kernel::kJsonKeyInferDataType] = kernel::TypeId2String(AnfAlgo::GetOutputInferDataType(cnode, i));
output_info[kernel::kJsonKeyDataType] = kernel::TypeId2String(AnfAlgo::GetOutputDeviceDataType(cnode, i));
outputs_info.push_back(output_info);
}
expand_info[kernel::kJsonKeyOutputDesc] = outputs_info;
(*kernel_json)[kJsonKeyExpandInfo] = expand_info;
return true;
}
} // namespace
FuncGraphPtr GraphKernelExpander::CreateExpandFuncGraph(const CNodePtr &node) {
nlohmann::json kernel_json;
if (!ExpandJsonInfo(node, &kernel_json)) {
MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump();
return nullptr;
}
auto node_desc_str = kernel_json.dump();
// call graph kernel ops generator.
MS_LOG(DEBUG) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] with input json:\n" << node_desc_str;
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str);
// parse result.
if (ret.is(py::none())) {
MS_LOG(ERROR) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] return invalid result, input json:\n"
<< node_desc_str;
return nullptr;
}
std::string kernel_desc_str = py::cast<std::string>(ret);
if (kernel_desc_str.empty()) {
MS_LOG(ERROR) << "Jump expand node: " << node->fullname_with_scope();
return nullptr;
}
// decode json to func_graph.
std::vector<AnfNodePtr> ori_inputs(node->inputs().begin() + 1, node->inputs().end());
return JsonDescToAnf(kernel_desc_str, ori_inputs);
}
AnfNodePtr GraphKernelExpander::CreateExpandGraphKernel(const FuncGraphPtr &func_graph,
const FuncGraphPtr &new_func_graph, const CNodePtr &node) {
std::vector<AnfNodePtr> inputs(node->inputs().begin() + 1, node->inputs().end());
AnfNodePtrList kernel_nodes;
AnfNodePtrList outputs;
kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes);
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs, false);
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs, AnfAlgo::GetProcessor(node));
std::string graph_kernel_flag;
std::for_each(kernel_nodes.begin(), kernel_nodes.end(), [&graph_kernel_flag](const AnfNodePtr &node) {
static_cast<void>(graph_kernel_flag.append(AnfAlgo::GetCNodeName(node)).append("_"));
});
MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope() << " with: " << graph_kernel_flag;
return graph_kernel_node;
}
bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
std::reverse(todos.begin(), todos.end());
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
for (const auto &n : todos) {
auto node = n->cast<CNodePtr>();
if (node == nullptr || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) || !CanExpand(node)) {
continue;
}
MS_LOG(INFO) << "Expand process node: " << node->fullname_with_scope();
auto new_func_graph = CreateExpandFuncGraph(node);
if (new_func_graph == nullptr) {
MS_LOG(ERROR) << "Decode fused nodes failed, " << node->fullname_with_scope();
continue;
}
mng->AddFuncGraph(new_func_graph);
MS_LOG(DEBUG) << "decode fused nodes success.";
auto graph_kernel_node = CreateExpandGraphKernel(func_graph, new_func_graph, node);
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(node)));
MS_LOG(INFO) << "create new cnode success.";
// replace origin node.
(void)mng->Replace(node, graph_kernel_node);
changed = true;
}
return changed;
}
bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
expand_ops_ = GetExpandOps();
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
return DoExpand(func_graph);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_
#include <memory>
#include <unordered_set>
#include "ir/func_graph.h"
#include "backend/optimizer/common/pass.h"
namespace mindspore {
namespace opt {
class GraphKernelExpander : public Pass {
public:
GraphKernelExpander() : Pass("graph_kernel_expander") {}
~GraphKernelExpander() override = default;
bool Run(const FuncGraphPtr &func_graph);
private:
FuncGraphPtr CreateExpandFuncGraph(const CNodePtr &node);
bool DoExpand(const FuncGraphPtr &func_graph);
AnfNodePtr CreateExpandGraphKernel(const FuncGraphPtr &func_graph, const FuncGraphPtr &new_func_graph,
const CNodePtr &node);
bool CanExpand(const CNodePtr &node) {
return std::any_of(expand_ops_.begin(), expand_ops_.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
private:
std::unordered_set<PrimitivePtr> expand_ops_;
};
using GraphKernelExpanderPtr = std::shared_ptr<GraphKernelExpander>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_EXPANDER_H_

View File

@ -0,0 +1,674 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include <map>
#include <unordered_set>
#include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/action.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "vm/segment_runner.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
#include "ir/func_graph_cloner.h"
#include "ir/func_graph.h"
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
#ifdef ENABLE_D
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
#endif
namespace mindspore {
namespace opt {
namespace {
void DebugDump(const FuncGraphPtr &graph, std::stringstream *buf) {
(*buf) << "Parameters: \n";
const auto &parameters = graph->parameters();
(*buf) << "size: " << parameters.size() << "\n";
for (const auto &p : parameters) {
(*buf) << "\t" << p->DebugString(2) << "\n";
}
(*buf) << "ValueNodes: \n";
const auto &value_nodes = graph->value_nodes();
(*buf) << "size: " << value_nodes.size() << "\n";
for (const auto &v : value_nodes) {
(*buf) << "\t" << v.first->DebugString(2) << "\n";
}
(*buf) << "CNodes: \n";
const auto &all_nodes = graph->nodes();
(*buf) << "size: " << all_nodes.size() << "\n";
for (const auto &n : all_nodes) {
(*buf) << "\t" << n->DebugString(2) << "\n";
}
}
bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
MS_EXCEPTION_IF_NULL(real_outs);
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
auto &inputs = out->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
real_outs->push_back(inputs[i]);
}
return true;
}
if (auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out); fg != nullptr) {
auto fg_out = fg->output();
if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) {
auto inputs = fg_out->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
real_outs->push_back(inputs[i]);
}
return true;
}
}
return false;
}
AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) {
auto out_spec = node->abstract();
if (out_spec->isa<abstract::AbstractTuple>()) {
return out_spec->cast<abstract::AbstractTuplePtr>()->elements()[output_idx];
}
return out_spec;
}
ValueNodePtr ProcessAttrsForCast(const CNodePtr &cnode, const std::string &attr_name) {
auto dst_type = AnfAlgo::GetNodeAttr<std::string>(cnode, attr_name);
auto type = TypeIdToType(kernel::DtypeToTypeId(dst_type));
auto type_val_node = NewValueNode(type);
return type_val_node;
}
const std::map<std::string, std::function<ValueNodePtr(const CNodePtr &cnode, const std::string &attr_name)>>
attrs_process_map = {
{kCastOpName, ProcessAttrsForCast},
};
ValueNodePtr ProcessAttrValue(const CNodePtr &cnode, const std::string &attr_name) {
auto op_name = AnfAlgo::GetCNodeName(cnode);
if (attrs_process_map.count(op_name) != 0) {
return attrs_process_map.at(op_name)(cnode, attr_name);
}
auto attr_val = AnfAlgo::GetNodeAttr<ValuePtr>(cnode, attr_name);
auto attr_val_node = NewValueNode(attr_val);
return attr_val_node;
}
AnfNodePtr ConstAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::unordered_set<size_t> &input_attrs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(DEBUG) << "process node: " << cnode->DebugString(2);
if (input_attrs.empty()) {
return nullptr;
}
auto input_names = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames);
MS_LOG(DEBUG) << "ori_input_names: " << kernel::Vector2Str(input_names);
std::vector<AnfNodePtr> new_inputs;
std::vector<std::string> new_input_names;
const auto &inputs = cnode->inputs();
for (size_t i = 0; i < inputs.size() - 1; ++i) {
new_input_names.push_back(input_names[i]);
}
(void)new_inputs.insert(new_inputs.end(), inputs.begin(), inputs.end());
bool need_update = false;
for (size_t i = inputs.size() - 1; i < input_names.size(); ++i) {
auto attr_name = input_names[i];
if (input_attrs.find(i) == input_attrs.end()) {
MS_LOG(WARNING) << "Other type input between tensors and attrs, name: " << attr_name
<< ", node: " << cnode->DebugString(2);
new_input_names.push_back(attr_name);
continue;
}
if (!AnfAlgo::HasNodeAttr(attr_name, cnode)) {
MS_LOG(EXCEPTION) << "Attr: " << attr_name << " not found in node: " << cnode->DebugString(2);
}
// Hardcode. It should convert attrs value according to format, like op ReduceSum.
auto attr_val_node = ProcessAttrValue(cnode, attr_name);
new_inputs.push_back(attr_val_node);
new_input_names.push_back(attr_name);
need_update = true;
MS_LOG(DEBUG) << "convert attr: " << attr_name << " to input, value: " << attr_val_node;
}
MS_LOG(DEBUG) << "new_input_names: " << kernel::Vector2Str(new_input_names);
if (!need_update) {
return nullptr;
}
auto new_cnode = func_graph->NewCNode(new_inputs);
// we do not modify abstract and kernel info.
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_kernel_info(cnode->kernel_info_ptr());
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(new_input_names), new_cnode);
return new_cnode;
}
AnfNodePtr DeleteAttrInInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::unordered_set<size_t> &input_attrs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(DEBUG) << "process node: " << cnode->DebugString(2);
if (input_attrs.empty()) {
return nullptr;
}
auto input_names = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames);
MS_LOG(DEBUG) << "ori_input_names: " << kernel::Vector2Str(input_names);
std::vector<AnfNodePtr> new_inputs;
std::vector<std::string> new_input_names;
const auto &inputs = cnode->inputs();
new_inputs.push_back(inputs[0]);
bool need_update = false;
for (size_t i = 0; i < inputs.size() - 1; ++i) {
auto input_node = inputs[i + 1];
MS_EXCEPTION_IF_NULL(input_node);
// The attrs counts from 0
if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
MS_LOG(DEBUG) << "delete attr input: " << i << " of node: " << cnode->DebugString(2);
if (i >= input_names.size()) {
MS_LOG(EXCEPTION) << "Index " << i << " is larger than input names size: " << input_names.size();
}
need_update = true;
} else {
new_inputs.push_back(input_node);
if (i < input_names.size()) {
new_input_names.push_back(input_names[i]);
}
}
}
MS_LOG(DEBUG) << "new_input_names: " << kernel::Vector2Str(new_input_names);
if (!need_update) {
return nullptr;
}
auto new_cnode = func_graph->NewCNode(new_inputs);
// we do not modify abstract and kernel info.
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_kernel_info(cnode->kernel_info_ptr());
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(new_input_names), new_cnode);
return new_cnode;
}
AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) {
AnfNodePtrList outs;
auto out_node = (*fg)->output();
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
std::vector<AnfNodePtr> output_args;
auto out_cnode = out_node->cast<CNodePtr>();
for (auto out : out_cnode->inputs()) {
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
auto inputs = out->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
output_args.push_back(inputs[i]);
}
} else {
output_args.push_back(out);
}
}
if (output_args.size() != out_cnode->inputs().size()) {
auto new_out = (*fg)->NewCNode(output_args);
(*mng)->Replace(out_node, new_out);
}
for (size_t i = 1; i < output_args.size(); ++i) {
outs.push_back(output_args[i]);
}
return outs;
}
outs.push_back(out_node);
return outs;
}
} // namespace
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs, kernel::Processor processor) {
std::vector<std::string> graph_input_format;
std::vector<TypeId> graph_input_type;
std::vector<std::string> graph_output_format;
std::vector<TypeId> graph_output_type;
for (size_t i = 0; i < inputs.size(); ++i) {
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
graph_input_format.push_back(input_format);
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
graph_input_type.push_back(input_type);
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
fg->parameters()[i]->set_abstract(input_abs);
}
auto new_outputs = outputs;
if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) {
std::vector<AnfNodePtr> real_outs;
if (IsMakeTupleOut(outputs[0], &real_outs)) {
new_outputs = real_outs;
}
}
for (size_t i = 0; i < new_outputs.size(); ++i) {
auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0);
auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
graph_output_format.push_back(output_format);
graph_output_type.push_back(output_type);
}
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(graph_input_format);
graph_info_builder.SetInputsDeviceType(graph_input_type);
graph_info_builder.SetOutputsFormat(graph_output_format);
graph_info_builder.SetOutputsDeviceType(graph_output_type);
graph_info_builder.SetProcessor(processor);
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto graph_selected_info = graph_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get());
}
void ConstAttrToInput(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(func_graph, &todos);
for (const auto &node : todos) {
ConstInputToAttrInfoRegister reg;
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(node), &reg)) {
continue;
}
auto new_node = ConstAttrToInput(func_graph, node->cast<CNodePtr>(), reg.GetConstInputAttrInfo());
if (new_node != nullptr && new_node != node) {
mng->Replace(node, new_node);
}
}
}
void DeleteAttrInInput(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(func_graph, &todos);
for (const auto &node : todos) {
ConstInputToAttrInfoRegister reg;
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(node), &reg)) {
continue;
}
auto new_node = DeleteAttrInInput(func_graph, node->cast<CNodePtr>(), reg.GetConstInputAttrInfo());
if (new_node != nullptr && new_node != node) {
mng->Replace(node, new_node);
}
}
}
AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) {
AnfNodePtrList res;
if (outs.size() <= 1) {
return outs;
}
for (auto out : outs) {
AnfNodePtrList real_outs;
if (IsMakeTupleOut(out, &real_outs)) {
res.insert(res.end(), real_outs.begin(), real_outs.end());
continue;
}
res.push_back(out);
}
return res;
}
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs, bool is_before_kernel_select) {
auto func_node = NewValueNode(fg);
std::vector<AnfNodePtr> fn_inputs;
fn_inputs.push_back(func_node);
fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end());
auto fuse_cnode = func_graph->NewCNode(fn_inputs);
// Set output abstract
if (outputs.size() > 1) {
std::vector<AbstractBasePtr> out_specs;
for (size_t i = 0; i < outputs.size(); ++i) {
out_specs.push_back(outputs[i]->abstract());
}
auto out_spec = std::make_shared<abstract::AbstractTuple>(out_specs);
fuse_cnode->set_abstract(out_spec);
} else {
fuse_cnode->set_abstract(outputs[0]->abstract());
}
// Set parameter abstract.
for (size_t i = 0; i < inputs.size(); ++i) {
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
fg->parameters()[i]->set_abstract(input_abs);
if (is_before_kernel_select) {
fg->parameters()[i]->set_kernel_info(std::make_shared<device::KernelInfo>());
}
}
return fuse_cnode;
}
void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode,
const AnfNodePtrList &outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
// single out
if (outputs.size() == 1) {
mng->Replace(outputs[0], new_fuse_cnode);
return;
}
std::vector<AnfNodePtr> fn_inputs;
for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
AnfNodePtrList real_outs;
// not make tuple out, replace
if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) {
fn_inputs.clear();
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
fn_inputs.push_back(new_fuse_cnode);
fn_inputs.push_back(NewValueNode(MakeValue(SizeToInt(out_idx))));
auto new_out = func_graph->NewCNode(fn_inputs);
new_out->set_abstract(outputs[out_idx]->abstract());
mng->Replace(outputs[out_idx], new_out);
continue;
}
// the out is make tuple , modify the get_item node's value
auto users = mng->node_users()[outputs[out_idx]];
for (auto &user : users) {
auto use_node = user.first;
if (!use_node->isa<CNode>() || !IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem)) {
continue;
}
auto get_item_cnode = use_node->cast<CNodePtr>();
auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(value_input);
auto value_node = value_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
int new_item_idx = SizeToInt(out_idx) + item_idx;
fn_inputs.clear();
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
fn_inputs.push_back(new_fuse_cnode);
fn_inputs.push_back(NewValueNode(new_item_idx));
auto new_out = func_graph->NewCNode(fn_inputs);
new_out->set_abstract(get_item_cnode->abstract());
mng->Replace(get_item_cnode, new_out);
}
}
}
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix,
bool is_before_kernel_select) {
if (fuse_nodes.empty()) {
return;
}
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
// Remove nest make tuple in outs
auto expand_out = GetExpandOuts(outputs);
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, expand_out, is_before_kernel_select);
if (!is_before_kernel_select) {
SetNewKernelInfo(fuse_new_node, fg, inputs, expand_out, AnfAlgo::GetProcessor(fuse_nodes[0]));
}
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
// Inline origin graphkernel
auto cnodes = fg->GetOrderedCnodes();
for (const auto &n : cnodes) {
if (!AnfAlgo::IsGraphKernel(n)) {
continue;
}
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0));
AnfNodePtrList ins;
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end());
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope());
mng->Replace(n, out);
}
EliminateMakeTuple(&fg, &mng);
// set graphKernel attr
std::string fuse_op_name = "";
for (auto &fuse_node : fuse_nodes) {
if (IsPrimitiveCNode(fuse_node)) {
fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_";
} else if (AnfAlgo::IsGraphKernel(fuse_node)) {
auto fuse_cnode = fuse_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(fuse_cnode);
auto graph_kernel_fg = GetValueNode<FuncGraphPtr>(fuse_cnode->input(kAnfPrimitiveIndex));
auto fg_flag_val = graph_kernel_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
auto fuse_fg_name = GetValue<std::string>(fg_flag_val);
fuse_op_name += fuse_fg_name + "_";
}
}
fuse_op_name += postfix;
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
}
bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map) {
MS_EXCEPTION_IF_NULL(op_desc);
if (nodes.empty()) {
MS_LOG(ERROR) << "Input nodes is empty.";
return false;
}
bool has_graph_kernel =
std::any_of(nodes.begin(), nodes.end(), [](const AnfNodePtr &node) { return AnfAlgo::IsGraphKernel(node); });
bool is_single_graph_kernel = has_graph_kernel && nodes.size() == 1;
auto gen_json = [&dump_option, &op_desc, &address_node_map](const AnfNodePtrList &op_nodes,
const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs) -> bool {
kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option);
if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, inputs, outputs)) {
MS_LOG(ERROR) << "Collect json desc failed.";
return false;
}
*op_desc = akg_kernel_json_generator.kernel_json();
if (address_node_map != nullptr) {
*address_node_map = akg_kernel_json_generator.address_node_map();
}
std::string fused_name;
std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) {
(void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_");
});
MS_LOG(INFO) << "Collect fusion json: " << fused_name;
return true;
};
FuncGraphPtr fg;
AnfNodePtrList op_nodes;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
if (is_single_graph_kernel) {
fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
return gen_json(op_nodes, inputs, outputs);
} else if (!has_graph_kernel) {
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes);
op_nodes = nodes;
return gen_json(op_nodes, inputs, outputs);
}
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes);
auto mng = Manage(fg, false);
fg->set_manager(mng);
// Inline origin graph kernel
auto fg_nodes = fg->GetOrderedCnodes();
for (auto const &n : fg_nodes) {
if (!AnfAlgo::IsGraphKernel(n)) {
continue;
}
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0));
AnfNodePtrList ins;
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end());
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope());
mng->Replace(n, out);
}
inputs.clear();
outputs.clear();
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
return gen_json(op_nodes, inputs, outputs);
}
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, DumpOption dump_option, nlohmann::json *op_desc) {
MS_EXCEPTION_IF_NULL(op_desc);
std::vector<nlohmann::json> graphs_desc;
for (auto const &graph_nodes : graphs) {
nlohmann::json desc;
if (!AnfToJsonDesc(graph_nodes, dump_option, &desc)) {
MS_LOG(ERROR) << "Collect json desc failed.";
return false;
}
graphs_desc.push_back(desc);
}
if (graphs_desc.empty()) {
MS_LOG(ERROR) << "Collect zero json desc.";
return false;
}
if (graphs_desc.size() > 1) {
nlohmann::json op_json_desc;
op_json_desc[kJsonKeyMultiGraph] = true;
op_json_desc[kJsonKeyGraphDesc] = graphs_desc;
*op_desc = op_json_desc;
return true;
}
*op_desc = graphs_desc[0];
return true;
}
FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs) {
kernel::AkgKernelJsonDecoder akg_kernel_json_decoder;
auto fg = akg_kernel_json_decoder.DecodeFusedNodes(json_desc);
if (fg == nullptr) {
MS_LOG(ERROR) << "Akg decode json to graph failed.";
return nullptr;
}
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
auto mng = resource->manager();
MS_EXCEPTION_IF_NULL(mng);
mng->AddFuncGraph(fg);
ConstAttrToInput(fg);
std::stringstream buf;
buf << "===================== graph after ConstAttrToInput " << fg->ToString() << " =====================\n";
DebugDump(fg, &buf);
MS_LOG(DEBUG) << buf.str();
// Do infer and specialize.
AbstractBasePtrList args_spec_list;
std::for_each(inputs.begin(), inputs.end(),
[&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); });
auto infer_fg = pipeline::Renormalize(resource, fg, args_spec_list);
if (infer_fg == nullptr) {
MS_LOG(ERROR) << "Infer decoded graph failed.";
return nullptr;
}
buf.str("");
buf << "===================== graph after Renormalize " << infer_fg->ToString() << " =====================\n";
DebugDump(infer_fg, &buf);
MS_LOG(DEBUG) << buf.str();
// delete no use inputs(attrs), like op ReduceSum(axis).
DeleteAttrInInput(infer_fg);
buf.str("");
buf << "===================== graph after DeleteAttrInInput " << infer_fg->ToString() << " =====================\n";
DebugDump(infer_fg, &buf);
MS_LOG(DEBUG) << buf.str();
// clone a new graph.
auto new_fg = TransformableClone(infer_fg, std::make_shared<TraceTransform>("akg_decode"));
return new_fg;
}
bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map,
std::vector<AnfNodePtrList> *res_graphs) {
MS_EXCEPTION_IF_NULL(res_graphs);
auto kernel_json = nlohmann::json::parse(json_desc);
if (kernel_json.find(kJsonKeyMultiGraph) == kernel_json.end() || kernel_json[kJsonKeyMultiGraph].is_null()) {
// not multi graphs.
MS_LOG(ERROR) << "Input json is not multi graph, " << json_desc;
return false;
}
kernel::AkgKernelJsonDecoder akg_kernel_json_decoder;
std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc];
if (graph_descs.empty()) {
MS_LOG(ERROR) << "No sub graph found, " << json_desc;
return false;
}
for (size_t i = 0; i < graph_descs.size(); ++i) {
const auto &graph_desc = graph_descs[i];
AnfNodePtrList res_graph;
if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
return false;
}
res_graphs->push_back(res_graph);
}
return true;
}
std::unordered_set<PrimitivePtr> GetExpandOps() {
std::unordered_set<PrimitivePtr> expand_ops = {
prim::kPrimSquare,
prim::kPrimGelu,
prim::kPrimSoftmax,
prim::kPrimLayerNorm,
};
return expand_ops;
}
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) {
std::stringstream name;
if (prefix != "") {
name << prefix << "_";
}
for (const auto &node : cnodes) {
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
name << AnfAlgo::GetCNodeName(node) << "_";
}
}
if (postfix != "") {
name << postfix;
}
return name.str();
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <unordered_set>
#include <nlohmann/json.hpp>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
namespace mindspore {
namespace opt {
using kernel::DumpOption;
constexpr auto kGraphKernelModule = "mindspore._extends.graph_kernel";
constexpr auto kGraphKernelSplitFunc = "split_with_json";
constexpr auto kGetGraphKernelOpExpander = "get_op_expander";
constexpr auto kJsonKeyMultiGraph = "multi_graph";
constexpr auto kJsonKeyGraphDesc = "graph_desc";
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs, kernel::Processor processor);
AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs);
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs, bool is_before_kernel_select);
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
const AnfNodePtrList &outputs);
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix,
bool is_before_kernel_select);
bool AnfToJsonDesc(const AnfNodePtrList &nodes, DumpOption dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map = nullptr);
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, DumpOption dump_option, nlohmann::json *op_desc);
FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs);
bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map,
std::vector<AnfNodePtrList> *res_graphs);
std::unordered_set<PrimitivePtr> GetExpandOps();
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_

View File

@ -0,0 +1,742 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
#include <algorithm>
#include <vector>
#include <string>
#include <unordered_set>
#include <utility>
#include <queue>
#include <map>
#include <unordered_map>
#include "frontend/optimizer/irpass.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "debug/anf_ir_dump.h"
namespace mindspore {
namespace opt {
namespace {
void TraverseFuncGraphFromCNode(const CNodePtr &cnode, std::function<void(AnfNodePtr &)> callback) {
std::unordered_set<AnfNodePtr> visited;
std::queue<AnfNodePtr> que;
que.push(cnode);
visited.insert(cnode);
while (!que.empty()) {
auto ft_node = que.front();
que.pop();
callback(ft_node);
auto ft_cnode = ft_node->cast<CNodePtr>();
if (ft_cnode == nullptr) continue;
for (const auto &in_node : ft_cnode->inputs()) {
if (visited.count(in_node) == 0) {
que.push(in_node);
visited.insert(in_node);
}
}
}
}
// Visited each AnfNode once, use callback to do the job on AnfNode
inline void TraverseFuncGraph(const FuncGraphPtr &root, std::function<void(AnfNodePtr &)> callback) {
TraverseFuncGraphFromCNode(root->get_return(), callback);
}
class AreaGraph;
class Splitter;
class Area {
public:
explicit Area(const AnfNodePtrList &anf_arr) {
nodes_.insert(anf_arr.begin(), anf_arr.end());
for (auto &node : anf_arr) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) continue;
const auto &inputs = cnode->inputs();
if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) {
spy_cnodes_.push_back(node);
}
}
}
// Set the external inputs of spy as a Parameter.
void CreateParameters(const FuncGraphPtr &func_graph, std::unordered_map<ParameterPtr, AnfNodePtr> *param_node_map) {
std::unordered_map<AnfNodePtr, ParameterPtr> node_param_map;
for (auto node : this->spy_cnodes_) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
AnfNodePtr in_node = cnode->input(i);
if (!IsExternalCNode(in_node)) continue;
auto it = node_param_map.find(in_node);
if (it == node_param_map.end()) {
auto new_param = std::make_shared<Parameter>(func_graph);
new_param->set_abstract(in_node->abstract());
func_graph->add_parameter(new_param);
node_param_map.insert(std::make_pair(in_node, new_param));
cnode->set_input(i, new_param);
} else {
cnode->set_input(i, it->second);
}
}
}
this->spy_cnodes_.clear(); // spy list is not useful anymore
for (auto &&elem : node_param_map) {
param_node_map->insert(std::make_pair(elem.second, elem.first));
}
return;
}
// Make a return node for traitor nodes.
void CreateReturnNode(const FuncGraphPtr &func_graph, std::unordered_map<AnfNodePtr, size_t> *tuple_node_index) {
// If there's no traitor in the area, it means that this area is the last part
// of the original FuncGraph, it already contains the original Return node.
if (traitor_nodes_.empty()) {
for (auto &node : nodes_) {
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
func_graph->set_return(node->cast<CNodePtr>());
node->set_func_graph(func_graph);
return;
}
}
MS_LOG(ERROR) << "Cannot find the return node in " << func_graph->ToString();
return;
}
AnfNodePtrList return_inputs = {NewValueNode(prim::kPrimReturn)};
if (traitor_nodes_.size() > 1) {
// The area has multiple output, it's necessary to make a tuple for them.
AnfNodePtrList maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
AbstractBasePtrList abstracts;
size_t i = 0;
for (auto &traitor : traitor_nodes_) {
tuple_node_index->insert(std::make_pair(traitor, i++));
maketuple_inputs.push_back(traitor);
abstracts.push_back(traitor->abstract());
}
auto maketuple_node = func_graph->NewCNode(maketuple_inputs);
maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstracts));
nodes_.insert(maketuple_node);
return_inputs.push_back(maketuple_node);
} else {
return_inputs.push_back(traitor_nodes_[0]);
}
auto return_node = func_graph->NewCNode(return_inputs);
return_node->set_abstract(return_inputs.back()->abstract());
func_graph->set_return(return_node);
nodes_.insert(return_node);
traitor_nodes_.clear(); // traitor list is not useful anymore
return;
}
void AddTraitor(const AnfNodePtr &node) {
if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) {
traitor_nodes_.push_back(node);
}
}
friend AreaGraph;
friend Splitter;
private:
// This is a CNode that does not belong to this area.
bool IsExternalCNode(const AnfNodePtr &node) { return node->isa<CNode>() && this->nodes_.count(node) == 0; }
// nodes in this area
std::unordered_set<AnfNodePtr> nodes_;
// if a node's output is used by other Area, it's a traitor
std::vector<AnfNodePtr> traitor_nodes_;
// if a node use other Area's output, it's a spy
std::vector<AnfNodePtr> spy_cnodes_;
};
class AreaGraph {
public:
using AreaGraphPtr = std::shared_ptr<AreaGraph>;
// Build an area graph to maintain the relation between areas.
// Input node_groups: A group list, each element is a AnfNode list representing the node set in this group.
static AreaGraphPtr BuildAreaGraph(const std::vector<AnfNodePtrList> &node_groups) {
AreaGraph *area_graph_ptr = new (std::nothrow) AreaGraph(node_groups);
if (!area_graph_ptr) return nullptr;
auto area_graph = AreaGraphPtr(area_graph_ptr);
if (!area_graph->TopoSort()) {
MS_LOG(WARNING) << "The groups have a cycle.";
return nullptr;
}
return area_graph;
}
// Split the graph to multiple areas, and reconnect the edges between the areas.
// The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs.
// The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting.
void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes,
std::vector<size_t> *cnode_group_id, std::function<void(Area *)> expand_callback) {
main_cnodes->clear();
main_cnodes->resize(areas_.size(), nullptr);
for (auto &area : this->areas_) {
expand_callback(&area);
}
for (auto index : topo_order_) {
auto &current_area = areas_[index];
auto sub_func_graph = std::make_shared<FuncGraph>();
std::unordered_map<ParameterPtr, AnfNodePtr> param_node_map;
current_area.CreateParameters(sub_func_graph, &param_node_map);
current_area.CreateReturnNode(sub_func_graph, &node_index_in_returned_tuple_);
auto new_main_cnode = this->CreateMainCNode(main_func_graph, sub_func_graph, *main_cnodes, param_node_map);
(*main_cnodes)[index] = new_main_cnode;
}
SortCNodes(main_cnodes);
cnode_group_id->swap(topo_order_); // The topo_order is not used anymore.
return;
}
private:
explicit AreaGraph(const std::vector<AnfNodePtrList> &node_groups) : edge_prev_(node_groups.size()) {
for (size_t i = 0; i < node_groups.size(); ++i) {
areas_.emplace_back(node_groups[i]);
for (const auto &node : node_groups[i]) {
node_area_map_[node] = i;
}
}
for (auto &area : areas_) {
for (auto &spy : area.spy_cnodes_) {
auto cnode = spy->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
size_t v = node_area_map_[spy];
for (auto &in_node : cnode->inputs()) {
if (!in_node->isa<CNode>()) continue;
// area edge u -> v
size_t u = node_area_map_[in_node];
if (u == v) continue;
areas_[u].AddTraitor(in_node);
if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) {
edge_prev_[v].push_back(u);
}
}
}
}
}
// Topological sort the areas.
bool TopoSort() {
std::vector<int> out_degree(edge_prev_.size(), 0);
std::queue<size_t> que;
for (auto &prev : edge_prev_) {
for (size_t i : prev) {
out_degree[i]++;
}
}
for (size_t i = 0; i < out_degree.size(); ++i) {
if (out_degree[i] == 0) que.push(i);
}
while (!que.empty()) {
size_t u = que.front();
que.pop();
topo_order_.push_back(u);
for (size_t i : edge_prev_[u]) {
if (--out_degree[i] == 0) que.push(i);
}
}
std::reverse(topo_order_.begin(), topo_order_.end());
return topo_order_.size() == areas_.size();
}
// Make a CNode in main graph to hold the sub_func_graph.
CNodePtr CreateMainCNode(const FuncGraphPtr &main_func_graph, const FuncGraphPtr &sub_func_graph,
const std::vector<CNodePtr> &main_cnodes,
const std::unordered_map<ParameterPtr, AnfNodePtr> &param_node_map) {
AnfNodePtrList main_cnode_inputs = {NewValueNode(sub_func_graph)};
for (const auto &param : sub_func_graph->parameters()) {
// assert the param exists.
const auto &input_node = param_node_map.find(param->cast<ParameterPtr>())->second;
size_t input_area = node_area_map_[input_node];
// if the input node is in a tuple, then we need to create a GetItem fot it.
if (node_index_in_returned_tuple_.count(input_node) != 0) {
int idx_val = SizeToInt(node_index_in_returned_tuple_[input_node]);
auto idx = NewValueNode(idx_val);
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
AnfNodePtrList getitem_inputs = {NewValueNode(prim::kPrimTupleGetItem), main_cnodes[input_area], idx};
auto getitem_node = main_func_graph->NewCNode(getitem_inputs);
getitem_node->set_abstract(main_cnodes[input_area]->abstract());
main_cnode_inputs.push_back(getitem_node);
} else {
main_cnode_inputs.push_back(main_cnodes[input_area]);
}
}
auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs);
new_main_cnode->set_abstract(sub_func_graph->get_return()->abstract());
return new_main_cnode;
}
void SortCNodes(std::vector<CNodePtr> *main_cnodes) {
std::vector<CNodePtr> main_cnodes_sorted;
std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted),
[main_cnodes](int index) { return main_cnodes->at(index); });
main_cnodes->swap(main_cnodes_sorted);
}
// Areas in this subgraph
std::vector<Area> areas_;
// Adjacency table of areas
std::vector<std::vector<size_t>> edge_prev_;
// Topological order of areas
std::vector<size_t> topo_order_;
// Map AnfNode to Area id
std::unordered_map<AnfNodePtr, size_t> node_area_map_;
// Map the nodes to their index if there are multiple value in an area
std::unordered_map<AnfNodePtr, size_t> node_index_in_returned_tuple_;
};
class Splitter {
public:
class SplitSchemer {
public:
virtual bool Split(const FuncGraphPtr &func_graph) = 0;
virtual bool NeedInline(size_t group_id) const { return false; }
const std::vector<AnfNodePtrList> &split_plan() const { return split_plan_; }
protected:
std::vector<AnfNodePtrList> split_plan_;
};
using SplitSchemerPtr = std::shared_ptr<SplitSchemer>;
using SplitterPtr = std::shared_ptr<Splitter>;
bool Split() {
GenParamMap();
auto ori_sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(old_subgraph_cnode_);
if (!split_schemer_->Split(ori_sub_func_graph)) {
return false;
}
auto area_graph = AreaGraph::BuildAreaGraph(split_schemer_->split_plan());
if (area_graph == nullptr) {
return false;
}
// The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan.
std::vector<size_t> cnodes_group_id;
std::function<void(Area *)> expand_callback = std::bind(&Splitter::AreaExpand, this, std::placeholders::_1);
area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id, expand_callback);
RebuildGraph(cnodes_group_id);
return true;
}
static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer) {
MS_EXCEPTION_IF_NULL(main_cnode);
MS_EXCEPTION_IF_NULL(main_cnode->func_graph());
MS_EXCEPTION_IF_NULL(split_schemer);
return SplitterPtr(new Splitter(main_cnode, split_schemer));
}
private:
Splitter(const CNodePtr &main_cnode, SplitSchemerPtr split_schemer)
: main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {}
// Maintain new subgraphs in main graph.
void RebuildGraph(const std::vector<size_t> &cnodes_group_id) {
BindFuncGraph();
RecoverParameter();
ConnectToMainGraph(cnodes_group_id);
UpdateSubGraphInfo();
}
// Rebind nodes to its new sub_func_graph
void BindFuncGraph() {
for (const auto &cnode : new_subgraph_cnodes_) {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
auto callback = [&sub_func_graph, this](const AnfNodePtr &node) {
if (!node->isa<ValueNode>()) {
node->set_func_graph(sub_func_graph);
}
};
TraverseFuncGraph(sub_func_graph, callback);
}
}
// Recover the original subgraph's parameter if the new graph needs it
void RecoverParameter() {
for (const auto &cnode : new_subgraph_cnodes_) {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
auto callback = [&cnode, &sub_func_graph, this](const AnfNodePtr &node) {
auto param = node->cast<ParameterPtr>();
if (param == nullptr) return;
auto it = this->param_to_main_graph_node_map_.find(param);
if (it != this->param_to_main_graph_node_map_.end()) {
cnode->add_input(it->second);
sub_func_graph->add_parameter(param);
// Avoid repeating parameters.
this->param_to_main_graph_node_map_.erase(it);
}
};
TraverseFuncGraph(sub_func_graph, callback);
}
}
CNodePtr InlineSubFuncGraph(const CNodePtr &main_node) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(main_node);
const auto &inputs = main_node->inputs();
auto output = func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output);
const auto &parameters = func_graph->parameters();
std::unordered_map<AnfNodePtr, AnfNodePtr> param_input;
for (size_t i = 0; i < parameters.size(); ++i) {
param_input[parameters[i]] = inputs[i + 1];
}
auto sub_nodes = TopoSort(func_graph->get_return());
for (auto node : sub_nodes) {
if (auto cnode = node->cast<CNodePtr>(); cnode != nullptr) {
cnode->set_func_graph(main_func_graph_);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto iter = param_input.find(cnode->input(i));
if (iter != param_input.end()) {
cnode->set_input(i, iter->second);
}
}
}
}
return output;
}
// Set the new sub_func_graph node as input of nodes original main graph.
void ConnectToMainGraph(const std::vector<size_t> &cnodes_group_id) {
// For single output kernel, the last area contains the original output node (return node),
// to replace old subgraph with new subgraphs, just replace the old CNode with new last CNode.
// For multiple output kernel, to avoid returning Parameter, the last MakeTuple was distribute to
// a new FuncGraph, just inline the last MakeTuple node.
std::vector<CNodePtr> tmp_subgraph_cnodes;
std::unordered_map<AnfNodePtr, AnfNodePtr> replace_map;
for (size_t i = 0; i < new_subgraph_cnodes_.size(); ++i) {
if (split_schemer_->NeedInline(cnodes_group_id[i])) {
// Connect the sub_graph's inner node to main_graph
auto output = InlineSubFuncGraph(new_subgraph_cnodes_[i]);
if (i + 1 == new_subgraph_cnodes_.size()) {
replace_map[this->old_subgraph_cnode_] = output;
} else {
replace_map[new_subgraph_cnodes_[i]] = output;
}
} else {
if (i + 1 == new_subgraph_cnodes_.size()) {
replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back();
}
tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]);
}
}
new_subgraph_cnodes_.swap(tmp_subgraph_cnodes);
TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) return;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto input_node = cnode->input(i);
auto iter = replace_map.find(input_node);
if (iter != replace_map.end()) {
cnode->set_input(i, iter->second);
}
}
});
}
void UpdateSubGraphInfo() {
auto graph_manager = main_func_graph_->manager();
MS_EXCEPTION_IF_NULL(graph_manager);
for (auto cnode : new_subgraph_cnodes_) {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
// add new sub_func_graph to manager
graph_manager->AddFuncGraph(sub_func_graph);
// set GraphKernel attr
auto attr = ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split");
sub_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(attr));
// set kernel info
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(sub_func_graph, &outputs);
SetNewKernelInfo(cnode, sub_func_graph, inputs, outputs, AnfAlgo::GetProcessor(old_subgraph_cnode_));
}
}
// Copy all Parameter and ValueNode that the area used.
void AreaExpand(Area *area) {
std::unordered_map<AnfNodePtr, AnfNodePtr> old_valuenode_and_param_map;
for (auto sub_node : area->nodes_) {
auto sub_cnode = sub_node->cast<CNodePtr>();
if (sub_cnode == nullptr) continue;
for (size_t i = 1; i < sub_cnode->inputs().size(); ++i) {
auto in_node = sub_cnode->input(i);
if (in_node->isa<CNode>()) continue;
auto it = old_valuenode_and_param_map.find(in_node);
if (it != old_valuenode_and_param_map.end()) {
sub_cnode->set_input(i, it->second);
} else {
if (in_node->isa<Parameter>()) {
auto param = in_node->cast<ParameterPtr>();
auto cp_param = this->ParameterClone(param, in_node->func_graph());
old_valuenode_and_param_map[in_node] = cp_param->cast<AnfNodePtr>();
sub_cnode->set_input(i, cp_param);
}
}
}
}
}
void GenParamMap() {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(old_subgraph_cnode_);
auto &param_arr = sub_func_graph->parameters();
for (size_t i = 0; i < param_arr.size(); ++i) {
auto param = param_arr[i]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
param_to_main_graph_node_map_[param] = old_subgraph_cnode_->input(i + 1);
}
}
ParameterPtr ParameterClone(const ParameterPtr &param, const FuncGraphPtr &func) {
ParameterPtr param_c = std::make_shared<Parameter>(func);
param_c->set_name(param->name());
param_c->set_abstract(param->abstract());
param_to_main_graph_node_map_[param_c] = param_to_main_graph_node_map_[param];
return param_c;
}
FuncGraphPtr main_func_graph_;
CNodePtr old_subgraph_cnode_; // The cnode that holds the original sub_func_graph
std::vector<CNodePtr> new_subgraph_cnodes_; // The cnode list that hold the new sub_func_graph
SplitSchemerPtr split_schemer_;
std::unordered_map<ParameterPtr, AnfNodePtr> param_to_main_graph_node_map_;
};
class CostModelSplitSchemer : public Splitter::SplitSchemer {
public:
bool Split(const FuncGraphPtr &func_graph) override {
if (!func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
MS_EXCEPTION(NotSupportError) << "func_graph must be a GraphKernel node.";
}
func_graph_ = func_graph;
this->Run();
return split_plan_.size() > 1;
}
bool NeedInline(size_t group_id) const override {
if (group_id >= need_inline_.size()) {
MS_LOG(EXCEPTION) << "The group_id " << group_id << " should be less than the group num " << need_inline_.size();
}
return need_inline_[group_id] != 0;
}
protected:
virtual bool SplitByCostModel() {
// Use an address map to record the anf node address when converting to json,
// it will recover the original node after split.
std::map<std::string, AnfNodePtr> address_node_map;
// convert anf-ir to json
nlohmann::json json_desc;
DumpOption dump_option;
dump_option.is_before_select_kernel = false;
dump_option.save_ptr_address = true;
if (!AnfToJsonDesc(topo_valid_nodes_, dump_option, &json_desc, &address_node_map)) {
MS_LOG(ERROR) << "Collect json desc failed.";
return false;
}
// call costmodel split function.
auto json_desc_str = json_desc.dump();
MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json:\n" << json_desc_str;
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str);
if (ret.is(py::none())) {
MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
<< json_desc_str;
return false;
}
std::string split_graphs_str = py::cast<std::string>(ret);
if (split_graphs_str.empty()) {
MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
<< json_desc_str;
return false;
}
// recover json to anf-ir.
split_plan_.clear();
if (!JsonDescToAnf(split_graphs_str, address_node_map, &split_plan_)) {
MS_LOG(ERROR) << "Failed to decode split graphs.";
return false;
}
// The info should be returned from costmodel.
need_inline_.assign(split_plan_.size(), 0);
return true;
}
virtual void Run() {
auto mng = func_graph_->manager();
if (mng == nullptr) {
mng = Manage(func_graph_, true);
func_graph_->set_manager(mng);
}
GetValidKernelNodes();
// call CostModel to get a split plan.
if (!SplitByCostModel() || split_plan_.size() <= 1) {
split_plan_.clear();
need_inline_.clear();
return;
} else {
MS_LOG(INFO) << "CostModel split successed. The kernel is split to " << split_plan_.size() << " parts.";
}
MapNodeGroup();
GroupReturnNode();
GroupVirtualNodes();
}
virtual bool IsValidKernelNode(const AnfNodePtr &node) const {
if (!node->isa<CNode>()) return false;
if (AnfAlgo::IsRealKernel(node)) return true;
return false;
}
virtual void GetValidKernelNodes() {
topo_all_nodes_ = TopoSort(func_graph_->get_return());
topo_valid_nodes_.clear();
std::copy_if(topo_all_nodes_.begin(), topo_all_nodes_.end(), std::back_inserter(topo_valid_nodes_),
[this](const AnfNodePtr &node) { return IsValidKernelNode(node); });
}
void MapNodeGroup() {
node_group_.clear();
for (size_t i = 0; i < split_plan_.size(); ++i) {
for (const auto &node : split_plan_[i]) {
node_group_[node] = i;
}
}
}
// group the return node and last MakeTuple node (if exists).
virtual void GroupReturnNode() {
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(func_graph_, &outputs);
auto ret_node = func_graph_->get_return();
auto output = func_graph_->output();
MS_EXCEPTION_IF_NULL(output);
if (IsValidKernelNode(output)) {
auto group_id = node_group_[ret_node] = node_group_[output];
split_plan_[group_id].push_back(ret_node);
return;
}
// assign the make_tuple node to a new group.
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
auto group_id = split_plan_.size();
split_plan_.push_back({output, ret_node});
need_inline_.push_back(1);
node_group_[ret_node] = node_group_[output] = group_id;
return;
}
}
// assign virtual node to the same group of its input.
virtual void GroupVirtualNodes() {
for (const auto &node : topo_all_nodes_) {
if (node_group_.count(node)) continue;
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) continue;
bool found = false;
for (const auto &input : cnode->inputs()) {
auto iter = node_group_.find(input);
if (iter != node_group_.end()) {
node_group_[node] = iter->second;
split_plan_[iter->second].push_back(node);
found = true;
break;
}
}
if (!found) {
MS_LOG(WARNING) << cnode->fullname_with_scope() << " is ungrouped.";
}
}
}
std::shared_ptr<FuncGraph> func_graph_;
AnfNodePtrList topo_all_nodes_;
AnfNodePtrList topo_valid_nodes_;
std::unordered_map<AnfNodePtr, size_t> node_group_;
std::vector<int> need_inline_;
};
// Eliminate the redundant MakeTuple-GetItem operations.
void EliminateTupleGetItem(const FuncGraphPtr &func_graph) {
auto callback = [](const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) return;
for (size_t i = 1; i < cnode->size(); ++i) {
auto getitem = cnode->input(i);
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
auto getitem_cnode = getitem->cast<CNodePtr>();
auto maketuple = getitem_cnode->input(kRealInputNodeIndexInTupleGetItem);
if (!AnfAlgo::CheckPrimitiveType(maketuple, prim::kPrimMakeTuple)) continue;
auto maketuple_cnode = maketuple->cast<CNodePtr>();
int getitem_idx =
GetValue<int>(getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value());
cnode->set_input(i, maketuple_cnode->input(getitem_idx + 1));
}
};
TraverseFuncGraph(func_graph, callback);
}
bool TrySplit(const CNodePtr &sub_root_cnode) {
MS_LOG(INFO) << "Split process node: " << sub_root_cnode->fullname_with_scope();
auto splitter = Splitter::MakeSplitter(sub_root_cnode, std::make_shared<CostModelSplitSchemer>());
MS_EXCEPTION_IF_NULL(splitter);
bool result = splitter->Split();
MS_LOG(INFO) << "Split node completed, result: " << result;
return result;
}
} // namespace
bool GraphKernelSplitter::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto todos = TopoSort(func_graph->get_return());
// Split subgraphs in reversed topo order,
// since the nodes behind the processing node may be modified when spliting.
bool changed = false;
for (auto iter = todos.crbegin(); iter != todos.crend(); ++iter) {
auto node = (*iter)->cast<CNodePtr>();
if (node != nullptr && AnfAlgo::IsGraphKernel(node)) {
changed = TrySplit(node) || changed;
}
}
if (changed) {
EliminateTupleGetItem(func_graph);
}
mng->RemoveRoots();
mng->KeepRoots({func_graph});
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_
#include <memory>
#include "ir/func_graph.h"
#include "backend/optimizer/common/pass.h"
namespace mindspore {
namespace opt {
class GraphKernelSplitter : public Pass {
public:
GraphKernelSplitter() : Pass("graph_kernel_splitter") {}
~GraphKernelSplitter() override = default;
bool Run(const FuncGraphPtr &func_graph);
};
using GraphKernelSplitterPtr = std::shared_ptr<GraphKernelSplitter>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_SPLITTER_H_

View File

@ -1,560 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/pass/fuse_graph_kernel.h"
#include <memory>
#include <string>
#include <algorithm>
#include <map>
#include <set>
#include <queue>
#include <vector>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "ir/graph_utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "vm/segment_runner.h"
#include "debug/anf_ir_dump.h"
#include "ir/func_graph_cloner.h"
namespace mindspore {
namespace opt {
std::vector<PrimitivePtr> get_fusable_basic_ops(bool is_before_kernel_select) {
std::vector<PrimitivePtr> fusable_basic_ops = {
prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum,
prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt,
prim::kPrimReciprocal, prim::kPrimExpandDims, prim::kPrimLessEqual};
if (!is_before_kernel_select) {
fusable_basic_ops.push_back(prim::kPrimCast);
}
return fusable_basic_ops;
}
std::vector<PrimitivePtr> get_fusable_basic_ops_with_reduce(bool is_before_kernel_select) {
std::vector<PrimitivePtr> fusable_basic_ops_with_reduce;
if (!is_before_kernel_select) {
fusable_basic_ops_with_reduce.push_back(prim::kPrimCast);
}
return fusable_basic_ops_with_reduce;
}
std::vector<PrimitivePtr> get_reduce_ops() {
std::vector<PrimitivePtr> reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin,
prim::kPrimReduceMax, prim::kPrimReduceAll};
return reduce_ops;
}
void GetGraphKernelInfo(const FuncGraphPtr fg, GraphKernelInfo *info) {
MS_EXCEPTION_IF_NULL(fg);
auto reduce_ops = get_reduce_ops();
const auto &nodes = fg->nodes();
info->op_type = ELEWISE;
info->cal_step = -1;
info->reduce_op_num = 0;
for (auto node : nodes) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
info->cal_step++;
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim != nullptr) {
bool is_reudce = std::any_of(reduce_ops.begin(), reduce_ops.end(), [&prim](const PrimitivePtr &op) {
return op->hash() == prim->hash() && op->name() == prim->name();
});
if (is_reudce) {
info->op_type = REDUCE;
info->reduce_op_num++;
}
}
}
}
bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) {
auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select);
auto fusable_basic_ops_with_reduce = get_fusable_basic_ops_with_reduce(info.is_before_kernel_select);
bool is_fusable = false;
if (info.op_type == REDUCE &&
(info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) {
is_fusable = std::any_of(fusable_basic_ops_with_reduce.begin(), fusable_basic_ops_with_reduce.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
} else {
is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
return is_fusable;
}
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info,
const AnfNodePtr &node) {
if (cur_node == node) {
return FOLLOW;
}
if (!IsPrimitiveCNode(node)) {
return EXCLUDE;
}
bool is_fusable = IsFuse(info, node);
return is_fusable ? FOLLOW : EXCLUDE;
}
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const GraphKernelInfo &info,
const AnfNodePtr &node) {
if (cur_node == node) {
return FOLLOW;
}
if (AnfAlgo::IsGraphKernel(node)) {
auto cnode = node->cast<CNodePtr>();
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kAnfPrimitiveIndex));
auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
MS_EXCEPTION_IF_NULL(fg_attr_val);
auto fg_attr = GetValue<std::string>(fg_attr_val);
if (fg_attr == kApplyMomentumOpName) {
return FOLLOW;
}
return EXCLUDE;
}
if (!IsPrimitiveCNode(node)) {
return EXCLUDE;
}
bool is_fusable = IsFuse(info, node);
return is_fusable ? FOLLOW : EXCLUDE;
}
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
std::set<AnfNodePtr> *cached_unconnected_set) {
if (!check_node->isa<CNode>() || AnfAlgo::IsGraphKernel(check_node)) {
return false;
}
auto cnode = check_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
// there is a input not in fused_op_set, but the input depends on the fused_op_set
bool has_circle = false;
for (auto input : inputs) {
if (input->isa<CNode>() && !fused_op_set.count(input)) {
std::set<AnfNodePtr> done;
std::vector<AnfNodePtr> todos = {input};
while (!todos.empty()) {
auto node = todos.back();
todos.pop_back();
if (done.count(node) || cached_unconnected_set->count(node)) {
continue;
}
done.insert(node);
if (fused_op_set.count(node)) {
has_circle = true;
break;
}
if (node->isa<CNode>()) {
auto cnode_ptr = node->cast<CNodePtr>();
for (auto it : cnode_ptr->inputs()) {
if (it->isa<CNode>()) {
todos.push_back(it);
}
}
}
}
if (has_circle) {
return true;
}
cached_unconnected_set->insert(done.begin(), done.end());
}
}
return false;
}
bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
auto &inputs = out->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
real_outs->push_back(inputs[i]);
}
return true;
}
if (AnfAlgo::GetCNodeFuncGraphPtr(out) != nullptr) {
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out);
auto fg_out = fg->output();
if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) {
auto inputs = fg_out->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
real_outs->push_back(inputs[i]);
}
return true;
}
}
return false;
}
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) {
std::set<AnfNodePtr> cached_unconnected_set;
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto include = [&fused_op_set](const AnfNodePtr &node) {
if (fused_op_set.count(node)) {
return FOLLOW;
}
return EXCLUDE;
};
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set);
// delete the circle node and the node which depend on the circle node in fused op
if (has_circle) {
auto mng = (*iter)->func_graph()->manager();
std::vector<AnfNodePtr> erase_nodes;
if (is_backward) {
erase_nodes = DeepUsersSearch(*iter, include, mng);
} else {
erase_nodes = DeepLinkedGraphSearch(*iter, include);
}
for (auto erase_node : erase_nodes) {
fused_op_set.erase(erase_node);
}
}
}
std::vector<AnfNodePtr> res;
for (auto node : fused_op) {
if (fused_op_set.count(node)) {
res.push_back(node);
}
}
return res;
}
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
if (lst->size() < 2) {
return;
}
std::vector<AnfNodePtr> res;
std::set<AnfNodePtr> node_sets(lst->begin(), lst->end());
std::map<AnfNodePtr, std::set<AnfNodePtr>> ins;
std::map<AnfNodePtr, std::set<AnfNodePtr>> outs;
std::queue<AnfNodePtr> q;
for (auto node : *lst) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (auto input : cnode->inputs()) {
if (!node_sets.count(input)) {
continue;
}
// out_degree
outs[input].insert(node);
// in_degree
ins[node].insert(input);
}
if (!ins.count(node)) {
ins[node] = {};
}
}
for (auto p : ins) {
if (p.second.size() == 0) {
q.push(p.first);
}
}
while (!q.empty()) {
auto node = q.front();
q.pop();
res.push_back(node);
if (!outs.count(node)) {
continue;
}
for (auto out : outs[node]) {
if (!ins.count(out)) {
continue;
}
ins[out].erase(node);
if (ins[out].size() == 0) {
q.push(out);
}
}
}
lst->assign(res.begin(), res.end());
}
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) {
auto func_graph = cnode->func_graph();
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(cnode->input(0));
GraphKernelInfo info;
info.is_before_kernel_select = is_before_kernel_select;
GetGraphKernelInfo(graph_kernel_g, &info);
auto mng = func_graph->manager();
// Search fusable nodes according input direction.
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1);
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
std::reverse(used_nodes.begin(), used_nodes.end());
// Search fusable nodes according output direction.
auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, info, std::placeholders::_1);
auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng);
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
if (used_nodes.size() > 1) {
used_nodes = RemoveCircle(used_nodes);
}
TopoSortForNodeList(&used_nodes);
return used_nodes;
}
AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) {
auto out_spec = node->abstract();
if (out_spec->isa<abstract::AbstractTuple>()) {
return out_spec->cast<abstract::AbstractTuplePtr>()->elements()[output_idx];
}
return out_spec;
}
AnfNodePtr CreateNewFuseCNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const FuncGraphPtr &fg,
const AnfNodePtrList &inputs, const AnfNodePtrList &outputs,
bool is_before_kernel_select) {
auto func_node = NewValueNode(fg);
std::vector<AnfNodePtr> fn_inputs;
fn_inputs.push_back(func_node);
fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end());
auto fuse_cnode = kernel_graph->NewCNode(fn_inputs);
// Set output abstract
if (outputs.size() > 1) {
std::vector<AbstractBasePtr> out_specs;
for (size_t i = 0; i < outputs.size(); ++i) {
out_specs.push_back(outputs[i]->abstract());
}
auto out_spec = std::make_shared<abstract::AbstractTuple>(out_specs);
fuse_cnode->set_abstract(out_spec);
} else {
fuse_cnode->set_abstract(outputs[0]->abstract());
}
// Set parameter abstract.
for (size_t i = 0; i < inputs.size(); ++i) {
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
fg->parameters()[i]->set_abstract(input_abs);
if (is_before_kernel_select) {
fg->parameters()[i]->set_kernel_info(std::make_shared<device::KernelInfo>());
}
}
// Set kernel info.
if (!is_before_kernel_select) {
std::vector<std::string> graph_input_format;
std::vector<TypeId> graph_input_type;
std::vector<std::string> graph_output_format;
std::vector<TypeId> graph_output_type;
for (size_t i = 0; i < inputs.size(); ++i) {
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
graph_input_format.push_back(input_format);
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
graph_input_type.push_back(input_type);
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
fg->parameters()[i]->set_abstract(input_abs);
}
auto new_outputs = outputs;
if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) {
std::vector<AnfNodePtr> real_outs;
if (IsMakeTupleOut(outputs[0], &real_outs)) {
new_outputs = real_outs;
}
}
for (size_t i = 0; i < new_outputs.size(); ++i) {
auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0);
auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
graph_output_format.push_back(output_format);
graph_output_type.push_back(output_type);
}
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(graph_input_format);
graph_info_builder.SetInputsDeviceType(graph_input_type);
graph_info_builder.SetOutputsFormat(graph_output_format);
graph_info_builder.SetOutputsDeviceType(graph_output_type);
graph_info_builder.SetProcessor(kernel::Processor::AICORE);
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto graph_selected_info = graph_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, fuse_cnode.get());
}
return fuse_cnode;
}
void ReplaceNewFuseCNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const AnfNodePtr &new_fuse_cnode,
const AnfNodePtrList &outputs) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
// single out
if (outputs.size() == 1) {
mng->Replace(outputs[0], new_fuse_cnode);
return;
}
std::vector<AnfNodePtr> fn_inputs;
for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
AnfNodePtrList real_outs;
// not make tuple out, replace
if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) {
fn_inputs.clear();
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
fn_inputs.push_back(new_fuse_cnode);
fn_inputs.push_back(NewValueNode(MakeValue(SizeToInt(out_idx))));
auto new_out = kernel_graph->NewCNode(fn_inputs);
new_out->set_abstract(outputs[out_idx]->abstract());
mng->Replace(outputs[out_idx], new_out);
continue;
}
// the out is make tuple , modify the get_item node's value
auto users = mng->node_users()[outputs[out_idx]];
for (auto &user : users) {
auto use_node = user.first;
if (use_node->isa<CNode>() && (IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem))) {
auto get_item_cnode = use_node->cast<CNodePtr>();
auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(value_input);
auto value_node = value_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
int new_item_idx = SizeToInt(out_idx) + item_idx;
fn_inputs.clear();
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
fn_inputs.push_back(new_fuse_cnode);
fn_inputs.push_back(NewValueNode(new_item_idx));
auto new_out = kernel_graph->NewCNode(fn_inputs);
new_out->set_abstract(get_item_cnode->abstract());
mng->Replace(get_item_cnode, new_out);
}
}
}
}
AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) {
AnfNodePtrList outs;
auto out_node = (*fg)->output();
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
std::vector<AnfNodePtr> output_args;
auto out_cnode = out_node->cast<CNodePtr>();
for (auto out : out_cnode->inputs()) {
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
auto inputs = out->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
output_args.push_back(inputs[i]);
}
} else {
output_args.push_back(out);
}
}
if (output_args.size() != out_cnode->inputs().size()) {
auto new_out = (*fg)->NewCNode(output_args);
(*mng)->Replace(out_node, new_out);
}
for (size_t i = 1; i < output_args.size(); ++i) {
outs.push_back(output_args[i]);
}
return outs;
}
outs.push_back(out_node);
return outs;
}
AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) {
AnfNodePtrList res;
if (outs.size() <= 1) {
return outs;
}
for (auto out : outs) {
AnfNodePtrList real_outs;
if (IsMakeTupleOut(out, &real_outs)) {
res.insert(res.end(), real_outs.begin(), real_outs.end());
continue;
}
res.push_back(out);
}
return res;
}
void FuseGraphKernel(const std::shared_ptr<session::KernelGraph> &kernel_graph, bool is_before_kernel_select) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
auto &todos = kernel_graph->execution_order();
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = *iter;
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
continue;
}
auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_attr != nullptr) {
auto fg_name = GetValue<std::string>(fg_attr);
if (graph_kernel_black_list.count(fg_name) != 0) {
continue;
}
}
auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select);
if (fuse_nodes.size() <= 1) {
continue;
}
FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
// Remove nest make tuple in outs
auto expand_out = GetExpandOuts(outputs);
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, expand_out, is_before_kernel_select);
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
// Inline origin graphkernel
auto cnodes = fg->GetOrderedCnodes();
for (const auto &n : cnodes) {
if (!AnfAlgo::IsGraphKernel(n)) {
continue;
}
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0));
AnfNodePtrList ins;
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end());
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope());
mng->Replace(n, out);
}
EliminateMakeTuple(&fg, &mng);
// Set graphkernel flag
auto ori_fg = GetValueNode<FuncGraphPtr>(node->input(kAnfPrimitiveIndex));
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, ori_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
}
}
} // namespace opt
} // namespace mindspore

View File

@ -41,6 +41,7 @@
#include "utils/config_manager.h" #include "utils/config_manager.h"
#include "utils/base_ref_extends.h" #include "utils/base_ref_extends.h"
#include "debug/tensor_load.h" #include "debug/tensor_load.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
namespace mindspore { namespace mindspore {
namespace session { namespace session {

View File

@ -39,6 +39,10 @@
#include "backend/optimizer/gpu/insert_format_transform_op.h" #include "backend/optimizer/gpu/insert_format_transform_op.h"
#include "backend/optimizer/gpu/remove_format_transform_pair.h" #include "backend/optimizer/gpu/remove_format_transform_pair.h"
#include "backend/optimizer/gpu/remove_redundant_format_transform.h" #include "backend/optimizer/gpu/remove_redundant_format_transform.h"
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "common/trans.h" #include "common/trans.h"
@ -104,6 +108,22 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();
} }
void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
return;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void GPUSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) { void GPUSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
device::gpu::AssignGpuStream(kernel_graph); device::gpu::AssignGpuStream(kernel_graph);
@ -218,6 +238,8 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
SelectKernel(graph); SelectKernel(graph);
// Graph optimization relevant to device data format // Graph optimization relevant to device data format
HardwareOptimize(graph); HardwareOptimize(graph);
// Graph kernel fusion optimization
GraphKernelOptimize(graph);
// Dump .pb graph after graph optimization // Dump .pb graph after graph optimization
if (save_graphs) { if (save_graphs) {
DumpIRProto(graph, "after_opt_" + std::to_string(graph_id)); DumpIRProto(graph, "after_opt_" + std::to_string(graph_id));

View File

@ -51,6 +51,8 @@ class GPUSession : public SessionBasic {
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph); void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph); void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph);
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;

View File

@ -18,6 +18,7 @@
#include <stdlib.h> #include <stdlib.h>
#endif #endif
#include <fstream> #include <fstream>
#include <iomanip>
#include <map> #include <map>
#include <memory> #include <memory>
#include "ir/primitive.h" #include "ir/primitive.h"
@ -446,13 +447,30 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
} }
} }
std::string AddGlobalId(const std::string &filename) {
static size_t g_id = 0;
std::ostringstream s;
auto i = filename.rfind('/');
if (i == string::npos) {
s << std::setfill('0') << std::setw(4) << g_id << "_";
s << filename;
} else {
s << filename.substr(0, i + 1);
s << std::setfill('0') << std::setw(4) << g_id << "_";
s << filename.substr(i + 1);
}
++g_id;
return s.str();
}
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
void DumpIR(const std::string &filename, const FuncGraphPtr &graph, bool dump_full_name) { void DumpIR(const std::string &filename, const FuncGraphPtr &graph, bool dump_full_name) {
if (graph == nullptr) { if (graph == nullptr) {
return; return;
} }
if (filename.size() > PATH_MAX) { auto real_filename = AddGlobalId(filename);
MS_LOG(ERROR) << "File path " << filename << " is too long."; if (real_filename.size() > PATH_MAX) {
MS_LOG(ERROR) << "File path " << real_filename << " is too long.";
return; return;
} }
char real_path[PATH_MAX] = {0}; char real_path[PATH_MAX] = {0};
@ -461,8 +479,8 @@ void DumpIR(const std::string &filename, const FuncGraphPtr &graph, bool dump_fu
MS_LOG(DEBUG) << "dir " << filename << " does not exit."; MS_LOG(DEBUG) << "dir " << filename << " does not exit.";
} }
#else #else
if (nullptr == realpath(filename.c_str(), real_path)) { if (nullptr == realpath(real_filename.c_str(), real_path)) {
MS_LOG(DEBUG) << "Dir " << filename << " does not exit."; MS_LOG(DEBUG) << "Dir " << real_filename << " does not exit.";
} }
#endif #endif

View File

@ -16,9 +16,9 @@
#include "runtime/device/gpu/gpu_kernel_build.h" #include "runtime/device/gpu/gpu_kernel_build.h"
#include <string> #include <string>
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h" #include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/common_utils.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_build_client.h" #include "backend/session/kernel_build_client.h"
@ -56,16 +56,16 @@ void GpuBuild(const KernelGraphPtr &kernel_graph) {
} }
auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel); auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel);
if (!gpu_kernel_ptr) { if (!gpu_kernel_ptr) {
MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed"; MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel->fullname_with_scope() << "] failed";
} }
session::AnfRuntimeAlgorithm::SetKernelMod(gpu_kernel_ptr, kernel.get()); session::AnfRuntimeAlgorithm::SetKernelMod(gpu_kernel_ptr, kernel.get());
} else { } else {
auto gpu_kernel_ptr = kernel::GpuKernelFactory::GetInstance().Create(kernel_name, kernel); auto gpu_kernel_ptr = kernel::GpuKernelFactory::GetInstance().Create(kernel_name, kernel);
if (!gpu_kernel_ptr) { if (!gpu_kernel_ptr) {
MS_LOG(EXCEPTION) << "Build gpu kernel op[" << kernel_name << "] failed"; MS_LOG(EXCEPTION) << "Build gpu kernel op[" << kernel->fullname_with_scope() << "] failed";
} }
if (!gpu_kernel_ptr->Init(kernel)) { if (!gpu_kernel_ptr->Init(kernel)) {
MS_LOG(EXCEPTION) << "Initialize gpu kernel op[" << kernel_name << "] failed."; MS_LOG(EXCEPTION) << "Initialize gpu kernel op[" << kernel->fullname_with_scope() << "] failed.";
} }
session::AnfRuntimeAlgorithm::SetKernelMod((kernel::KernelModPtr)gpu_kernel_ptr, kernel.get()); session::AnfRuntimeAlgorithm::SetKernelMod((kernel::KernelModPtr)gpu_kernel_ptr, kernel.get());
} }

View File

@ -392,9 +392,13 @@ void GPUKernelRuntime::ReleaseDeviceRes() {
mem_manager_->FreeDeviceMemory(); mem_manager_->FreeDeviceMemory();
} }
kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(bin_map); MS_EXCEPTION_IF_NULL(context_ptr);
bin_map->RemoveKernelCache(); if (!(context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG))) {
kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance();
MS_EXCEPTION_IF_NULL(bin_map);
bin_map->RemoveKernelCache();
}
} }
void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,

View File

@ -234,6 +234,67 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
*origin_data_format = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "data_format"); *origin_data_format = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "data_format");
} }
} }
void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> node_list;
std::vector<AnfNodePtr> input_list;
std::vector<AnfNodePtr> output_list;
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
std::vector<std::string> graph_input_format;
std::vector<TypeId> graph_input_type;
// set graph kernel inputs kernel info.
for (size_t i = 0; i < input_list.size(); ++i) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
std::vector<std::string> outputs_format = {kOpFormat_DEFAULT};
std::vector<TypeId> outputs_device_type = {AnfAlgo::GetOutputInferDataType(input_list[i], 0)};
graph_input_format.push_back(kOpFormat_DEFAULT);
graph_input_type.push_back(AnfAlgo::GetOutputInferDataType(input_list[i], 0));
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_device_type);
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
}
// set graph kernel innner nodes kernel info.
for (size_t i = 0; i < node_list.size(); ++i) {
const auto &anf_node = node_list[i];
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
SetKernelInfo(cnode, KernelType::AKG_KERNEL);
}
// set graph kernel node kernel info.
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
std::vector<std::string> graph_output_format;
std::vector<TypeId> graph_output_type;
for (size_t i = 0; i < output_index.size(); ++i) {
auto const &output = output_index[i];
graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second));
graph_output_type.push_back(AnfAlgo::GetOutputDeviceDataType(output.first, output.second));
}
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(graph_input_format);
graph_info_builder.SetInputsDeviceType(graph_input_type);
graph_info_builder.SetOutputsFormat(graph_output_format);
graph_info_builder.SetOutputsDeviceType(graph_output_type);
graph_info_builder.SetProcessor(kernel::Processor::CUDA);
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto graph_selected_info = graph_info_builder.Build();
MS_EXCEPTION_IF_NULL(graph_selected_info);
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
SetTensorDeviceInfo(*graph_selected_info, kernel_node);
}
} // namespace } // namespace
void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
@ -266,7 +327,14 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<s
format_transform_ = false; format_transform_ = false;
} }
void SetKernelInfo(const CNodePtr &kernel_node) { void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
if (AnfAlgo::IsGraphKernel(kernel_node)) {
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
MS_EXCEPTION_IF_NULL(func_graph);
SetGraphKernelInfo(kernel_node, func_graph);
return;
}
std::vector<std::string> inputs_format; std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type; std::vector<TypeId> inputs_type;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
@ -291,13 +359,19 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
builder->SetOutputsFormat(outputs_format); builder->SetOutputsFormat(outputs_format);
builder->SetOutputsDeviceType(outputs_type); builder->SetOutputsDeviceType(outputs_type);
bool result = bool result = false;
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); KernelType res_kernel_type = UNKNOWN_KERNEL_TYPE;
KernelType kernel_type = UNKNOWN_KERNEL_TYPE; if (kernel_type == UNKNOWN_KERNEL_TYPE) {
result =
kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build());
if (!result) { if (!result) {
result = SelectAkgKernel(kernel_node, builder->Build());
res_kernel_type = AKG_KERNEL;
}
} else if (kernel_type == AKG_KERNEL) {
result = SelectAkgKernel(kernel_node, builder->Build()); result = SelectAkgKernel(kernel_node, builder->Build());
kernel_type = AKG_KERNEL; res_kernel_type = AKG_KERNEL;
} }
if (!result) { if (!result) {
@ -314,7 +388,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
<< "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists
<< ", but get " << build_type; << ", but get " << build_type;
} }
builder->SetKernelType(kernel_type); builder->SetKernelType(res_kernel_type);
builder->SetProcessor(kernel::Processor::CUDA); builder->SetProcessor(kernel::Processor::CUDA);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
SetTensorDeviceInfo(*(builder->Build()), kernel_node); SetTensorDeviceInfo(*(builder->Build()), kernel_node);

View File

@ -26,6 +26,7 @@
#include "ir/dtype.h" #include "ir/dtype.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
namespace mindspore { namespace mindspore {
@ -59,7 +60,7 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
{prim::kPrimAddN->name(), {{}, {0}}}, {prim::kPrimAddN->name(), {{}, {0}}},
}; };
void SetKernelInfo(const CNodePtr &kernel_node); void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
class FormatTransformChecker { class FormatTransformChecker {
public: public:

View File

@ -194,6 +194,26 @@ constexpr auto kPaddingOpName = "Padding";
constexpr auto kAvgPoolOpName = "AvgPool"; constexpr auto kAvgPoolOpName = "AvgPool";
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu"; constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
constexpr auto kTensorAddOpName = "TensorAdd"; constexpr auto kTensorAddOpName = "TensorAdd";
constexpr auto kCastOpName = "Cast";
constexpr auto kGreaterEqualOpName = "GreaterEqual";
constexpr auto kAbsOpName = "Abs";
constexpr auto kExpOpName = "Exp";
constexpr auto kNegOpName = "Neg";
constexpr auto kMinimumOpName = "Minimum";
constexpr auto kMaximumOpName = "Maximum";
constexpr auto kMulOpName = "Mul";
constexpr auto kSubOpName = "Sub";
constexpr auto kLogOpName = "Log";
constexpr auto kPowOpName = "Pow";
constexpr auto kReciprocalOpName = "Reciprocal";
constexpr auto kEqualOpName = "Equal";
constexpr auto kLessOpName = "Less";
constexpr auto kLessEqualOpName = "LessEqual";
constexpr auto kSquareOpName = "Square";
constexpr auto kSelectOpName = "Select";
constexpr auto kReduceSumOpName = "ReduceSum";
constexpr auto kReduceMinOpName = "ReduceMin";
constexpr auto kReduceMaxOpName = "ReduceMax";
constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum"; constexpr auto kFusedWeightScaleApplyMomentum = "FusedWeightScaleApplyMomentum";
constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum"; constexpr auto kFusedScaleApplyMomentum = "FusedScaleApplyMomentum";

View File

@ -206,6 +206,11 @@ inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims"); inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
inline const PrimitivePtr kPrimAbs = std::make_shared<Primitive>("Abs");
inline const PrimitivePtr kPrimRound = std::make_shared<Primitive>("Round");
inline const PrimitivePtr kPrimExp = std::make_shared<Primitive>("Exp");
inline const PrimitivePtr kPrimLog = std::make_shared<Primitive>("Log");
inline const PrimitivePtr kPrimRsqrt = std::make_shared<Primitive>("Rsqrt");
// Statements // Statements
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");

View File

@ -290,7 +290,7 @@ class Parameter : public ANode {
std::string DebugString(int recursive_level = 1) const override; std::string DebugString(int recursive_level = 1) const override;
std::string name() const { return name_; } std::string name() const { return name_; }
void set_name(const std::string &name) { name_ = name; } void set_name(const std::string &name) { name_ = name; }
std::string fullname_with_scope() override { return name(); }; std::string fullname_with_scope() override { return name(); }
bool has_default() const { return has_default_; } bool has_default() const { return has_default_; }
void set_default_param(ValuePtr param) { void set_default_param(ValuePtr param) {

View File

@ -273,7 +273,8 @@ class Lamb(Optimizer):
self.global_step = Parameter(initializer(0, [1]), name='global_step') self.global_step = Parameter(initializer(0, [1]), name='global_step')
self.assignadd = P.AssignAdd() self.assignadd = P.AssignAdd()
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.enable_graph_kernel = context.get_context("enable_graph_kernel") self.enable_graph_kernel = context.get_context("enable_graph_kernel") and \
context.get_context("device_target") == "Ascend"
def construct(self, gradients): def construct(self, gradients):
lr = self.get_lr() lr = self.get_lr()

View File

@ -13,24 +13,44 @@
# limitations under the License. # limitations under the License.
"""__init__""" """__init__"""
from .abs import _abs_akg
from .add import _add_akg
from .add_n import _addn_akg
from .cast import _cast_akg from .cast import _cast_akg
from .equal import _equal_akg from .equal import _equal_akg
from .mean import _simple_mean_akg from .exp import _exp_akg
from .mean_grad import _simple_mean_grad_akg from .expand_dims import _expand_dims_akg
from .mul import _mul_akg from .greater_equal import _greater_equal_akg
from .relu6 import _relu6_akg
from .relu6_grad import _relu6_grad_akg
from .squeeze import _squeeze_akg
from .squeeze_grad import _squeeze_grad_akg
from .tile import _tile_akg
from .hsigmoid import _hsigmoid_akg from .hsigmoid import _hsigmoid_akg
from .hsigmoid_grad import _hsigmoid_grad_akg from .hsigmoid_grad import _hsigmoid_grad_akg
from .hswish import _hswish_akg from .hswish import _hswish_akg
from .hswish_grad import _hswish_grad_akg from .hswish_grad import _hswish_grad_akg
from .sub import _sub_akg from .lessequal import _lessequal_akg
from .log import _log_akg
from .logical_and import _logical_and_akg from .logical_and import _logical_and_akg
from .logical_not import _logical_not_akg from .logical_not import _logical_not_akg
from .logical_or import _logical_or_akg from .logical_or import _logical_or_akg
from .lessequal import _lessequal_akg from .maximum import _maximum_akg
from .mean import _simple_mean_akg
from .mean_grad import _simple_mean_grad_akg
from .minimum import _minimum_akg
from .mul import _mul_akg
from .neg import _neg_akg
from .notequal import _notequal_akg from .notequal import _notequal_akg
from .greater_equal import _greater_equal_akg from .pow import _pow_akg
from .real_div import _real_div_akg
from .reciprocal import _reciprocal_akg
from .reduce_max import _reduce_max_akg
from .reduce_sum import _reduce_sum_akg
from .relu6 import _relu6_akg
from .relu6_grad import _relu6_grad_akg
from .reshape import _reshape_akg
from .round import _round_akg
from .rsqrt import _rsqrt_akg
from .sqrt import _sqrt_akg
from .squeeze import _squeeze_akg
from .squeeze_grad import _squeeze_grad_akg
from .sub import _sub_akg
from .tile import _tile_akg
# Please insert op register in lexicographical order of the filename.

View File

@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Abs op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Abs") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default) \
.get_op_info()
@op_info_register(op_info)
def _abs_akg():
"""Abs Akg register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorAdd op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("TensorAdd") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.get_op_info()
@op_info_register(op_info)
def _add_akg():
"""TensorAdd Akg register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AddN op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("AddN") \
.fusion_type("ELEMWISE") \
.input(0, "inputs", "dynamic") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _addn_akg():
"""AddN Akg register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Exp op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Exp") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _exp_akg():
"""Exp Akg register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ExpandDims op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("ExpandDims") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.attr("axis", "required", "int") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default) \
.get_op_info()
@op_info_register(op_info)
def _expand_dims_akg():
"""ExpandDims Akg register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Log") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default) \
.get_op_info()
@op_info_register(op_info)
def _log_akg():
"""Log Akg register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Maximum op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Maximum") \
.fusion_type("COMMREDUCE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.get_op_info()
@op_info_register(op_info)
def _maximum_akg():
"""Maximum Akg register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Minimum op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Minimum") \
.fusion_type("COMMREDUCE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.get_op_info()
@op_info_register(op_info)
def _minimum_akg():
"""Minimum Akg register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Neg op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Neg") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _neg_akg():
"""Neg Akg register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pow op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Pow") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.get_op_info()
@op_info_register(op_info)
def _pow_akg():
"""Pow Akg register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RealDiv op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("RealDiv") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _real_div_akg():
"""RealDiv Akg register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reciprocal op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Reciprocal") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _reciprocal_akg():
"""Reciprocal Akg register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceMax op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("ReduceMax") \
.fusion_type("COMMREDUCE") \
.input(0, "x") \
.output(0, "output") \
.attr("axis", "required", "listInt") \
.attr("keep_dims", "required", "bool") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _reduce_max_akg():
"""ReduceMax Akg register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceMin op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("ReduceMin") \
.fusion_type("COMMREDUCE") \
.input(0, "x") \
.output(0, "output") \
.attr("axis", "required", "listInt") \
.attr("keep_dims", "required", "bool") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _reduce_min_akg():
"""ReduceMin Akg register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceSum op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("ReduceSum") \
.fusion_type("COMMREDUCE") \
.input(0, "x") \
.output(0, "output") \
.attr("axis", "required", "listInt") \
.attr("keep_dims", "required", "bool") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _reduce_sum_akg():
"""ReduceSum Akg register"""
return

View File

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reshape op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Reshape") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "y") \
.attr("shape", "required", "listInt") \
.dtype_format(DT.BOOL_Default, DT.BOOL_Default) \
.dtype_format(DT.I8_Default, DT.I8_Default) \
.dtype_format(DT.I16_Default, DT.I16_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default) \
.dtype_format(DT.I64_Default, DT.I64_Default) \
.dtype_format(DT.U8_Default, DT.U8_Default) \
.dtype_format(DT.U16_Default, DT.U16_Default) \
.dtype_format(DT.U32_Default, DT.U32_Default) \
.dtype_format(DT.U64_Default, DT.U64_Default) \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.F64_Default, DT.F64_Default) \
.get_op_info()
@op_info_register(op_info)
def _reshape_akg():
"""Reshape Akg register"""
return

View File

@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Round op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Round") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default) \
.get_op_info()
@op_info_register(op_info)
def _round_akg():
"""Round Akg register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Rsqrt op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Rsqrt") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _rsqrt_akg():
"""Rsqrt Akg register"""
return

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sqrt op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType as DT
op_info = AkgGpuRegOp("Sqrt") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.get_op_info()
@op_info_register(op_info)
def _sqrt_akg():
"""Sqrt Akg register"""
return

View File

@ -0,0 +1,63 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
from mindspore.nn.graph_kernels import ReLU
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.relu = ReLU()
def construct(self, x, y):
sub_res = self.sub(x, y)
mul_res = self.mul(sub_res, x)
relu_res = self.relu(mul_res)
square_res = P.Square()(relu_res)
add_res = self.add(relu_res, square_res)
add1_res = self.add(add_res, add_res)
return self.add(add1_res, add1_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
sub_res = input_x - input_y
mul_res = sub_res * input_x
relu_res = np.maximum(mul_res, 0)
square_res = np.square(relu_res)
add_res = relu_res + square_res
add1_res = add_res + add_res
expect = add1_res + add1_res
net = Net()
result = net(Tensor(input_x), Tensor(input_y))
res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res

View File

@ -0,0 +1,77 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations as P
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
self.layernorm = P.LayerNorm(1, 1)
def construct(self, x, y, z):
return self.layernorm(x, y, z)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_basic():
input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
gamma = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
beta = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
shape_x = [2, 3, 4, 3]
begin_norm_axis = 1
in_rank = len(shape_x)
if begin_norm_axis < 0:
norm_axis = begin_norm_axis + in_rank
else:
norm_axis = begin_norm_axis
norm_axes = tuple(range(norm_axis, in_rank))
mean = np.mean(input_x, axis=norm_axes, keepdims=True)
mean_b = np.broadcast_to(mean, shape_x)
diff = input_x - mean_b
square = np.square(diff)
smean = np.mean(square, axis=norm_axes, keepdims=True)
smean_b = np.broadcast_to(smean, shape_x)
meps = smean_b + 1e-5
logs = np.log(meps)
mul = logs * (-0.5)
rsqrt = np.exp(mul)
out = diff * rsqrt
bn = out * gamma + beta
expect = (bn, mean, smean)
net = Net()
net_result = net(Tensor(input_x), Tensor(gamma), Tensor(beta))
if isinstance(net_result, tuple) and len(net_result) == 3:
result = (net_result[0].asnumpy(), net_result[1].asnumpy(), net_result[2].asnumpy())
res0 = np.allclose(expect[0], result[0], rtol=1.e-4, atol=1.e-4, equal_nan=True)
assert res0
res1 = np.allclose(expect[1], result[1], rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res1
res2 = np.allclose(expect[2], result[2], rtol=1.e-4, atol=1.e-7, equal_nan=True)
assert res2
else:
assert False

View File

@ -115,6 +115,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc" "../../../mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc"
"../../../mindspore/ccsrc/backend/kernel_compiler/tbe/*.cc" "../../../mindspore/ccsrc/backend/kernel_compiler/tbe/*.cc"
"../../../mindspore/ccsrc/backend/optimizer/ascend/*.cc" "../../../mindspore/ccsrc/backend/optimizer/ascend/*.cc"
"../../../mindspore/ccsrc/backend/optimizer/graph_kernel/*.cc"
"../../../mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc" "../../../mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc"
"../../../mindspore/ccsrc/backend/session/ascend_session.cc" "../../../mindspore/ccsrc/backend/session/ascend_session.cc"
"../../../mindspore/ccsrc/backend/session/ascend_control_parser.cc" "../../../mindspore/ccsrc/backend/session/ascend_control_parser.cc"