GraphKernel supports multi-output kernels

This commit is contained in:
dayschan 2020-09-29 11:35:37 +08:00
parent 5dbbcacadd
commit 7599686a72
13 changed files with 704 additions and 231 deletions

View File

@ -14,138 +14,221 @@
# ===========================================================================
"""Cost model splitter"""
from .model import PrimLib, Graph
from .model import PrimLib, Graph, Tensor
class GraphSplitByPattern:
"""Graph split by pattern"""
"""Graph splitter"""
class Area:
"""Area"""
MODE_BASIC = 1
MODE_COMPOSITE = 2
def __init__(self, init_op):
self.pattern = PrimLib.iter_type(init_op)
self.ops = [init_op]
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
self.mode = self.MODE_BASIC
def __str__(self):
return '<' + '-'.join([op.output.name for op in self.ops]) + '>'
def __repr__(self):
return str(self)
def link_input(self, area_map):
"""Link inputs"""
def get_relation(op, i):
relation = PrimLib.UNKNOWN
_, elem_relation = PrimLib.input_relation(op, i)
for r in elem_relation:
if r is not None and r > relation:
relation = r
return relation
for i, t in enumerate(self.ops[0].inputs):
if t.op is not None:
area, relation = area_map[t.op], get_relation(self.ops[0], i)
self.in_relations[area] = relation
def link_output(self):
"""Link outputs"""
for input_area, r in self.in_relations.items():
input_area.out_relations[self] = r
def fuse(self, area):
"""Fuse `area` to `self`"""
def _update_relation(relations, a, r):
relations[a] = max(r, relations[a]) if a in relations else r
def _update_pattern():
self.pattern = max(self.pattern, area.pattern, self.in_relations[area])
def _fuse_relation(self_relations, new_relations):
for a, r in new_relations.items():
if a != self:
_update_relation(self_relations, a, r)
if area in self_relations:
self_relations.pop(area)
def _redirect_relation(rels):
"""Replace `area` with `self` in relations"""
if area in rels:
r = rels.pop(area)
_update_relation(rels, self, r)
self.ops.extend(area.ops)
_update_pattern()
_fuse_relation(self.in_relations, area.in_relations)
_fuse_relation(self.out_relations, area.out_relations)
for a, _ in area.in_relations.items():
_redirect_relation(a.out_relations)
for a, _ in area.out_relations.items():
_redirect_relation(a.in_relations)
self.mode = self.MODE_COMPOSITE
def check_circle(self, to):
"""Check circle. It returns false if circle exists"""
def _reached(area, to):
for out, _ in area.out_relations.items():
if out == to or _reached(out, to):
return True
return False
for out, _ in self.out_relations.items():
if out != to and _reached(out, to):
return False
return True
BORADCAST_FUSE_DEPTH = 3
REDUCE_FUSE_DEPTH = 3
def __init__(self, graph):
self.graph = graph
self.groups = []
self.op_group = {}
for op in self.graph.ops:
g = [op]
self.groups.append(g)
self.op_group[op] = g
self.ids = {}
for i, op in enumerate(graph.ops):
self.ids[op] = i
self.doms = self.post_dom(graph.ops)
_, outputs = graph.deduce_parameters()
self.outputs = set(outputs)
self.areas = []
area_map = {}
for op in graph.ops:
a = self.Area(op)
self.areas.append(a)
area_map[op] = a
for a in self.areas:
a.link_input(area_map)
for a in self.areas:
a.link_output()
def post_dom(self, ops):
"""Post dom"""
doms, i_doms = {}, {}
for i in range(len(ops) - 1, -1, -1):
op = ops[i]
doms[op] = {op}
i_dom = None
if op.output.to_ops:
suc_dom = set(doms[op.output.to_ops[0]])
for to in op.output.to_ops[1:]:
suc_dom.intersection_update(doms[to])
doms[op].update(suc_dom)
for dom in suc_dom:
if i_dom is None or self.ids[dom] < self.ids[i_dom]:
i_dom = dom
i_doms[op] = i_dom
return i_doms
def get_pattern(self, op, i):
"""Get pattern"""
pattern = PrimLib.UNKNOWN
_, elem_relation = PrimLib.input_relation(op, i)
for pat in elem_relation:
if pat and pat > pattern:
pattern = pat
return pattern
def fuse(self, check_fun):
"""Fuse ops"""
def _get_path(op, dom):
path_ops, visited = [], set()
def _get_path_depth(p):
visited.add(p)
if self.op_group[p][0] == p:
path_ops.append(p)
for to in p.output.to_ops:
if to != dom and to not in visited:
_get_path_depth(to)
_get_path_depth(op)
return path_ops
changed = True
while changed:
for group in self.groups:
op = group[0]
dom = self.doms[op]
if dom is None or op.output in self.outputs:
continue
ops = _get_path(op, dom)
if check_fun(op, dom, ops):
dom_group = self.op_group[dom]
fused = []
for fop in ops:
f_group = self.op_group[fop]
for p in f_group:
self.op_group[p] = dom_group
fused.append(f_group)
dom_group += f_group
for g in fused:
self.groups.remove(g)
def fuse(self, selector):
"""Fuse areas"""
changed = False
while True:
for dominant in self.areas:
fuse_areas = selector(dominant)
if fuse_areas:
for area in fuse_areas:
changed = True
dominant.fuse(area)
self.areas.remove(area)
break
else:
changed = False
return changed
def to_subgraphs(self):
"""Transform op groups to subgraphs"""
ids = {}
for i, op in enumerate(self.graph.ops):
ids[op] = i
subgraphs = []
for i, group in enumerate(self.groups):
group.sort(key=lambda op: self.ids[op])
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), group))
return subgraphs
graphmodes = []
for i, area in enumerate(self.areas):
area.ops.sort(key=lambda op: ids[op])
subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops))
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
return subgraphs, graphmodes
def split(self):
"""Split graph"""
def _buddy(op, dom, path_ops):
"""Fuse buddy together"""
group = self.op_group[op]
for p in group:
# p is buddy
if p.output.buddy is not None and p.output.buddy.members[0].op not in group:
"""Split graph by pattern"""
def _elemwise_depth(dom):
if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1:
return None
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE:
return None
return [a]
def _elemwise_width(dom):
if dom.pattern > PrimLib.BROADCAST:
return None
fused = []
for a, r in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom):
fused.append(a)
return fused
def _broadcast_depth(dom):
if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1:
return None
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \
r != PrimLib.BROADCAST or len(a.ops) > self.BORADCAST_FUSE_DEPTH:
return None
return [a]
def _broadcast_width(dom):
if dom.pattern > PrimLib.BROADCAST:
return None
fused = []
for a, r in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.BROADCAST and \
a.check_circle(dom) and len(a.ops) <= self.BORADCAST_FUSE_DEPTH:
fused.append(a)
return fused
def _check_reduce_exclude(dom):
# exclude large all-reduce
if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \
dom.ops[0].inputs[0].get_size() > 10000:
return True
# exclude multi output
for a in dom.in_relations.keys():
if len(a.out_relations) > 1:
return True
if any([op.output.para_type == Tensor.PARA_OUTPUT for op in a.ops]):
return True
# p's output is buddy
for to in p.output.to_ops:
if to.output.buddy is not None and to not in group:
return True
return False
def _injective(pattern, limit):
def _checker(op, dom, path_ops):
for p in op.output.to_ops:
if p not in self.op_group[dom]:
return False
if PrimLib.iter_type(op) in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
for i, t in enumerate(dom.inputs):
if t == op.output:
return self.get_pattern(dom, i) == pattern and len(self.op_group[op]) < limit
return False
return _checker
def _reduce_depth(dom):
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
return None
if _check_reduce_exclude(dom):
return None
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \
r > PrimLib.REDUCE or len(a.ops) > self.REDUCE_FUSE_DEPTH:
return None
return [a]
def _diamond(op, dom, path_ops):
if PrimLib.iter_type(op) not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
PrimLib.iter_type(dom) in (PrimLib.UNKNOWN, PrimLib.TRANSFORM):
return False
return len(path_ops) == 1 and op.output not in dom.inputs
self.fuse(_buddy)
self.fuse(_injective(PrimLib.ELEMWISE, 100))
self.fuse(_injective(PrimLib.BROADCAST, 6))
self.fuse(_injective(PrimLib.REDUCE, 6))
self.fuse(_diamond)
return self.to_subgraphs()
def _reduce_width(dom):
if dom.pattern != PrimLib.REDUCE:
return None
if _check_reduce_exclude(dom):
return None
fused = []
for a, r in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.REDUCE and \
a.check_circle(dom) and len(a.ops) <= self.REDUCE_FUSE_DEPTH:
fused.append(a)
return fused
changed = True
while changed:
changed = self.fuse(_elemwise_depth)
changed = self.fuse(_elemwise_width) or changed
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_reduce_depth) or changed
changed = self.fuse(_reduce_width) or changed
subgraphs, graphmodes = self.to_subgraphs()
return subgraphs, graphmodes
def split(graph):
"""Split graph"""
return GraphSplitByPattern(graph).split()

View File

@ -196,8 +196,7 @@ class CompositeGraph:
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
cur_fusion = None
for op in desc['op_desc']:
inputs = [self.tensors[d[0]['tensor_name']]
for d in op['input_desc'] if 'value' not in d[0]]
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
out_desc = op['output_desc']
name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[
0]['shape'], out_desc[0]['data_type'], out_desc[0]['format']
@ -263,7 +262,7 @@ class CompositeGraph:
self.tensors[y], True)
inplace_desc = copy.deepcopy(d)
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
z_desc, out_desc = inplace_desc['input_desc'][2][0].inplace_desc['output_desc'][0]
z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
z_desc['shape'] = z.shape
z_desc['data_type'] = z.dtype
z_desc['tensor_name'] = z.name

View File

@ -26,10 +26,12 @@ def split_with_json(json_str: str):
try:
graph_desc = json.loads(json_str)
comp = model.load_composite(graph_desc)
graph_split = model.split(comp.graph)
graph_split, graph_mode = model.split(comp.graph)
is_multi_graph = len(graph_split) > 1
graph_list = list(map(comp.dump, graph_split))
result = {"multi_graph": is_multi_graph, "graph_desc": graph_list}
result = {"multi_graph": is_multi_graph,
"graph_desc": graph_list,
"graph_mode": graph_mode}
return json.dumps(result)
except jd.JSONDecodeError:
logger.error(traceback.format_exc())

