forked from mindspore-Ecosystem/mindspore
GraphKernel supports multi-output kernels
This commit is contained in:
parent
5dbbcacadd
commit
7599686a72
|
@ -14,138 +14,221 @@
|
|||
# ===========================================================================
|
||||
"""Cost model splitter"""
|
||||
|
||||
from .model import PrimLib, Graph
|
||||
from .model import PrimLib, Graph, Tensor
|
||||
|
||||
|
||||
class GraphSplitByPattern:
|
||||
"""Graph split by pattern"""
|
||||
"""Graph splitter"""
|
||||
class Area:
|
||||
"""Area"""
|
||||
MODE_BASIC = 1
|
||||
MODE_COMPOSITE = 2
|
||||
|
||||
def __init__(self, init_op):
|
||||
self.pattern = PrimLib.iter_type(init_op)
|
||||
self.ops = [init_op]
|
||||
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
|
||||
self.mode = self.MODE_BASIC
|
||||
|
||||
def __str__(self):
|
||||
return '<' + '-'.join([op.output.name for op in self.ops]) + '>'
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def link_input(self, area_map):
|
||||
"""Link inputs"""
|
||||
def get_relation(op, i):
|
||||
relation = PrimLib.UNKNOWN
|
||||
_, elem_relation = PrimLib.input_relation(op, i)
|
||||
for r in elem_relation:
|
||||
if r is not None and r > relation:
|
||||
relation = r
|
||||
return relation
|
||||
for i, t in enumerate(self.ops[0].inputs):
|
||||
if t.op is not None:
|
||||
area, relation = area_map[t.op], get_relation(self.ops[0], i)
|
||||
self.in_relations[area] = relation
|
||||
|
||||
def link_output(self):
|
||||
"""Link outputs"""
|
||||
for input_area, r in self.in_relations.items():
|
||||
input_area.out_relations[self] = r
|
||||
|
||||
def fuse(self, area):
|
||||
"""Fuse `area` to `self`"""
|
||||
def _update_relation(relations, a, r):
|
||||
relations[a] = max(r, relations[a]) if a in relations else r
|
||||
|
||||
def _update_pattern():
|
||||
self.pattern = max(self.pattern, area.pattern, self.in_relations[area])
|
||||
|
||||
def _fuse_relation(self_relations, new_relations):
|
||||
for a, r in new_relations.items():
|
||||
if a != self:
|
||||
_update_relation(self_relations, a, r)
|
||||
if area in self_relations:
|
||||
self_relations.pop(area)
|
||||
|
||||
def _redirect_relation(rels):
|
||||
"""Replace `area` with `self` in relations"""
|
||||
if area in rels:
|
||||
r = rels.pop(area)
|
||||
_update_relation(rels, self, r)
|
||||
|
||||
self.ops.extend(area.ops)
|
||||
_update_pattern()
|
||||
_fuse_relation(self.in_relations, area.in_relations)
|
||||
_fuse_relation(self.out_relations, area.out_relations)
|
||||
for a, _ in area.in_relations.items():
|
||||
_redirect_relation(a.out_relations)
|
||||
for a, _ in area.out_relations.items():
|
||||
_redirect_relation(a.in_relations)
|
||||
self.mode = self.MODE_COMPOSITE
|
||||
|
||||
def check_circle(self, to):
|
||||
"""Check circle. It returns false if circle exists"""
|
||||
def _reached(area, to):
|
||||
for out, _ in area.out_relations.items():
|
||||
if out == to or _reached(out, to):
|
||||
return True
|
||||
return False
|
||||
for out, _ in self.out_relations.items():
|
||||
if out != to and _reached(out, to):
|
||||
return False
|
||||
return True
|
||||
|
||||
BORADCAST_FUSE_DEPTH = 3
|
||||
REDUCE_FUSE_DEPTH = 3
|
||||
|
||||
def __init__(self, graph):
|
||||
self.graph = graph
|
||||
self.groups = []
|
||||
self.op_group = {}
|
||||
for op in self.graph.ops:
|
||||
g = [op]
|
||||
self.groups.append(g)
|
||||
self.op_group[op] = g
|
||||
self.ids = {}
|
||||
for i, op in enumerate(graph.ops):
|
||||
self.ids[op] = i
|
||||
self.doms = self.post_dom(graph.ops)
|
||||
_, outputs = graph.deduce_parameters()
|
||||
self.outputs = set(outputs)
|
||||
self.areas = []
|
||||
area_map = {}
|
||||
for op in graph.ops:
|
||||
a = self.Area(op)
|
||||
self.areas.append(a)
|
||||
area_map[op] = a
|
||||
for a in self.areas:
|
||||
a.link_input(area_map)
|
||||
for a in self.areas:
|
||||
a.link_output()
|
||||
|
||||
def post_dom(self, ops):
|
||||
"""Post dom"""
|
||||
doms, i_doms = {}, {}
|
||||
for i in range(len(ops) - 1, -1, -1):
|
||||
op = ops[i]
|
||||
doms[op] = {op}
|
||||
i_dom = None
|
||||
if op.output.to_ops:
|
||||
suc_dom = set(doms[op.output.to_ops[0]])
|
||||
for to in op.output.to_ops[1:]:
|
||||
suc_dom.intersection_update(doms[to])
|
||||
doms[op].update(suc_dom)
|
||||
for dom in suc_dom:
|
||||
if i_dom is None or self.ids[dom] < self.ids[i_dom]:
|
||||
i_dom = dom
|
||||
i_doms[op] = i_dom
|
||||
return i_doms
|
||||
|
||||
def get_pattern(self, op, i):
|
||||
"""Get pattern"""
|
||||
pattern = PrimLib.UNKNOWN
|
||||
_, elem_relation = PrimLib.input_relation(op, i)
|
||||
for pat in elem_relation:
|
||||
if pat and pat > pattern:
|
||||
pattern = pat
|
||||
return pattern
|
||||
|
||||
def fuse(self, check_fun):
|
||||
"""Fuse ops"""
|
||||
def _get_path(op, dom):
|
||||
path_ops, visited = [], set()
|
||||
|
||||
def _get_path_depth(p):
|
||||
visited.add(p)
|
||||
if self.op_group[p][0] == p:
|
||||
path_ops.append(p)
|
||||
for to in p.output.to_ops:
|
||||
if to != dom and to not in visited:
|
||||
_get_path_depth(to)
|
||||
_get_path_depth(op)
|
||||
return path_ops
|
||||
changed = True
|
||||
while changed:
|
||||
for group in self.groups:
|
||||
op = group[0]
|
||||
dom = self.doms[op]
|
||||
if dom is None or op.output in self.outputs:
|
||||
continue
|
||||
ops = _get_path(op, dom)
|
||||
if check_fun(op, dom, ops):
|
||||
dom_group = self.op_group[dom]
|
||||
fused = []
|
||||
for fop in ops:
|
||||
f_group = self.op_group[fop]
|
||||
for p in f_group:
|
||||
self.op_group[p] = dom_group
|
||||
fused.append(f_group)
|
||||
dom_group += f_group
|
||||
for g in fused:
|
||||
self.groups.remove(g)
|
||||
def fuse(self, selector):
|
||||
"""Fuse areas"""
|
||||
changed = False
|
||||
while True:
|
||||
for dominant in self.areas:
|
||||
fuse_areas = selector(dominant)
|
||||
if fuse_areas:
|
||||
for area in fuse_areas:
|
||||
changed = True
|
||||
dominant.fuse(area)
|
||||
self.areas.remove(area)
|
||||
break
|
||||
else:
|
||||
changed = False
|
||||
return changed
|
||||
|
||||
def to_subgraphs(self):
|
||||
"""Transform op groups to subgraphs"""
|
||||
ids = {}
|
||||
for i, op in enumerate(self.graph.ops):
|
||||
ids[op] = i
|
||||
subgraphs = []
|
||||
for i, group in enumerate(self.groups):
|
||||
group.sort(key=lambda op: self.ids[op])
|
||||
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), group))
|
||||
return subgraphs
|
||||
graphmodes = []
|
||||
for i, area in enumerate(self.areas):
|
||||
area.ops.sort(key=lambda op: ids[op])
|
||||
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops))
|
||||
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
|
||||
return subgraphs, graphmodes
|
||||
|
||||
def split(self):
|
||||
"""Split graph"""
|
||||
def _buddy(op, dom, path_ops):
|
||||
"""Fuse buddy together"""
|
||||
group = self.op_group[op]
|
||||
for p in group:
|
||||
# p is buddy
|
||||
if p.output.buddy is not None and p.output.buddy.members[0].op not in group:
|
||||
"""Split graph by pattern"""
|
||||
def _elemwise_depth(dom):
|
||||
if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1:
|
||||
return None
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE:
|
||||
return None
|
||||
return [a]
|
||||
|
||||
def _elemwise_width(dom):
|
||||
if dom.pattern > PrimLib.BROADCAST:
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom):
|
||||
fused.append(a)
|
||||
return fused
|
||||
|
||||
def _broadcast_depth(dom):
|
||||
if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1:
|
||||
return None
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \
|
||||
r != PrimLib.BROADCAST or len(a.ops) > self.BORADCAST_FUSE_DEPTH:
|
||||
return None
|
||||
return [a]
|
||||
|
||||
def _broadcast_width(dom):
|
||||
if dom.pattern > PrimLib.BROADCAST:
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.BROADCAST and \
|
||||
a.check_circle(dom) and len(a.ops) <= self.BORADCAST_FUSE_DEPTH:
|
||||
fused.append(a)
|
||||
return fused
|
||||
|
||||
def _check_reduce_exclude(dom):
|
||||
# exclude large all-reduce
|
||||
if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \
|
||||
dom.ops[0].inputs[0].get_size() > 10000:
|
||||
return True
|
||||
|
||||
# exclude multi output
|
||||
for a in dom.in_relations.keys():
|
||||
if len(a.out_relations) > 1:
|
||||
return True
|
||||
if any([op.output.para_type == Tensor.PARA_OUTPUT for op in a.ops]):
|
||||
return True
|
||||
# p's output is buddy
|
||||
for to in p.output.to_ops:
|
||||
if to.output.buddy is not None and to not in group:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _injective(pattern, limit):
|
||||
def _checker(op, dom, path_ops):
|
||||
for p in op.output.to_ops:
|
||||
if p not in self.op_group[dom]:
|
||||
return False
|
||||
if PrimLib.iter_type(op) in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
||||
for i, t in enumerate(dom.inputs):
|
||||
if t == op.output:
|
||||
return self.get_pattern(dom, i) == pattern and len(self.op_group[op]) < limit
|
||||
return False
|
||||
return _checker
|
||||
def _reduce_depth(dom):
|
||||
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
|
||||
return None
|
||||
if _check_reduce_exclude(dom):
|
||||
return None
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \
|
||||
r > PrimLib.REDUCE or len(a.ops) > self.REDUCE_FUSE_DEPTH:
|
||||
return None
|
||||
return [a]
|
||||
|
||||
def _diamond(op, dom, path_ops):
|
||||
if PrimLib.iter_type(op) not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
|
||||
PrimLib.iter_type(dom) in (PrimLib.UNKNOWN, PrimLib.TRANSFORM):
|
||||
return False
|
||||
return len(path_ops) == 1 and op.output not in dom.inputs
|
||||
self.fuse(_buddy)
|
||||
self.fuse(_injective(PrimLib.ELEMWISE, 100))
|
||||
self.fuse(_injective(PrimLib.BROADCAST, 6))
|
||||
self.fuse(_injective(PrimLib.REDUCE, 6))
|
||||
self.fuse(_diamond)
|
||||
return self.to_subgraphs()
|
||||
def _reduce_width(dom):
|
||||
if dom.pattern != PrimLib.REDUCE:
|
||||
return None
|
||||
if _check_reduce_exclude(dom):
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.REDUCE and \
|
||||
a.check_circle(dom) and len(a.ops) <= self.REDUCE_FUSE_DEPTH:
|
||||
fused.append(a)
|
||||
return fused
|
||||
changed = True
|
||||
while changed:
|
||||
changed = self.fuse(_elemwise_depth)
|
||||
changed = self.fuse(_elemwise_width) or changed
|
||||
changed = self.fuse(_broadcast_depth) or changed
|
||||
changed = self.fuse(_broadcast_width) or changed
|
||||
changed = self.fuse(_reduce_depth) or changed
|
||||
changed = self.fuse(_reduce_width) or changed
|
||||
subgraphs, graphmodes = self.to_subgraphs()
|
||||
return subgraphs, graphmodes
|
||||
|
||||
|
||||
def split(graph):
|
||||
"""Split graph"""
|
||||
return GraphSplitByPattern(graph).split()
|
||||
|
|
|
@ -196,8 +196,7 @@ class CompositeGraph:
|
|||
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
|
||||
cur_fusion = None
|
||||
for op in desc['op_desc']:
|
||||
inputs = [self.tensors[d[0]['tensor_name']]
|
||||
for d in op['input_desc'] if 'value' not in d[0]]
|
||||
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
|
||||
out_desc = op['output_desc']
|
||||
name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[
|
||||
0]['shape'], out_desc[0]['data_type'], out_desc[0]['format']
|
||||
|
@ -263,7 +262,7 @@ class CompositeGraph:
|
|||
self.tensors[y], True)
|
||||
inplace_desc = copy.deepcopy(d)
|
||||
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
|
||||
z_desc, out_desc = inplace_desc['input_desc'][2][0].inplace_desc['output_desc'][0]
|
||||
z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
|
||||
z_desc['shape'] = z.shape
|
||||
z_desc['data_type'] = z.dtype
|
||||
z_desc['tensor_name'] = z.name
|
||||
|
|
|
@ -26,10 +26,12 @@ def split_with_json(json_str: str):
|
|||
try:
|
||||
graph_desc = json.loads(json_str)
|
||||
comp = model.load_composite(graph_desc)
|
||||
graph_split = model.split(comp.graph)
|
||||
graph_split, graph_mode = model.split(comp.graph)
|
||||
is_multi_graph = len(graph_split) > 1
|
||||
graph_list = list(map(comp.dump, graph_split))
|
||||
result = {"multi_graph": is_multi_graph, "graph_desc": graph_list}
|
||||
result = {"multi_graph": is_multi_graph,
|
||||
"graph_desc": graph_list,
|
||||
"graph_mode": graph_mode}
|
||||
return json.dumps(result)
|
||||
except jd.JSONDecodeError:
|
||||
logger.error(traceback.format_exc())
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""test split"""
|
||||
import model
|
||||
|
||||
|
||||
def graph_1():
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a = gb.tensor([1024, 16], "float32", name="a")
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("Abs", b, 'c')
|
||||
d = gb.emit("Abs", c, 'd')
|
||||
gb.emit("TensorAdd", [b, d], "e")
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_2():
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a = gb.tensor([1024, 16], "float32", name="a")
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("Abs", b, 'c')
|
||||
d = gb.emit("ReduceSum", c, 'd', attrs={'reduce_axis': (1,)})
|
||||
gb.emit("Sqrt", d, 'e')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def test_split_by_pattern():
|
||||
def _test(graph):
|
||||
print("***************** main graph ***************")
|
||||
print(graph)
|
||||
subgraphs = model.split(graph)
|
||||
for i, g in enumerate(subgraphs):
|
||||
print('------------- subgraph {} --------------'.format(i))
|
||||
print(g)
|
||||
_test(graph_2())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_split_by_pattern()
|
|
@ -485,7 +485,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
|
|||
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
||||
(*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]);
|
||||
(*kernel_json)[kJsonKeyComposite] = true;
|
||||
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();
|
||||
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id();
|
||||
|
||||
if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
|
||||
MS_LOG(ERROR) << "Cal mem size failed.";
|
||||
|
|
|
@ -37,22 +37,17 @@ namespace opt {
|
|||
namespace {
|
||||
bool IsBasicOp(const AnfNodePtr &node, bool is_before_kernel_select) {
|
||||
#if ENABLE_D
|
||||
std::vector<PrimitivePtr> fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub,
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub,
|
||||
prim::kPrimExpandDims};
|
||||
if (!is_before_kernel_select) {
|
||||
fusable_basic_ops.push_back(prim::kPrimCast);
|
||||
fusible_basic_ops.push_back(prim::kPrimCast);
|
||||
}
|
||||
#elif ENABLE_GPU
|
||||
std::vector<PrimitivePtr> fusable_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
||||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
|
||||
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect,
|
||||
prim::kPrimGreater, prim::kPrimAssign};
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = GetFusibleOpList();
|
||||
#else
|
||||
std::vector<PrimitivePtr> fusable_basic_ops;
|
||||
std::vector<PrimitivePtr> fusible_basic_ops;
|
||||
#endif
|
||||
return std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
|
||||
return std::any_of(fusible_basic_ops.begin(), fusible_basic_ops.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
|
||||
|
|
|
@ -49,12 +49,7 @@ bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) {
|
|||
basic_ops.push_back(prim::kPrimCast);
|
||||
}
|
||||
#elif ENABLE_GPU
|
||||
std::vector<PrimitivePtr> basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
||||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
|
||||
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect,
|
||||
prim::kPrimGreater, prim::kPrimAssign};
|
||||
std::vector<PrimitivePtr> basic_ops = GetFusibleOpList();
|
||||
#else
|
||||
std::vector<PrimitivePtr> basic_ops;
|
||||
#endif
|
||||
|
|
|
@ -26,8 +26,8 @@
|
|||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
||||
#ifdef ENABLE_D
|
||||
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
|
||||
#if ENABLE_GPU
|
||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -612,36 +612,6 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
|
|||
return new_fg;
|
||||
}
|
||||
|
||||
bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||
std::vector<AnfNodePtrList> *res_graphs) {
|
||||
MS_EXCEPTION_IF_NULL(res_graphs);
|
||||
auto kernel_json = nlohmann::json::parse(json_desc);
|
||||
if (kernel_json.find(kJsonKeyMultiGraph) == kernel_json.end() || kernel_json[kJsonKeyMultiGraph].is_null()) {
|
||||
// not multi graphs.
|
||||
MS_LOG(ERROR) << "Input json is not multi graph, " << json_desc;
|
||||
return false;
|
||||
}
|
||||
|
||||
kernel::AkgKernelJsonDecoder akg_kernel_json_decoder;
|
||||
std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc];
|
||||
if (graph_descs.empty()) {
|
||||
MS_LOG(ERROR) << "No sub graph found, " << json_desc;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < graph_descs.size(); ++i) {
|
||||
const auto &graph_desc = graph_descs[i];
|
||||
AnfNodePtrList res_graph;
|
||||
if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
|
||||
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
|
||||
return false;
|
||||
}
|
||||
res_graphs->push_back(res_graph);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_set<PrimitivePtr> GetExpandOps() {
|
||||
std::unordered_set<PrimitivePtr> expand_ops = {
|
||||
prim::kPrimSquare,
|
||||
|
@ -664,5 +634,23 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
|
|||
}
|
||||
return name.str();
|
||||
}
|
||||
|
||||
std::vector<PrimitivePtr> GetFusibleOpList() {
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
||||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
|
||||
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect,
|
||||
prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum};
|
||||
return fusible_basic_ops;
|
||||
}
|
||||
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
#if ENABLE_GPU
|
||||
device::gpu::SetKernelInfo(cnode, kernel_type);
|
||||
#endif
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,6 +35,7 @@ constexpr auto kGraphKernelSplitFunc = "split_with_json";
|
|||
constexpr auto kGetGraphKernelOpExpander = "get_op_expander";
|
||||
constexpr auto kJsonKeyMultiGraph = "multi_graph";
|
||||
constexpr auto kJsonKeyGraphDesc = "graph_desc";
|
||||
constexpr auto kJsonKeyGraphMode = "graph_mode";
|
||||
|
||||
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtrList &outputs, kernel::Processor processor);
|
||||
|
@ -50,10 +51,10 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n
|
|||
std::map<std::string, AnfNodePtr> *address_node_map = nullptr);
|
||||
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc);
|
||||
FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs);
|
||||
bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map,
|
||||
std::vector<AnfNodePtrList> *res_graphs);
|
||||
std::unordered_set<PrimitivePtr> GetExpandOps();
|
||||
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
|
||||
std::vector<PrimitivePtr> GetFusibleOpList();
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
|
@ -203,7 +204,7 @@ class AreaGraph {
|
|||
}
|
||||
|
||||
SortCNodes(main_cnodes);
|
||||
cnode_group_id->swap(topo_order_); // The topo_order is not used anymore.
|
||||
*cnode_group_id = std::move(topo_order_); // The topo_order is not used anymore.
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -291,7 +292,7 @@ class AreaGraph {
|
|||
std::vector<CNodePtr> main_cnodes_sorted;
|
||||
std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted),
|
||||
[main_cnodes](int index) { return main_cnodes->at(index); });
|
||||
main_cnodes->swap(main_cnodes_sorted);
|
||||
*main_cnodes = std::move(main_cnodes_sorted);
|
||||
}
|
||||
|
||||
// Areas in this subgraph
|
||||
|
@ -415,6 +416,9 @@ class Splitter {
|
|||
cnode->set_input(i, iter->second);
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::IsRealKernel(node)) {
|
||||
ResetKernelInfo(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return output;
|
||||
|
@ -445,7 +449,7 @@ class Splitter {
|
|||
tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]);
|
||||
}
|
||||
}
|
||||
new_subgraph_cnodes_.swap(tmp_subgraph_cnodes);
|
||||
new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes);
|
||||
|
||||
TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -580,15 +584,38 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
|
|||
return false;
|
||||
}
|
||||
|
||||
// recover json to anf-ir.
|
||||
split_plan_.clear();
|
||||
if (!JsonDescToAnf(split_graphs_str, address_node_map, &split_plan_)) {
|
||||
MS_LOG(ERROR) << "Failed to decode split graphs.";
|
||||
if (!DecodeJson(split_graphs_str, address_node_map)) {
|
||||
MS_LOG(ERROR) << "Failed to decode split graphs. input json:\n" << split_graphs_str;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual bool DecodeJson(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map) {
|
||||
auto kernel_json = nlohmann::json::parse(json_desc);
|
||||
kernel::AkgKernelJsonDecoder akg_kernel_json_decoder;
|
||||
std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc];
|
||||
std::vector<std::string> graph_modes = kernel_json[kJsonKeyGraphMode];
|
||||
if (graph_modes.size() != graph_descs.size()) {
|
||||
MS_LOG(ERROR) << "Size of graph_mode " << graph_modes.size() << " mismatch graph_desc " << graph_descs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
// The info should be returned from costmodel.
|
||||
need_inline_.assign(split_plan_.size(), 0);
|
||||
// recover json to anfnode.
|
||||
split_plan_.clear();
|
||||
for (const auto &graph_desc : graph_descs) {
|
||||
AnfNodePtrList res_graph;
|
||||
if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
|
||||
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
|
||||
return false;
|
||||
}
|
||||
split_plan_.push_back(std::move(res_graph));
|
||||
}
|
||||
|
||||
// ops to be inlined.
|
||||
need_inline_.clear();
|
||||
std::transform(graph_modes.begin(), graph_modes.end(), std::back_inserter(need_inline_),
|
||||
[](const std::string &mode) { return mode == "basic" ? 1 : 0; });
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -13,5 +13,5 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
PYTHONPATH="$(pwd)/..:${PYTHONPATH}"
|
||||
PYTHONPATH="$(pwd)/../../../../mindspore/_extends/graph_kernel:${PYTHONPATH}"
|
||||
export PYTHONPATH
|
|
@ -0,0 +1,436 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Test split"""
|
||||
import model
|
||||
from model import model as estimate
|
||||
from model import graph_split as split
|
||||
|
||||
|
||||
def get_nodes(sp, ops):
|
||||
"""Get nodes"""
|
||||
if isinstance(ops[0], str):
|
||||
new_ops = []
|
||||
for t in ops:
|
||||
for op in sp.graph.ops:
|
||||
if op.output.name == t:
|
||||
new_ops.append(op)
|
||||
break
|
||||
else:
|
||||
print("ERROR: not found op: ", t)
|
||||
ops = new_ops
|
||||
return [sp.nodes[sp.graph.ops.index(op)] for op in ops]
|
||||
|
||||
|
||||
def first_connected(sp, space):
|
||||
for cand in space:
|
||||
nodes = [sp.nodes[i] for i in cand[0]]
|
||||
graphs = sp.resolve_connnected_graphs(nodes)
|
||||
if len(graphs) != 1:
|
||||
print("connect check faied: ", nodes)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def split_format(sp, cand):
|
||||
names = []
|
||||
for ids in cand:
|
||||
ops = []
|
||||
for i in ids:
|
||||
ops.append(sp.graph.ops[i].output.name)
|
||||
names.append(','.join(ops))
|
||||
return '|'.join(names)
|
||||
|
||||
|
||||
def graph_1():
|
||||
''' ring, no succ_dep, no prev '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a = gb.tensor([10240, 16], "float32", name="a")
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("Abs", b, 'c')
|
||||
d = gb.emit("Abs", c, 'd')
|
||||
gb.emit('TensorAdd', [b, d], 'e')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_2():
|
||||
''' ring, succ_dep, no prev '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([10240, 16], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("Abs", a, 'c')
|
||||
d = gb.emit("Abs", b, 'd')
|
||||
e = gb.emit('TensorAdd', [c, d], 'e')
|
||||
gb.emit("Abs", e, 'f')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_3():
|
||||
''' no ring, 1 sibling node '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([10240, 16], "float32", name="a0")
|
||||
a1 = gb.tensor([10240, 16], "float32", name="a1")
|
||||
b = gb.emit("Abs", a0, 'b')
|
||||
c = gb.emit("Abs", a1, 'c')
|
||||
d = gb.emit("Abs", b, 'd')
|
||||
e = gb.emit('TensorAdd', [c, d], 'e')
|
||||
gb.emit("Abs", e, 'f')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_4():
|
||||
''' no ring, 2 sibling nodes in 1 step '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([10240, 16], "float32", name="a0")
|
||||
a1 = gb.tensor([10240, 16], "float32", name="a1")
|
||||
b = gb.emit("Abs", a0, 'b')
|
||||
c = gb.emit("Abs", b, 'c')
|
||||
d = gb.emit("Abs", a1, 'd')
|
||||
e = gb.emit("Abs", d, 'e')
|
||||
f = gb.emit('TensorAdd', [c, e], 'f')
|
||||
gb.emit('Abs', f, 'g')
|
||||
h = gb.emit("Abs", d, 'h')
|
||||
i = gb.emit('TensorAdd', [c, h], 'i')
|
||||
gb.emit("Abs", i, 'j')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_5():
|
||||
''' no ring, 2 sibling step '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main") as g:
|
||||
a0 = gb.tensor([10240, 16], "float32", name="a0")
|
||||
a1 = gb.tensor([10240, 16], "float32", name="a1")
|
||||
a2 = gb.tensor([10240, 16], "float32", name="a2")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a1, 'b')
|
||||
c = gb.emit("Abs", b, 'c')
|
||||
d = gb.emit('TensorAdd', [a, c], 'd')
|
||||
gb.emit("Abs", d, 'e')
|
||||
f = gb.emit("Abs", a2, 'f')
|
||||
g = gb.emit('TensorAdd', [c, f], 'g')
|
||||
gb.emit("Abs", g, 'h')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_6():
|
||||
''' no ring, tree down '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([10240, 16], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
gb.emit("Abs", b, 'd')
|
||||
gb.emit("Abs", b, 'e')
|
||||
c = gb.emit("Abs", a, 'c')
|
||||
gb.emit("Abs", c, 'f')
|
||||
gb.emit("Abs", c, 'g')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_pat_1():
|
||||
''' split by reduce '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
|
||||
d = gb.emit("Sqrt", c, 'd')
|
||||
gb.emit("Sqrt", d, 'f')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_pat_2():
|
||||
''' multi output '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
|
||||
gb.emit("ReduceSum", b, 'e', attrs={'reduce_axis': (1,)})
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_pat_3():
|
||||
''' two reduce '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
|
||||
d = gb.emit("Abs", c, 'd')
|
||||
gb.emit("ReduceSum", d, 'e', attrs={'reduce_axis': (1,)})
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_pat_4():
|
||||
''' elewise + broadcast '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([1, 1024], "float32", name="a0")
|
||||
a2 = gb.tensor([1014, 1024], "float32", name="a2")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("Abs", b, 'c')
|
||||
d = gb.emit("Abs", c, 'd')
|
||||
e = gb.emit("Abs", d, 'e')
|
||||
f = gb.emit("Abs", e, 'f')
|
||||
g0 = gb.emit("Abs", a2, 'g0')
|
||||
# g0 = gb.emit("Abs", g0, 'g0')
|
||||
# g0 = gb.emit("Abs", g0, 'g0')
|
||||
# g0 = gb.emit("Abs", g0, 'g0')
|
||||
# g0 = gb.emit("Abs", g0, 'g0')
|
||||
# g0 = gb.emit("Abs", g0, 'g0')
|
||||
# g0 = gb.emit("Abs", g0, 'g0')
|
||||
g0 = gb.emit("Abs", g0, 'g0')
|
||||
g1 = gb.emit('TensorAdd', [f, g0], 'g1')
|
||||
g2 = gb.emit("Abs", g1, 'g2')
|
||||
g3 = gb.emit("Abs", g2, 'g3')
|
||||
g4 = gb.emit("Abs", g3, 'g4')
|
||||
gb.emit("Abs", g4, 'g5')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_pat_5():
|
||||
''' reduce + reshape '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
|
||||
d = gb.emit("Abs", c, 'd')
|
||||
e = gb.tensor([512, 2048], "float32", name="e")
|
||||
gb.op("Reshape", e, [d])
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_pat_6():
|
||||
''' dimond '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("Abs", a, 'c')
|
||||
gb.emit("TensorAdd", [b, c], 'd')
|
||||
gb.emit("Abs", c, 'f') # broke dimond
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_pat_7():
|
||||
''' buddy of control op '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a1 = gb.tensor([1024, 1024], "float32", name="a1")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a1, 'b')
|
||||
c = gb.emit("make_tuple", [a, b], 'c')
|
||||
d = gb.tensor([1024, 1024], "float32", name="d")
|
||||
gb.op("AddN", d, [c])
|
||||
gb.emit("Abs", d, 'f')
|
||||
graph = gb.get()[0]
|
||||
estimate.AddControlBuddy().visit_graph(graph)
|
||||
return graph
|
||||
|
||||
|
||||
def graph_pat_8():
|
||||
''' reduce + reshape '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
#c = gb.emit("Abs", b, 'b')
|
||||
c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
|
||||
gb.emit("TensorAdd", [b, c], 'd')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_pat_9():
|
||||
''' scalar '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a1 = gb.tensor([1], "float32", name="a1")
|
||||
a = gb.emit("Maximum", a1, 'a')
|
||||
b = gb.emit("Mul", [a, a1], 'b')
|
||||
gb.emit('Mul', [b, a0], 'c')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_mo_1():
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main"):
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
gb.emit("Abs", a, 'b')
|
||||
gb.emit("Abs", a, 'c')
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_mo_2():
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main") as g:
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("Abs", b, 'c')
|
||||
g.set_output(b, c)
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_mo_3():
|
||||
''' two reduce '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main") as g:
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
|
||||
g.set_output(b, c)
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def graph_mo_4():
|
||||
''' two reduce '''
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope("main") as g:
|
||||
a0 = gb.tensor([1024, 1024], "float32", name="a0")
|
||||
a = gb.emit("Abs", a0, 'a')
|
||||
b = gb.emit("Abs", a, 'b')
|
||||
c = gb.emit("ReduceSum", a, 'c', attrs={'reduce_axis': (1,)})
|
||||
g.set_output(b, c)
|
||||
return gb.get()[0]
|
||||
|
||||
|
||||
def test_binary_split():
|
||||
"""Test binary split"""
|
||||
def _test(graph, expected_space_size):
|
||||
print("********* test on graph : {} *************".format(graph.name))
|
||||
sp = split.GraphSpliter(graph)
|
||||
nodes = get_nodes(sp, graph.ops)
|
||||
space = sp.binary_split(nodes)
|
||||
for i, s in enumerate(space):
|
||||
print('{}: {}'.format(i, split_format(sp, s)))
|
||||
assert len(space) == expected_space_size
|
||||
assert first_connected(sp, space)
|
||||
_test(graph_1(), 3)
|
||||
_test(graph_2(), 7)
|
||||
_test(graph_3(), 4)
|
||||
_test(graph_4(), 17)
|
||||
_test(graph_5(), 11)
|
||||
_test(graph_6(), 24)
|
||||
|
||||
|
||||
def test_resolve_connnected_graphs():
|
||||
"""Test resolve connected graphs"""
|
||||
graph = graph_5()
|
||||
sp = split.GraphSpliter(graph)
|
||||
n1 = get_nodes(sp, ['a', 'd', 'b', 'c'])
|
||||
graphs = sp.resolve_connnected_graphs(n1)
|
||||
print(graphs)
|
||||
assert len(graphs) == 1
|
||||
n2 = get_nodes(sp, ['a', 'd', 'e', 'f', 'g'])
|
||||
graphs = sp.resolve_connnected_graphs(n2)
|
||||
print(graphs)
|
||||
assert len(graphs) == 2
|
||||
n3 = get_nodes(sp, ['a', 'b', 'f'])
|
||||
graphs = sp.resolve_connnected_graphs(n3)
|
||||
print(graphs)
|
||||
assert len(graphs) == 3
|
||||
|
||||
|
||||
def test_split():
|
||||
"""Test split"""
|
||||
def _print_cost(name, c):
|
||||
print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" %
|
||||
(name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type()))
|
||||
|
||||
def _test(graph):
|
||||
print("********* test on graph : {} *************".format(graph.name))
|
||||
sp = split.GraphSpliter(graph)
|
||||
subgraphs = sp.split(False)
|
||||
print('----- main graph -------')
|
||||
print(graph)
|
||||
for i, g in enumerate(subgraphs):
|
||||
print(' -------- subgraph {} -------'.format(i))
|
||||
print(g)
|
||||
print("--------- cost ------------")
|
||||
cost, _ = model.estimate(graph)
|
||||
_print_cost("main graph", cost)
|
||||
fc, sub_costs = model.estimate(subgraphs)
|
||||
_print_cost("Subgraphs:", fc)
|
||||
for i, cost in enumerate(sub_costs):
|
||||
_print_cost(" |_%d:\t" % (i), cost)
|
||||
_test(graph_5())
|
||||
# _test(graph_4())
|
||||
|
||||
|
||||
def test_estimate():
|
||||
"""Test estimate"""
|
||||
graph = graph_5()
|
||||
e = estimate.Estimator(graph)
|
||||
e.estimate()
|
||||
print(e.iter_space)
|
||||
|
||||
|
||||
def test_pattern_split():
|
||||
"""Test pattern split"""
|
||||
def _test(graph, expect_n=0):
|
||||
print("************* main graph **************")
|
||||
print(graph)
|
||||
subgraphs = split.GraphSplitByPatternV2(graph).split()
|
||||
for i, g in enumerate(subgraphs):
|
||||
print(' -------- subgraph {} -------'.format(i))
|
||||
print(g)
|
||||
if expect_n > 0:
|
||||
assert len(subgraphs) == expect_n
|
||||
|
||||
# _test(graph_1(), 1)
|
||||
# _test(graph_pat_1(), 2)
|
||||
# _test(graph_pat_2())
|
||||
# _test(graph_pat_3())
|
||||
# _test(graph_pat_4())
|
||||
# _test(graph_pat_5())
|
||||
# _test(graph_pat_6())
|
||||
# _test(graph_pat_7())
|
||||
# _test(graph_pat_8())
|
||||
# _test(graph_pat_9())
|
||||
|
||||
# _test(graph_mo_1())
|
||||
# _test(graph_mo_2())
|
||||
# _test(graph_mo_3())
|
||||
_test(graph_mo_4())
|
||||
|
||||
|
||||
def main():
|
||||
# test_binary_split()
|
||||
# test_resolve_connnected_graphs()
|
||||
# test_split()
|
||||
# test_estimate()
|
||||
test_pattern_split()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue