!9111 【GraphKernel】Refactor the BasicOpsFusion and CompositeOpsFusion

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2020-12-04 14:33:28 +08:00 committed by Gitee
commit f02541b8ed
12 changed files with 489 additions and 370 deletions

View File

@ -61,11 +61,11 @@ def expand_fusedadam(expand_info):
next_para = graph_builder.emit('Sub', [param, update_with_lr])
param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True})
v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True})
param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
# set graph output.
graph_scope.set_output(param_result, m_result, v_result)
graph_scope.set_output(param_result)
graph = graph_builder.get()[0]
return graph

View File

@ -66,11 +66,11 @@ def expand_fusedadamweightdecay(expand_info):
next_para = graph_builder.emit('Sub', [param, update_with_lr])
para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True})
v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True})
para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True})
para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True})
# set graph output.
graph_scope.set_output(para_result, m_result, v_result)
graph_scope.set_output(para_result)
graph = graph_builder.get()[0]
return graph

View File

@ -117,8 +117,6 @@
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h"
#include "utils/ms_context.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
#include "debug/dump_proto.h"

View File

@ -31,20 +31,98 @@
#include "ir/func_graph_cloner.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/pass/getitem_tuple.h"
namespace mindspore {
namespace opt {
namespace {
bool IsFusibleOp(const AnfNodePtr &node) {
#if ENABLE_D
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
"LambNextMV", "LambUpdateWithLR"};
if (AnfAlgo::IsGraphKernel(node)) {
auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_attr != nullptr) {
return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0;
}
}
#endif
return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node);
}
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
if (cur_node == node) {
return FOLLOW;
}
if (!IsPrimitiveCNode(node)) {
return EXCLUDE;
if (IsFusibleOp(node)) {
return FOLLOW;
}
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (AnfAlgo::IsGraphKernel(prev_node)) {
return FOLLOW;
}
}
return EXCLUDE;
}
// The GetItem node should be fused with its real input and users.
// If its real input is not in the fuse_list, the GetItem should be excluded.
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
if (fused_op.empty()) return AnfNodePtrList();
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
auto mng = fused_op[0]->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = true;
while (changed) {
changed = false;
AnfNodePtrList remove_list;
for (auto getitem : fused_op_set) {
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
// GetItem should be fused with its real input.
auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (check_include(prev_node) == EXCLUDE) {
remove_list.push_back(getitem);
break;
}
// GetItem should be fused with its all users.
const auto &users = mng->node_users()[getitem];
if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
return check_include(user.first) == EXCLUDE;
})) {
remove_list = DeepLinkedGraphSearch(getitem, check_include);
break;
}
// To fix the issue of getitem-index, only support to fuse the previous node with its all users.
const auto &brothers = mng->node_users()[prev_node];
if (std::any_of(brothers.begin(), brothers.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
return check_include(user.first) == EXCLUDE;
})) {
remove_list = DeepLinkedGraphSearch(getitem, check_include);
break;
}
}
if (!remove_list.empty()) {
for (auto node : remove_list) {
fused_op_set.erase(node);
}
changed = true;
}
}
bool is_fusable = IsBasicFuseOp(node);
return is_fusable ? FOLLOW : EXCLUDE;
// keep the original order of fused_op.
AnfNodePtrList result;
for (auto node : fused_op) {
if (fused_op_set.count(node)) {
result.push_back(node);
}
}
return result;
}
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
@ -53,112 +131,13 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
if (used_nodes.size() > 1) {
used_nodes = RemoveCircle(used_nodes, dep_pri, false);
used_nodes = RemoveCircle(used_nodes, dep_pri);
}
used_nodes = RemoveWildGetitem(used_nodes);
TopoSortForNodeList(&used_nodes);
return used_nodes;
}
void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &users,
std::vector<CNodePtr> *control_depend_nodes, std::vector<size_t> *control_depend_use_index,
bool *is_only_control_depend_use, AnfNodePtr *use_out) {
for (auto &user : users) {
auto use_node = user.first;
if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) {
*is_only_control_depend_use = false;
continue;
}
if (outputs_set.count(use_node) != 0) {
*use_out = use_node;
}
if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) {
control_depend_nodes->push_back(use_node->cast<CNodePtr>());
control_depend_use_index->push_back(user.second);
}
}
}
bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng,
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
AnfNodeSet outputs_set;
for (auto out : *outputs) {
outputs_set.insert(out);
}
bool has_erase_outs = false;
int index = -1;
for (auto it = outputs->begin(); it != outputs->end();) {
index++;
auto out = *it;
(*eqv)[out] = vir_outputs[IntToSize(index)];
auto users = mng->node_users()[out];
bool is_only_control_depend_use = true;
std::vector<size_t> control_depend_use_index;
std::vector<CNodePtr> control_depend_nodes;
AnfNodePtr use_out = nullptr;
SearchForDependNode(outputs_set, users, &control_depend_nodes, &control_depend_use_index,
&is_only_control_depend_use, &use_out);
if (is_only_control_depend_use && !control_depend_nodes.empty()) {
MS_EXCEPTION_IF_NULL(use_out);
it = outputs->erase(it);
for (size_t i = 0; i < control_depend_nodes.size(); ++i) {
auto control_depend_node = control_depend_nodes[i];
std::vector<AnfNodePtr> new_control_depend_inputs;
for (size_t j = 0; j < control_depend_node->size(); ++j) {
if (j == control_depend_use_index[i]) {
new_control_depend_inputs.push_back(use_out);
} else {
new_control_depend_inputs.push_back(control_depend_node->input(j));
}
}
auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs);
mng->Replace(control_depend_node, new_control_depend);
has_erase_outs = true;
UpdateControlDependNode(depend_prior, control_depend_node, new_control_depend);
}
} else {
it++;
}
}
return has_erase_outs;
}
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
AnfNodePtrList vir_outputs;
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
auto fg_outputs = fg->output();
if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) {
auto cnode = fg_outputs->cast<CNodePtr>();
for (size_t i = 1; i < cnode->size(); ++i) {
vir_outputs.push_back(cnode->input(i));
}
} else {
vir_outputs.push_back(fg_outputs);
}
if (vir_outputs.size() != outputs->size()) {
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
}
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv, depend_prior)) {
return;
}
AnfNodePtr fg_new_output;
if (outputs->size() > 1) {
std::vector<AnfNodePtr> output_args;
output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
(void)std::transform(std::begin(*outputs), std::end(*outputs), std::back_inserter(output_args),
[&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
// Set output for AnfGraph
fg_new_output = fg->NewCNode(output_args);
} else {
fg_new_output = eqv[(*outputs)[0]];
}
fg->set_output(fg_new_output, true);
}
bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos,
std::unordered_set<AnfNodePtr> *fused_ops) {
bool changed = false;
@ -170,14 +149,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = (*iter)->cast<CNodePtr>();
if (node == nullptr) {
if (node == nullptr || fused_ops->count(node)) {
continue;
}
if (fused_ops->count(node)) {
continue;
}
bool is_basic_op = IsBasicFuseOp(node);
if (!is_basic_op || !kernel_graph->nodes().contains(node)) {
bool is_fusible_op = IsFusibleOp(node);
if (!is_fusible_op || !kernel_graph->nodes().contains(node)) {
continue;
}
@ -185,26 +161,12 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
if (fuse_nodes.size() <= 1) {
continue;
}
changed = true;
FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
RemoveControlDependOut(fg, &outputs, mng, &depend_prior);
ConvertNonscalarTensorToParameter(fg, &inputs);
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fuse_new_node, outputs);
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
// Set graph kernel attr
std::string fuse_op_name = "";
for (auto &fuse_node : fuse_nodes) {
fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_";
}
fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
AnfNodePtr fused_new_node;
AnfNodePtrList old_outputs;
std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "fusion");
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs);
}
std::dynamic_pointer_cast<session::KernelGraph>(kernel_graph)->SetExecOrderByDefault();
return changed;
@ -224,6 +186,22 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph) {
return FuseBasicOps(kernel_graph, todos, &fused_ops);
}
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { return FuseBasicOps(func_graph); }
void EliminateGetitem(const FuncGraphPtr &func_graph) {
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
auto todos = TopoSort(func_graph->get_return());
for (auto node : todos) {
if (AnfAlgo::IsGraphKernel(node)) {
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node));
}
}
}
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) {
bool changed = FuseBasicOps(func_graph);
if (changed) {
EliminateGetitem(func_graph);
}
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -50,53 +50,8 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roo
}
return inputs;
}
std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include,
const FuncGraphManagerPtr &mng) {
std::vector<AnfNodePtr> users;
for (auto &root : roots) {
auto tmp = DeepUsersSearch(root, include, mng);
users.insert(users.end(), tmp.begin(), tmp.end());
}
return users;
}
} // namespace
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
if (cur_node == node) {
return FOLLOW;
}
if (AnfAlgo::IsGraphKernel(node) || IsBasicFuseOp(node)) {
return FOLLOW;
}
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (AnfAlgo::IsGraphKernel(prev_node)) {
return FOLLOW;
}
}
return EXCLUDE;
}
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
if (cur_node == node) {
return FOLLOW;
}
if (AnfAlgo::IsGraphKernel(node)) {
auto cnode = node->cast<CNodePtr>();
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kAnfPrimitiveIndex));
auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
MS_EXCEPTION_IF_NULL(fg_attr_val);
auto fg_attr = GetValue<std::string>(fg_attr_val);
if (fg_attr == kApplyMomentumOpName) {
return FOLLOW;
}
return EXCLUDE;
}
bool is_fusable = IsBasicFuseOp(node);
return is_fusable ? FOLLOW : EXCLUDE;
}
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
@ -163,9 +118,8 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
return !circle_nodes->empty();
}
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
bool is_backward) {
AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
std::set<AnfNodePtr> cached_unconnected_set;
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto include = [&fused_op_set](const AnfNodePtr &node) {
@ -181,13 +135,8 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior);
// delete the circle node and the node which depend on the circle node in fused op
if (has_circle) {
auto mng = (*iter)->func_graph()->manager();
std::vector<AnfNodePtr> erase_nodes;
if (is_backward) {
erase_nodes = DeepUsersSearch(circle_nodes, include, mng);
} else {
erase_nodes = DeepLinkedGraphSearch(circle_nodes, include);
}
erase_nodes = DeepLinkedGraphSearch(circle_nodes, include);
for (auto erase_node : erase_nodes) {
fused_op_set.erase(erase_node);
}
@ -203,60 +152,6 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
return res;
}
// The GetItem node should be fused with its real input and users.
// If its real input is not in the fuse_list, the GetItem should be excluded.
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
if (fused_op.empty()) return AnfNodePtrList();
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
auto mng = fused_op[0]->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = true;
while (changed) {
changed = false;
AnfNodePtrList remove_list;
for (auto node : fused_op_set) {
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) continue;
// GetItem should be fused with its real input.
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (check_include(prev_node) == EXCLUDE) {
remove_list.push_back(node);
break;
}
// GetItem should be fused with its all users.
auto &users = mng->node_users()[node];
bool outside_user_found = false;
for (auto iter = users.begin(); iter != users.end(); ++iter) {
if (check_include(iter->first) == EXCLUDE) {
outside_user_found = true;
break;
}
}
if (outside_user_found) {
remove_list = DeepUsersSearch(node, check_include, mng);
break;
}
}
if (!remove_list.empty()) {
for (auto node : remove_list) {
fused_op_set.erase(node);
}
changed = true;
}
}
// keep the original order of fused_op.
AnfNodePtrList result;
for (auto node : fused_op) {
if (fused_op_set.count(node)) {
result.push_back(node);
}
}
return result;
}
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
if (lst->size() < 2) {
return;
@ -310,87 +205,5 @@ void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
lst->assign(res.begin(), res.end());
}
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
auto func_graph = cnode->func_graph();
auto mng = func_graph->manager();
// Search fusable nodes according input direction.
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
std::reverse(used_nodes.begin(), used_nodes.end());
// Search fusable nodes according output direction.
auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, std::placeholders::_1);
auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng);
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
if (used_nodes.size() > 1) {
used_nodes = RemoveCircle(used_nodes, dep_pri);
}
used_nodes = RemoveWildGetitem(used_nodes);
TopoSortForNodeList(&used_nodes);
return used_nodes;
}
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
auto todos = TopoSort(kernel_graph->get_return());
std::reverse(todos.begin(), todos.end());
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
InitDependPrior(todos, &depend_prior);
bool changed = false;
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = *iter;
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
continue;
}
auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_attr != nullptr) {
auto fg_name = GetValue<std::string>(fg_attr);
if (graph_kernel_black_list.count(fg_name) != 0) {
continue;
}
}
auto fuse_nodes = FindFuseCNodes(node->cast<CNodePtr>(), depend_prior);
if (fuse_nodes.size() <= 1) {
continue;
}
changed = true;
AnfNodePtr fused_new_node;
AnfNodePtrList old_outputs;
std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs);
}
return changed;
}
void EliminateGetItem(const FuncGraphPtr &func_graph) {
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
auto todos = TopoSort(func_graph->get_return());
for (auto node : todos) {
if (AnfAlgo::IsGraphKernel(node)) {
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node));
}
}
}
bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) {
auto changed = FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph));
if (changed) {
EliminateGetItem(func_graph);
}
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -28,24 +28,10 @@
namespace mindspore {
namespace opt {
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
"LambNextMV", "LambUpdateWithLR"};
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
bool is_backward = true);
AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior);
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph);
class CompositeOpsFusion : public Pass {
public:
CompositeOpsFusion() : Pass("composite_ops_fusion") {}
~CompositeOpsFusion() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
using FuseGraphKernelPassPtr = std::shared_ptr<CompositeOpsFusion>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_

View File

@ -0,0 +1,327 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
#include <memory>
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
#include "base/core_ops.h"
#include "ir/graph_utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "debug/anf_ir_dump.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
namespace mindspore {
namespace opt {
namespace {
inline size_t GetIndex(const AnfNodePtr &getitem_node) {
MS_EXCEPTION_IF_NULL(getitem_node);
if (!IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) {
MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem but got " << getitem_node->fullname_with_scope();
}
return LongToSize(GetValue<int64_t>(
getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value()));
}
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
bool merge_repeated_getitem = false) {
MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(getitem_list);
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(func_graph);
auto output = func_graph->output();
if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "The output should be a MakeTuple, but got " << output->fullname_with_scope();
}
auto output_num = output->cast<CNodePtr>()->size() - 1;
getitem_list->clear();
getitem_list->resize(output_num, nullptr);
const auto &users = mng->node_users()[node];
bool changed = false;
AnfNodePtrList user_nodes;
std::transform(users.begin(), users.end(), std::back_inserter(user_nodes),
[](const std::pair<AnfNodePtr, int> &user) { return user.first; });
for (const auto &getitem : user_nodes) {
MS_EXCEPTION_IF_NULL(getitem);
auto idx = GetIndex(getitem);
if (idx >= output_num) {
MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString();
}
if (merge_repeated_getitem && (*getitem_list)[idx] != nullptr) {
mng->Replace(getitem, (*getitem_list)[idx]);
changed = true;
} else {
(*getitem_list)[idx] = getitem;
}
}
return changed;
}
AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
auto todos = TopoSort(func_graph->get_return());
AnfNodePtrList result;
std::copy_if(todos.begin(), todos.end(), std::back_inserter(result), [](const AnfNodePtr &node) {
return AnfAlgo::IsGraphKernel(node) &&
IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(node)->output(), prim::kPrimMakeTuple);
});
return result;
}
/* Merge the get_item nodes that have same index.
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 0)
* %4 = tuple_getitem(%1, 1)
* %5 = user_x(%2)
* %6 = user_y(%3)
* %7 = user_z(%4)
* --->
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 1)
* %4 = user_x(%2)
* %5 = user_y(%2)
* %6 = user_z(%3)
*/
class MergeRepeatedGetitem : public Pass {
public:
bool Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
bool changed = false;
for (auto node : todos) {
AnfNodePtrList getitem_list;
changed = GetGraphKernelGetitemList(mng, node, &getitem_list, true) || changed;
}
return changed;
}
};
/* Merge the get_item nodes that have same index.
* subgraph graph_kernel(%para1, %para2)
* %1 = TensorAdd(%para1, %para2)
* %2 = Neg(%1)
* %3 = make_tuple(%1, %2)
* return (%3)
* %1 = call @graph_kernel(%p1, %p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 1)
* %4 = ControlDepend(%0, %2)
* %5 = other_user(%3)
* --->
* subgraph graph_kernel(%para1, %para2)
* %1 = TensorAdd(%para1, %para2)
* %2 = Neg(%1)
* %3 = make_tuple(%1, %2)
* return (%3)
* %1 = call @graph_kernel(%p1, %p2)
* %3 = tuple_getitem(%1, 1)
* %4 = ControlDepend(%0, %3)
* %5 = other_user(%3)
*
* Then the output 0 can be eliminate in the later pass.
*/
class EliminateGetitemForControlDepend : public Pass {
public:
bool Run(const FuncGraphPtr &func_graph) {
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = false;
for (const auto &node : todos) {
getitems_.clear();
GetGraphKernelGetitemList(mng, node, &getitems_, false);
if (getitems_.empty()) continue;
indexes_.clear();
GetIndexesToControlDepend(mng);
FilterRedundantOutputs(node);
if (indexes_.empty()) continue;
size_t index = GetFinalIndex(node);
changed = ReplaceGetitems(mng, index) || changed;
}
return changed;
}
private:
AnfNodePtrList getitems_; // Users of GraphKernel node with multiple outputs.
std::vector<size_t> indexes_; // Indexes of MakeTuple to be eliminated.
bool ReplaceGetitems(const FuncGraphManagerPtr &mng, size_t index) {
MS_EXCEPTION_IF_NULL(getitems_[index]);
bool changed = false;
for (auto i : indexes_) {
if (i != index) {
MS_EXCEPTION_IF_NULL(getitems_[i]);
mng->Replace(getitems_[i], getitems_[index]);
changed = true;
}
}
return changed;
}
// Find the redundant output index.
// the real output should have multiple users.
void FilterRedundantOutputs(const AnfNodePtr &node) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto &users = mng->node_users();
auto maketuple = func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maketuple);
std::vector<size_t> result;
for (auto i : indexes_) {
auto real_output = maketuple->input(i);
if (users[real_output].size() > 1) {
result.push_back(i);
}
}
indexes_ = std::move(result);
}
// Get the nodes that only have ControlDepend users.
void GetIndexesToControlDepend(const FuncGraphManagerPtr &mng) {
for (size_t i = 0; i < getitems_.size(); ++i) {
const AnfNodePtr &getitem = getitems_[i];
if (getitem == nullptr) {
continue;
}
const auto &getitem_user = mng->node_users()[getitem];
if (std::all_of(getitem_user.begin(), getitem_user.end(), [](const std::pair<AnfNodePtr, int> &user) {
return IsPrimitiveCNode(user.first, prim::kPrimControlDepend);
})) {
indexes_.push_back(i);
}
}
}
size_t GetFinalIndex(const AnfNodePtr &node) {
auto is_redundant_index = [this](size_t i) {
return std::find(indexes_.begin(), indexes_.end(), i) != indexes_.end();
};
for (size_t i = 0; i < getitems_.size(); ++i) {
if (getitems_[i] != nullptr && !is_redundant_index(i)) {
return i;
}
}
return indexes_[0];
}
};
} // namespace
// Remove the output without user or with virtual user (like ControlDepend)
bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
bool changed = std::make_shared<MergeRepeatedGetitem>()->Run(func_graph);
changed = std::make_shared<EliminateGetitemForControlDepend>()->Run(func_graph) || changed;
changed = Process(func_graph) || changed;
return changed;
}
void EliminateRedundantOutput::UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset) {
if (offset == 0) return;
MS_EXCEPTION_IF_NULL(getitem);
int64_t index = SizeToLong(GetIndex(getitem));
if (offset > index) {
MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString();
}
index -= offset;
auto idx_node = NewValueNode(MakeValue<int64_t>(index));
auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
idx_node->set_abstract(abstract);
idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
}
AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(func_graph);
auto old_maketuple = func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(old_maketuple);
AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
AbstractBasePtrList abstract_list;
int64_t offset = 0;
for (size_t i = 0; i < getitems.size(); ++i) {
if (getitems[i] == nullptr) {
offset++;
} else {
new_maketuple_inputs.push_back(old_maketuple->input(i + 1));
abstract_list.push_back(old_maketuple->input(i + 1)->abstract());
UpdateGetitemIndex(getitems[i]->cast<CNodePtr>(), offset);
}
}
if (offset == 0) return nullptr;
if (new_maketuple_inputs.size() == 1) {
MS_LOG(EXCEPTION) << "Input of MakeTuple could not be empty";
}
if (new_maketuple_inputs.size() == 2) {
func_graph->set_output(new_maketuple_inputs.back());
} else {
auto make_tuple = func_graph->NewCNode(new_maketuple_inputs);
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
make_tuple->set_kernel_info(std::make_shared<device::KernelInfo>());
func_graph->set_output(make_tuple);
}
auto old_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(old_cnode);
AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end());
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs);
SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs, AnfAlgo::GetProcessor(node));
return graph_kernel_node;
}
bool EliminateRedundantOutput::Process(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
bool changed = false;
for (auto node : todos) {
AnfNodePtrList getitems;
GetGraphKernelGetitemList(mng, node, &getitems, false);
auto new_node = ReplaceMakeTuple(node, getitems);
if (new_node != nullptr) {
if (!IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(new_node)->output(), prim::kPrimMakeTuple)) {
// only one output, remove the getitem.
auto i = std::find_if(getitems.begin(), getitems.end(), [](const AnfNodePtr &node) { return node != nullptr; });
if (i != getitems.end()) {
mng->Replace(*i, new_node);
}
} else {
mng->Replace(node, new_node);
}
changed = true;
}
}
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class EliminateRedundantOutput : public Pass {
public:
EliminateRedundantOutput() : Pass("eliminate_redundant_output") {}
~EliminateRedundantOutput() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool Process(const FuncGraphPtr &func_graph);
void UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset);
AnfNodePtr ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems);
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_

View File

@ -531,7 +531,7 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f
}
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph,
const FuncGraphPtr &kernel_graph,
const std::string &postfix) {
auto mng = kernel_graph->manager();
if (mng == nullptr) {
@ -861,23 +861,6 @@ void InitDependPrior(const std::vector<AnfNodePtr> &todos,
}
}
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend) {
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
if (iter->second.second == control_depend_node) {
iter = depend_prior->erase(iter);
continue;
}
++iter;
}
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_depend_prior;
InitDependPrior(std::vector<AnfNodePtr>{new_control_depend}, &new_depend_prior);
for (auto item : new_depend_prior) {
depend_prior->insert(item);
}
}
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;

View File

@ -65,8 +65,8 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphP
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
const AnfNodePtrList &outputs);
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph,
const std::string &postfix);
const FuncGraphPtr &kernel_graph,
const std::string &postfix = "");
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map);
@ -79,8 +79,6 @@ bool IsBasicFuseOp(const AnfNodePtr &node);
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior);
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend);
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);

View File

@ -42,7 +42,7 @@
#include "debug/data_dump/dump_json_parser.h"
#include "debug/tensor_load.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
@ -822,7 +822,7 @@ void AscendSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kern
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
pm->AddPass(std::make_shared<opt::TensorPromotion>());
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());

View File

@ -38,7 +38,7 @@
#include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h"
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
@ -171,7 +171,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>());
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));