View File

@ -1,53 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""test split"""
import model
def graph_1():
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a = gb.tensor([1024, 16], "float32", name="a")
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", b, 'c')
d = gb.emit("Abs", c, 'd')
gb.emit("TensorAdd", [b, d], "e")
return gb.get()[0]
def graph_2():
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a = gb.tensor([1024, 16], "float32", name="a")
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", b, 'c')
d = gb.emit("ReduceSum", c, 'd', attrs={'reduce_axis': (1,)})
gb.emit("Sqrt", d, 'e')
return gb.get()[0]
def test_split_by_pattern():
def _test(graph):
print("***************** main graph ***************")
print(graph)
subgraphs = model.split(graph)
for i, g in enumerate(subgraphs):
print('------------- subgraph {} --------------'.format(i))
print(g)
_test(graph_2())
if __name__ == '__main__':
test_split_by_pattern()

View File

@ -485,7 +485,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
(*kernel_json)[kJsonKeyPlatform] = "AKG";
(*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]);
(*kernel_json)[kJsonKeyComposite] = true;
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id();
if (!GetIOSize(*kernel_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";

View File

@ -37,22 +37,17 @@ namespace opt {
namespace {
bool IsBasicOp(const AnfNodePtr &node, bool is_before_kernel_select) {
#if ENABLE_D
std::vector<PrimitivePtr> fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub,
std::vector<PrimitivePtr> fusible_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub,
prim::kPrimExpandDims};
if (!is_before_kernel_select) {
fusable_basic_ops.push_back(prim::kPrimCast);
fusible_basic_ops.push_back(prim::kPrimCast);
}
#elif ENABLE_GPU
std::vector<PrimitivePtr> fusable_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect,
prim::kPrimGreater, prim::kPrimAssign};
std::vector<PrimitivePtr> fusible_basic_ops = GetFusibleOpList();
#else
std::vector<PrimitivePtr> fusable_basic_ops;
std::vector<PrimitivePtr> fusible_basic_ops;
#endif
return std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(),
return std::any_of(fusible_basic_ops.begin(), fusible_basic_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}

View File

@ -49,12 +49,7 @@ bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) {
basic_ops.push_back(prim::kPrimCast);
}
#elif ENABLE_GPU
std::vector<PrimitivePtr> basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect,
prim::kPrimGreater, prim::kPrimAssign};
std::vector<PrimitivePtr> basic_ops = GetFusibleOpList();
#else
std::vector<PrimitivePtr> basic_ops;
#endif

View File

@ -26,8 +26,8 @@
#include "ir/func_graph_cloner.h"
#include "ir/func_graph.h"
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
#ifdef ENABLE_D
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
#if ENABLE_GPU
#include "runtime/device/gpu/kernel_info_setter.h"
#endif
namespace mindspore {
@ -612,36 +612,6 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
return new_fg;
}
bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map,
std::vector<AnfNodePtrList> *res_graphs) {
MS_EXCEPTION_IF_NULL(res_graphs);
auto kernel_json = nlohmann::json::parse(json_desc);
if (kernel_json.find(kJsonKeyMultiGraph) == kernel_json.end() || kernel_json[kJsonKeyMultiGraph].is_null()) {
// not multi graphs.
MS_LOG(ERROR) << "Input json is not multi graph, " << json_desc;
return false;
}
kernel::AkgKernelJsonDecoder akg_kernel_json_decoder;
std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc];
if (graph_descs.empty()) {
MS_LOG(ERROR) << "No sub graph found, " << json_desc;
return false;
}
for (size_t i = 0; i < graph_descs.size(); ++i) {
const auto &graph_desc = graph_descs[i];
AnfNodePtrList res_graph;
if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
return false;
}
res_graphs->push_back(res_graph);
}
return true;
}
std::unordered_set<PrimitivePtr> GetExpandOps() {
std::unordered_set<PrimitivePtr> expand_ops = {
prim::kPrimSquare,
@ -664,5 +634,23 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
}
return name.str();
}
std::vector<PrimitivePtr> GetFusibleOpList() {
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect,
prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum};
return fusible_basic_ops;
}
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
#if ENABLE_GPU
device::gpu::SetKernelInfo(cnode, kernel_type);
#endif
}
} // namespace opt
} // namespace mindspore

View File

@ -35,6 +35,7 @@ constexpr auto kGraphKernelSplitFunc = "split_with_json";
constexpr auto kGetGraphKernelOpExpander = "get_op_expander";
constexpr auto kJsonKeyMultiGraph = "multi_graph";
constexpr auto kJsonKeyGraphDesc = "graph_desc";
constexpr auto kJsonKeyGraphMode = "graph_mode";
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs, kernel::Processor processor);
@ -50,10 +51,10 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n
std::map<std::string, AnfNodePtr> *address_node_map = nullptr);
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc);
FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs);
bool JsonDescToAnf(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map,
std::vector<AnfNodePtrList> *res_graphs);
std::unordered_set<PrimitivePtr> GetExpandOps();
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
std::vector<PrimitivePtr> GetFusibleOpList();
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_

View File

@ -26,6 +26,7 @@
#include "pipeline/jit/parse/python_adapter.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "debug/anf_ir_dump.h"
@ -203,7 +204,7 @@ class AreaGraph {
}
SortCNodes(main_cnodes);
cnode_group_id->swap(topo_order_); // The topo_order is not used anymore.
*cnode_group_id = std::move(topo_order_); // The topo_order is not used anymore.
return;
}
@ -291,7 +292,7 @@ class AreaGraph {
std::vector<CNodePtr> main_cnodes_sorted;
std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted),
[main_cnodes](int index) { return main_cnodes->at(index); });
main_cnodes->swap(main_cnodes_sorted);
*main_cnodes = std::move(main_cnodes_sorted);
}
// Areas in this subgraph
@ -415,6 +416,9 @@ class Splitter {
cnode->set_input(i, iter->second);
}
}
if (AnfAlgo::IsRealKernel(node)) {
ResetKernelInfo(node);
}
}
}
return output;
@ -445,7 +449,7 @@ class Splitter {
tmp_subgraph_cnodes.push_back(new_subgraph_cnodes_[i]);
}
}
new_subgraph_cnodes_.swap(tmp_subgraph_cnodes);
new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes);
TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
@ -580,15 +584,38 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
return false;
}
// recover json to anf-ir.
split_plan_.clear();
if (!JsonDescToAnf(split_graphs_str, address_node_map, &split_plan_)) {
MS_LOG(ERROR) << "Failed to decode split graphs.";
if (!DecodeJson(split_graphs_str, address_node_map)) {
MS_LOG(ERROR) << "Failed to decode split graphs. input json:\n" << split_graphs_str;
return false;
}
return true;
}
virtual bool DecodeJson(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map) {
auto kernel_json = nlohmann::json::parse(json_desc);
kernel::AkgKernelJsonDecoder akg_kernel_json_decoder;
std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc];
std::vector<std::string> graph_modes = kernel_json[kJsonKeyGraphMode];
if (graph_modes.size() != graph_descs.size()) {
MS_LOG(ERROR) << "Size of graph_mode " << graph_modes.size() << " mismatch graph_desc " << graph_descs.size();
return false;
}
// The info should be returned from costmodel.
need_inline_.assign(split_plan_.size(), 0);
// recover json to anfnode.
split_plan_.clear();
for (const auto &graph_desc : graph_descs) {
AnfNodePtrList res_graph;
if (!akg_kernel_json_decoder.DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
return false;
}
split_plan_.push_back(std::move(res_graph));
}
// ops to be inlined.
need_inline_.clear();
std::transform(graph_modes.begin(), graph_modes.end(), std::back_inserter(need_inline_),
[](const std::string &mode) { return mode == "basic" ? 1 : 0; });
return true;
}

View File

@ -13,5 +13,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
PYTHONPATH="$(pwd)/..:${PYTHONPATH}"
PYTHONPATH="$(pwd)/../../../../mindspore/_extends/graph_kernel:${PYTHONPATH}"
export PYTHONPATH

View File

@ -0,0 +1,436 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""Test split"""
import model
from model import model as estimate
from model import graph_split as split
def get_nodes(sp, ops):
"""Get nodes"""
if isinstance(ops[0], str):
new_ops = []
for t in ops:
for op in sp.graph.ops:
if op.output.name == t:
new_ops.append(op)
break
else:
print("ERROR: not found op: ", t)
ops = new_ops
return [sp.nodes[sp.graph.ops.index(op)] for op in ops]
def first_connected(sp, space):
for cand in space:
nodes = [sp.nodes[i] for i in cand[0]]
graphs = sp.resolve_connnected_graphs(nodes)
if len(graphs) != 1:
print("connect check faied: ", nodes)
return False
return True
def split_format(sp, cand):
names = []
for ids in cand:
ops = []
for i in ids:
ops.append(sp.graph.ops[i].output.name)
names.append(','.join(ops))
return '|'.join(names)
def graph_1():
''' ring, no succ_dep, no prev '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a = gb.tensor([10240, 16], "float32", name="a")
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", b, 'c')
d = gb.emit("Abs", c, 'd')
gb.emit('TensorAdd', [b, d], 'e')
return gb.get()[0]
def graph_2():
''' ring, succ_dep, no prev '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([10240, 16], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", a, 'c')
d = gb.emit("Abs", b, 'd')
e = gb.emit('TensorAdd', [c, d], 'e')
gb.emit("Abs", e, 'f')
return gb.get()[0]
def graph_3():
''' no ring, 1 sibling node '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([10240, 16], "float32", name="a0")
a1 = gb.tensor([10240, 16], "float32", name="a1")
b = gb.emit("Abs", a0, 'b')
c = gb.emit("Abs", a1, 'c')
d = gb.emit("Abs", b, 'd')
e = gb.emit('TensorAdd', [c, d], 'e')
gb.emit("Abs", e, 'f')
return gb.get()[0]
def graph_4():
''' no ring, 2 sibling nodes in 1 step '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([10240, 16], "float32", name="a0")
a1 = gb.tensor([10240, 16], "float32", name="a1")
b = gb.emit("Abs", a0, 'b')
c = gb.emit("Abs", b, 'c')
d = gb.emit("Abs", a1, 'd')
e = gb.emit("Abs", d, 'e')
f = gb.emit('TensorAdd', [c, e], 'f')
gb.emit('Abs', f, 'g')
h = gb.emit("Abs", d, 'h')
i = gb.emit('TensorAdd', [c, h], 'i')
gb.emit("Abs", i, 'j')
return gb.get()[0]
def graph_5():
''' no ring, 2 sibling step '''
gb = model.GraphBuilder()
with gb.graph_scope("main") as g:
a0 = gb.tensor([10240, 16], "float32", name="a0")
a1 = gb.tensor([10240, 16], "float32", name="a1")
a2 = gb.tensor([10240, 16], "float32", name="a2")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a1, 'b')
c = gb.emit("Abs", b, 'c')
d = gb.emit('TensorAdd', [a, c], 'd')
gb.emit("Abs", d, 'e')
f = gb.emit("Abs", a2, 'f')
g = gb.emit('TensorAdd', [c, f], 'g')
gb.emit("Abs", g, 'h')
return gb.get()[0]
def graph_6():
''' no ring, tree down '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([10240, 16], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
gb.emit("Abs", b, 'd')
gb.emit("Abs", b, 'e')
c = gb.emit("Abs", a, 'c')
gb.emit("Abs", c, 'f')
gb.emit("Abs", c, 'g')
return gb.get()[0]
def graph_pat_1():
''' split by reduce '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
d = gb.emit("Sqrt", c, 'd')
gb.emit("Sqrt", d, 'f')
return gb.get()[0]
def graph_pat_2():
''' multi output '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
gb.emit("ReduceSum", b, 'e', attrs={'reduce_axis': (1,)})
return gb.get()[0]
def graph_pat_3():
''' two reduce '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
d = gb.emit("Abs", c, 'd')
gb.emit("ReduceSum", d, 'e', attrs={'reduce_axis': (1,)})
return gb.get()[0]
def graph_pat_4():
''' elewise + broadcast '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([1, 1024], "float32", name="a0")
a2 = gb.tensor([1014, 1024], "float32", name="a2")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", b, 'c')
d = gb.emit("Abs", c, 'd')
e = gb.emit("Abs", d, 'e')
f = gb.emit("Abs", e, 'f')
g0 = gb.emit("Abs", a2, 'g0')
# g0 = gb.emit("Abs", g0, 'g0')
# g0 = gb.emit("Abs", g0, 'g0')
# g0 = gb.emit("Abs", g0, 'g0')
# g0 = gb.emit("Abs", g0, 'g0')
# g0 = gb.emit("Abs", g0, 'g0')
# g0 = gb.emit("Abs", g0, 'g0')
g0 = gb.emit("Abs", g0, 'g0')
g1 = gb.emit('TensorAdd', [f, g0], 'g1')
g2 = gb.emit("Abs", g1, 'g2')
g3 = gb.emit("Abs", g2, 'g3')
g4 = gb.emit("Abs", g3, 'g4')
gb.emit("Abs", g4, 'g5')
return gb.get()[0]
def graph_pat_5():
''' reduce + reshape '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
d = gb.emit("Abs", c, 'd')
e = gb.tensor([512, 2048], "float32", name="e")
gb.op("Reshape", e, [d])
return gb.get()[0]
def graph_pat_6():
''' dimond '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", a, 'c')
gb.emit("TensorAdd", [b, c], 'd')
gb.emit("Abs", c, 'f') # broke dimond
return gb.get()[0]
def graph_pat_7():
''' buddy of control op '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a1 = gb.tensor([1024, 1024], "float32", name="a1")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a1, 'b')
c = gb.emit("make_tuple", [a, b], 'c')
d = gb.tensor([1024, 1024], "float32", name="d")
gb.op("AddN", d, [c])
gb.emit("Abs", d, 'f')
graph = gb.get()[0]
estimate.AddControlBuddy().visit_graph(graph)
return graph
def graph_pat_8():
''' reduce + reshape '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
#c = gb.emit("Abs", b, 'b')
c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
gb.emit("TensorAdd", [b, c], 'd')
return gb.get()[0]
def graph_pat_9():
''' scalar '''
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a1 = gb.tensor([1], "float32", name="a1")
a = gb.emit("Maximum", a1, 'a')
b = gb.emit("Mul", [a, a1], 'b')
gb.emit('Mul', [b, a0], 'c')
return gb.get()[0]
def graph_mo_1():
gb = model.GraphBuilder()
with gb.graph_scope("main"):
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
gb.emit("Abs", a, 'b')
gb.emit("Abs", a, 'c')
return gb.get()[0]
def graph_mo_2():
gb = model.GraphBuilder()
with gb.graph_scope("main") as g:
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
c = gb.emit("Abs", b, 'c')
g.set_output(b, c)
return gb.get()[0]
def graph_mo_3():
''' two reduce '''
gb = model.GraphBuilder()
with gb.graph_scope("main") as g:
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
c = gb.emit("ReduceSum", b, 'c', attrs={'reduce_axis': (1,)})
g.set_output(b, c)
return gb.get()[0]
def graph_mo_4():
''' two reduce '''
gb = model.GraphBuilder()
with gb.graph_scope("main") as g:
a0 = gb.tensor([1024, 1024], "float32", name="a0")
a = gb.emit("Abs", a0, 'a')
b = gb.emit("Abs", a, 'b')
c = gb.emit("ReduceSum", a, 'c', attrs={'reduce_axis': (1,)})
g.set_output(b, c)
return gb.get()[0]
def test_binary_split():
"""Test binary split"""
def _test(graph, expected_space_size):
print("********* test on graph : {} *************".format(graph.name))
sp = split.GraphSpliter(graph)
nodes = get_nodes(sp, graph.ops)
space = sp.binary_split(nodes)
for i, s in enumerate(space):
print('{}: {}'.format(i, split_format(sp, s)))
assert len(space) == expected_space_size
assert first_connected(sp, space)
_test(graph_1(), 3)
_test(graph_2(), 7)
_test(graph_3(), 4)
_test(graph_4(), 17)
_test(graph_5(), 11)
_test(graph_6(), 24)
def test_resolve_connnected_graphs():
"""Test resolve connected graphs"""
graph = graph_5()
sp = split.GraphSpliter(graph)
n1 = get_nodes(sp, ['a', 'd', 'b', 'c'])
graphs = sp.resolve_connnected_graphs(n1)
print(graphs)
assert len(graphs) == 1
n2 = get_nodes(sp, ['a', 'd', 'e', 'f', 'g'])
graphs = sp.resolve_connnected_graphs(n2)
print(graphs)
assert len(graphs) == 2
n3 = get_nodes(sp, ['a', 'b', 'f'])
graphs = sp.resolve_connnected_graphs(n3)
print(graphs)
assert len(graphs) == 3
def test_split():
"""Test split"""
def _print_cost(name, c):
print("%s\tdma_ratio=%f, saturation=%f, mix_saturation=%f, type=%s" %
(name, c.dma_ratio(), c.saturation(), c.mix_saturation(), c.cost_type()))
def _test(graph):
print("********* test on graph : {} *************".format(graph.name))
sp = split.GraphSpliter(graph)
subgraphs = sp.split(False)
print('----- main graph -------')
print(graph)
for i, g in enumerate(subgraphs):
print(' -------- subgraph {} -------'.format(i))
print(g)
print("--------- cost ------------")
cost, _ = model.estimate(graph)
_print_cost("main graph", cost)
fc, sub_costs = model.estimate(subgraphs)
_print_cost("Subgraphs:", fc)
for i, cost in enumerate(sub_costs):
_print_cost(" |_%d:\t" % (i), cost)
_test(graph_5())
# _test(graph_4())
def test_estimate():
"""Test estimate"""
graph = graph_5()
e = estimate.Estimator(graph)
e.estimate()
print(e.iter_space)
def test_pattern_split():
"""Test pattern split"""
def _test(graph, expect_n=0):
print("************* main graph **************")
print(graph)
subgraphs = split.GraphSplitByPatternV2(graph).split()
for i, g in enumerate(subgraphs):
print(' -------- subgraph {} -------'.format(i))
print(g)
if expect_n > 0:
assert len(subgraphs) == expect_n
# _test(graph_1(), 1)
# _test(graph_pat_1(), 2)
# _test(graph_pat_2())
# _test(graph_pat_3())
# _test(graph_pat_4())
# _test(graph_pat_5())
# _test(graph_pat_6())
# _test(graph_pat_7())
# _test(graph_pat_8())
# _test(graph_pat_9())
# _test(graph_mo_1())
# _test(graph_mo_2())
# _test(graph_mo_3())
_test(graph_mo_4())
def main():
# test_binary_split()
# test_resolve_connnected_graphs()
# test_split()
# test_estimate()
test_pattern_split()
if __name__ == '__main__':
main()