forked from mindspore-Ecosystem/mindspore
!9111 【GraphKernel】Refactor the BasicOpsFusion and CompositeOpsFusion
From: @dayschan Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
f02541b8ed
|
@ -61,11 +61,11 @@ def expand_fusedadam(expand_info):
|
|||
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
||||
|
||||
param_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
|
||||
m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True})
|
||||
v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True})
|
||||
param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
|
||||
param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(param_result, m_result, v_result)
|
||||
graph_scope.set_output(param_result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
|
|
|
@ -66,11 +66,11 @@ def expand_fusedadamweightdecay(expand_info):
|
|||
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
||||
|
||||
para_result = graph_builder.emit('InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
|
||||
m_result = graph_builder.emit('InplaceAssign', [m, next_m, next_m], attrs={'fake_output': True})
|
||||
v_result = graph_builder.emit('InplaceAssign', [v, next_v, next_v], attrs={'fake_output': True})
|
||||
para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True})
|
||||
para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True})
|
||||
|
||||
# set graph output.
|
||||
graph_scope.set_output(para_result, m_result, v_result)
|
||||
graph_scope.set_output(para_result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
||||
|
|
|
@ -117,8 +117,6 @@
|
|||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
|
||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
|
|
|
@ -31,20 +31,98 @@
|
|||
#include "ir/func_graph_cloner.h"
|
||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool IsFusibleOp(const AnfNodePtr &node) {
|
||||
#if ENABLE_D
|
||||
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
||||
"LambNextMV", "LambUpdateWithLR"};
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
if (fg_attr != nullptr) {
|
||||
return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node);
|
||||
}
|
||||
|
||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
||||
if (cur_node == node) {
|
||||
return FOLLOW;
|
||||
}
|
||||
if (!IsPrimitiveCNode(node)) {
|
||||
return EXCLUDE;
|
||||
if (IsFusibleOp(node)) {
|
||||
return FOLLOW;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||
if (AnfAlgo::IsGraphKernel(prev_node)) {
|
||||
return FOLLOW;
|
||||
}
|
||||
}
|
||||
return EXCLUDE;
|
||||
}
|
||||
|
||||
// The GetItem node should be fused with its real input and users.
|
||||
// If its real input is not in the fuse_list, the GetItem should be excluded.
|
||||
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
|
||||
if (fused_op.empty()) return AnfNodePtrList();
|
||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
|
||||
|
||||
auto mng = fused_op[0]->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
AnfNodePtrList remove_list;
|
||||
for (auto getitem : fused_op_set) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
|
||||
|
||||
// GetItem should be fused with its real input.
|
||||
auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||
if (check_include(prev_node) == EXCLUDE) {
|
||||
remove_list.push_back(getitem);
|
||||
break;
|
||||
}
|
||||
|
||||
// GetItem should be fused with its all users.
|
||||
const auto &users = mng->node_users()[getitem];
|
||||
if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
|
||||
return check_include(user.first) == EXCLUDE;
|
||||
})) {
|
||||
remove_list = DeepLinkedGraphSearch(getitem, check_include);
|
||||
break;
|
||||
}
|
||||
|
||||
// To fix the issue of getitem-index, only support to fuse the previous node with its all users.
|
||||
const auto &brothers = mng->node_users()[prev_node];
|
||||
if (std::any_of(brothers.begin(), brothers.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
|
||||
return check_include(user.first) == EXCLUDE;
|
||||
})) {
|
||||
remove_list = DeepLinkedGraphSearch(getitem, check_include);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!remove_list.empty()) {
|
||||
for (auto node : remove_list) {
|
||||
fused_op_set.erase(node);
|
||||
}
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_fusable = IsBasicFuseOp(node);
|
||||
return is_fusable ? FOLLOW : EXCLUDE;
|
||||
// keep the original order of fused_op.
|
||||
AnfNodePtrList result;
|
||||
for (auto node : fused_op) {
|
||||
if (fused_op_set.count(node)) {
|
||||
result.push_back(node);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
|
||||
|
@ -53,112 +131,13 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
|
|||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
|
||||
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
||||
if (used_nodes.size() > 1) {
|
||||
used_nodes = RemoveCircle(used_nodes, dep_pri, false);
|
||||
used_nodes = RemoveCircle(used_nodes, dep_pri);
|
||||
}
|
||||
used_nodes = RemoveWildGetitem(used_nodes);
|
||||
TopoSortForNodeList(&used_nodes);
|
||||
return used_nodes;
|
||||
}
|
||||
|
||||
void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &users,
|
||||
std::vector<CNodePtr> *control_depend_nodes, std::vector<size_t> *control_depend_use_index,
|
||||
bool *is_only_control_depend_use, AnfNodePtr *use_out) {
|
||||
for (auto &user : users) {
|
||||
auto use_node = user.first;
|
||||
if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) {
|
||||
*is_only_control_depend_use = false;
|
||||
continue;
|
||||
}
|
||||
if (outputs_set.count(use_node) != 0) {
|
||||
*use_out = use_node;
|
||||
}
|
||||
if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) {
|
||||
control_depend_nodes->push_back(use_node->cast<CNodePtr>());
|
||||
control_depend_use_index->push_back(user.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv,
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
|
||||
AnfNodeSet outputs_set;
|
||||
for (auto out : *outputs) {
|
||||
outputs_set.insert(out);
|
||||
}
|
||||
bool has_erase_outs = false;
|
||||
int index = -1;
|
||||
for (auto it = outputs->begin(); it != outputs->end();) {
|
||||
index++;
|
||||
auto out = *it;
|
||||
(*eqv)[out] = vir_outputs[IntToSize(index)];
|
||||
auto users = mng->node_users()[out];
|
||||
bool is_only_control_depend_use = true;
|
||||
std::vector<size_t> control_depend_use_index;
|
||||
std::vector<CNodePtr> control_depend_nodes;
|
||||
AnfNodePtr use_out = nullptr;
|
||||
SearchForDependNode(outputs_set, users, &control_depend_nodes, &control_depend_use_index,
|
||||
&is_only_control_depend_use, &use_out);
|
||||
if (is_only_control_depend_use && !control_depend_nodes.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(use_out);
|
||||
it = outputs->erase(it);
|
||||
for (size_t i = 0; i < control_depend_nodes.size(); ++i) {
|
||||
auto control_depend_node = control_depend_nodes[i];
|
||||
std::vector<AnfNodePtr> new_control_depend_inputs;
|
||||
for (size_t j = 0; j < control_depend_node->size(); ++j) {
|
||||
if (j == control_depend_use_index[i]) {
|
||||
new_control_depend_inputs.push_back(use_out);
|
||||
} else {
|
||||
new_control_depend_inputs.push_back(control_depend_node->input(j));
|
||||
}
|
||||
}
|
||||
auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs);
|
||||
mng->Replace(control_depend_node, new_control_depend);
|
||||
has_erase_outs = true;
|
||||
UpdateControlDependNode(depend_prior, control_depend_node, new_control_depend);
|
||||
}
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
return has_erase_outs;
|
||||
}
|
||||
|
||||
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng,
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
|
||||
AnfNodePtrList vir_outputs;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
|
||||
auto fg_outputs = fg->output();
|
||||
if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) {
|
||||
auto cnode = fg_outputs->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
vir_outputs.push_back(cnode->input(i));
|
||||
}
|
||||
} else {
|
||||
vir_outputs.push_back(fg_outputs);
|
||||
}
|
||||
|
||||
if (vir_outputs.size() != outputs->size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
|
||||
}
|
||||
|
||||
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv, depend_prior)) {
|
||||
return;
|
||||
}
|
||||
|
||||
AnfNodePtr fg_new_output;
|
||||
if (outputs->size() > 1) {
|
||||
std::vector<AnfNodePtr> output_args;
|
||||
output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
(void)std::transform(std::begin(*outputs), std::end(*outputs), std::back_inserter(output_args),
|
||||
[&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; });
|
||||
// Set output for AnfGraph
|
||||
fg_new_output = fg->NewCNode(output_args);
|
||||
} else {
|
||||
fg_new_output = eqv[(*outputs)[0]];
|
||||
}
|
||||
fg->set_output(fg_new_output, true);
|
||||
}
|
||||
|
||||
bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos,
|
||||
std::unordered_set<AnfNodePtr> *fused_ops) {
|
||||
bool changed = false;
|
||||
|
@ -170,14 +149,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
|
||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||
auto node = (*iter)->cast<CNodePtr>();
|
||||
if (node == nullptr) {
|
||||
if (node == nullptr || fused_ops->count(node)) {
|
||||
continue;
|
||||
}
|
||||
if (fused_ops->count(node)) {
|
||||
continue;
|
||||
}
|
||||
bool is_basic_op = IsBasicFuseOp(node);
|
||||
if (!is_basic_op || !kernel_graph->nodes().contains(node)) {
|
||||
bool is_fusible_op = IsFusibleOp(node);
|
||||
if (!is_fusible_op || !kernel_graph->nodes().contains(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -185,26 +161,12 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
if (fuse_nodes.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
changed = true;
|
||||
FuncGraphPtr fg;
|
||||
AnfNodePtrList inputs;
|
||||
AnfNodePtrList outputs;
|
||||
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
|
||||
RemoveControlDependOut(fg, &outputs, mng, &depend_prior);
|
||||
ConvertNonscalarTensorToParameter(fg, &inputs);
|
||||
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
|
||||
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
|
||||
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fuse_new_node, outputs);
|
||||
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
|
||||
|
||||
// Set graph kernel attr
|
||||
std::string fuse_op_name = "";
|
||||
for (auto &fuse_node : fuse_nodes) {
|
||||
fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_";
|
||||
}
|
||||
fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
|
||||
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
||||
AnfNodePtr fused_new_node;
|
||||
AnfNodePtrList old_outputs;
|
||||
std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "fusion");
|
||||
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs);
|
||||
}
|
||||
std::dynamic_pointer_cast<session::KernelGraph>(kernel_graph)->SetExecOrderByDefault();
|
||||
return changed;
|
||||
|
@ -224,6 +186,22 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph) {
|
|||
return FuseBasicOps(kernel_graph, todos, &fused_ops);
|
||||
}
|
||||
|
||||
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { return FuseBasicOps(func_graph); }
|
||||
void EliminateGetitem(const FuncGraphPtr &func_graph) {
|
||||
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
bool changed = FuseBasicOps(func_graph);
|
||||
if (changed) {
|
||||
EliminateGetitem(func_graph);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,53 +50,8 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roo
|
|||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
std::vector<AnfNodePtr> users;
|
||||
for (auto &root : roots) {
|
||||
auto tmp = DeepUsersSearch(root, include, mng);
|
||||
users.insert(users.end(), tmp.begin(), tmp.end());
|
||||
}
|
||||
return users;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
||||
if (cur_node == node) {
|
||||
return FOLLOW;
|
||||
}
|
||||
if (AnfAlgo::IsGraphKernel(node) || IsBasicFuseOp(node)) {
|
||||
return FOLLOW;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||
if (AnfAlgo::IsGraphKernel(prev_node)) {
|
||||
return FOLLOW;
|
||||
}
|
||||
}
|
||||
return EXCLUDE;
|
||||
}
|
||||
|
||||
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
||||
if (cur_node == node) {
|
||||
return FOLLOW;
|
||||
}
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kAnfPrimitiveIndex));
|
||||
auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
MS_EXCEPTION_IF_NULL(fg_attr_val);
|
||||
auto fg_attr = GetValue<std::string>(fg_attr_val);
|
||||
if (fg_attr == kApplyMomentumOpName) {
|
||||
return FOLLOW;
|
||||
}
|
||||
return EXCLUDE;
|
||||
}
|
||||
bool is_fusable = IsBasicFuseOp(node);
|
||||
return is_fusable ? FOLLOW : EXCLUDE;
|
||||
}
|
||||
|
||||
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
|
||||
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
|
||||
|
@ -163,9 +118,8 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
|
|||
return !circle_nodes->empty();
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
|
||||
bool is_backward) {
|
||||
AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
|
||||
std::set<AnfNodePtr> cached_unconnected_set;
|
||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||
auto include = [&fused_op_set](const AnfNodePtr &node) {
|
||||
|
@ -181,13 +135,8 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
|||
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior);
|
||||
// delete the circle node and the node which depend on the circle node in fused op
|
||||
if (has_circle) {
|
||||
auto mng = (*iter)->func_graph()->manager();
|
||||
std::vector<AnfNodePtr> erase_nodes;
|
||||
if (is_backward) {
|
||||
erase_nodes = DeepUsersSearch(circle_nodes, include, mng);
|
||||
} else {
|
||||
erase_nodes = DeepLinkedGraphSearch(circle_nodes, include);
|
||||
}
|
||||
erase_nodes = DeepLinkedGraphSearch(circle_nodes, include);
|
||||
for (auto erase_node : erase_nodes) {
|
||||
fused_op_set.erase(erase_node);
|
||||
}
|
||||
|
@ -203,60 +152,6 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
|||
return res;
|
||||
}
|
||||
|
||||
// The GetItem node should be fused with its real input and users.
|
||||
// If its real input is not in the fuse_list, the GetItem should be excluded.
|
||||
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
|
||||
if (fused_op.empty()) return AnfNodePtrList();
|
||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
|
||||
|
||||
auto mng = fused_op[0]->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
AnfNodePtrList remove_list;
|
||||
for (auto node : fused_op_set) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) continue;
|
||||
// GetItem should be fused with its real input.
|
||||
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||
if (check_include(prev_node) == EXCLUDE) {
|
||||
remove_list.push_back(node);
|
||||
break;
|
||||
}
|
||||
|
||||
// GetItem should be fused with its all users.
|
||||
auto &users = mng->node_users()[node];
|
||||
bool outside_user_found = false;
|
||||
for (auto iter = users.begin(); iter != users.end(); ++iter) {
|
||||
if (check_include(iter->first) == EXCLUDE) {
|
||||
outside_user_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (outside_user_found) {
|
||||
remove_list = DeepUsersSearch(node, check_include, mng);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!remove_list.empty()) {
|
||||
for (auto node : remove_list) {
|
||||
fused_op_set.erase(node);
|
||||
}
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
// keep the original order of fused_op.
|
||||
AnfNodePtrList result;
|
||||
for (auto node : fused_op) {
|
||||
if (fused_op_set.count(node)) {
|
||||
result.push_back(node);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
|
||||
if (lst->size() < 2) {
|
||||
return;
|
||||
|
@ -310,87 +205,5 @@ void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
|
|||
|
||||
lst->assign(res.begin(), res.end());
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
|
||||
auto func_graph = cnode->func_graph();
|
||||
auto mng = func_graph->manager();
|
||||
// Search fusable nodes according input direction.
|
||||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
|
||||
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
||||
std::reverse(used_nodes.begin(), used_nodes.end());
|
||||
// Search fusable nodes according output direction.
|
||||
auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, std::placeholders::_1);
|
||||
auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng);
|
||||
|
||||
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
|
||||
if (used_nodes.size() > 1) {
|
||||
used_nodes = RemoveCircle(used_nodes, dep_pri);
|
||||
}
|
||||
used_nodes = RemoveWildGetitem(used_nodes);
|
||||
TopoSortForNodeList(&used_nodes);
|
||||
return used_nodes;
|
||||
}
|
||||
|
||||
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(kernel_graph, true);
|
||||
kernel_graph->set_manager(mng);
|
||||
}
|
||||
auto todos = TopoSort(kernel_graph->get_return());
|
||||
std::reverse(todos.begin(), todos.end());
|
||||
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
|
||||
InitDependPrior(todos, &depend_prior);
|
||||
|
||||
bool changed = false;
|
||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||
auto node = *iter;
|
||||
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
if (fg_attr != nullptr) {
|
||||
auto fg_name = GetValue<std::string>(fg_attr);
|
||||
if (graph_kernel_black_list.count(fg_name) != 0) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto fuse_nodes = FindFuseCNodes(node->cast<CNodePtr>(), depend_prior);
|
||||
if (fuse_nodes.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
changed = true;
|
||||
|
||||
AnfNodePtr fused_new_node;
|
||||
AnfNodePtrList old_outputs;
|
||||
std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
|
||||
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
void EliminateGetItem(const FuncGraphPtr &func_graph) {
|
||||
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
auto changed = FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph));
|
||||
if (changed) {
|
||||
EliminateGetItem(func_graph);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,24 +28,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
||||
"LambNextMV", "LambUpdateWithLR"};
|
||||
|
||||
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
|
||||
bool is_backward = true);
|
||||
AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior);
|
||||
|
||||
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
|
||||
|
||||
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
|
||||
class CompositeOpsFusion : public Pass {
|
||||
public:
|
||||
CompositeOpsFusion() : Pass("composite_ops_fusion") {}
|
||||
~CompositeOpsFusion() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
using FuseGraphKernelPassPtr = std::shared_ptr<CompositeOpsFusion>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
||||
|
|
|
@ -0,0 +1,327 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
|
||||
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
inline size_t GetIndex(const AnfNodePtr &getitem_node) {
|
||||
MS_EXCEPTION_IF_NULL(getitem_node);
|
||||
if (!IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) {
|
||||
MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem but got " << getitem_node->fullname_with_scope();
|
||||
}
|
||||
return LongToSize(GetValue<int64_t>(
|
||||
getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value()));
|
||||
}
|
||||
|
||||
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
|
||||
bool merge_repeated_getitem = false) {
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
MS_EXCEPTION_IF_NULL(getitem_list);
|
||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto output = func_graph->output();
|
||||
if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "The output should be a MakeTuple, but got " << output->fullname_with_scope();
|
||||
}
|
||||
auto output_num = output->cast<CNodePtr>()->size() - 1;
|
||||
getitem_list->clear();
|
||||
getitem_list->resize(output_num, nullptr);
|
||||
const auto &users = mng->node_users()[node];
|
||||
bool changed = false;
|
||||
AnfNodePtrList user_nodes;
|
||||
std::transform(users.begin(), users.end(), std::back_inserter(user_nodes),
|
||||
[](const std::pair<AnfNodePtr, int> &user) { return user.first; });
|
||||
for (const auto &getitem : user_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
auto idx = GetIndex(getitem);
|
||||
if (idx >= output_num) {
|
||||
MS_LOG(EXCEPTION) << "Index of GetItem is out of range of MakeTuple. getitem node: " << getitem->DebugString();
|
||||
}
|
||||
if (merge_repeated_getitem && (*getitem_list)[idx] != nullptr) {
|
||||
mng->Replace(getitem, (*getitem_list)[idx]);
|
||||
changed = true;
|
||||
} else {
|
||||
(*getitem_list)[idx] = getitem;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
AnfNodePtrList result;
|
||||
std::copy_if(todos.begin(), todos.end(), std::back_inserter(result), [](const AnfNodePtr &node) {
|
||||
return AnfAlgo::IsGraphKernel(node) &&
|
||||
IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(node)->output(), prim::kPrimMakeTuple);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
/* Merge the get_item nodes that have same index.
|
||||
* %1 = call @graph_kernel(p1, p2)
|
||||
* %2 = tuple_getitem(%1, 0)
|
||||
* %3 = tuple_getitem(%1, 0)
|
||||
* %4 = tuple_getitem(%1, 1)
|
||||
* %5 = user_x(%2)
|
||||
* %6 = user_y(%3)
|
||||
* %7 = user_z(%4)
|
||||
* --->
|
||||
* %1 = call @graph_kernel(p1, p2)
|
||||
* %2 = tuple_getitem(%1, 0)
|
||||
* %3 = tuple_getitem(%1, 1)
|
||||
* %4 = user_x(%2)
|
||||
* %5 = user_y(%2)
|
||||
* %6 = user_z(%3)
|
||||
*/
|
||||
class MergeRepeatedGetitem : public Pass {
|
||||
public:
|
||||
bool Run(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
|
||||
bool changed = false;
|
||||
for (auto node : todos) {
|
||||
AnfNodePtrList getitem_list;
|
||||
changed = GetGraphKernelGetitemList(mng, node, &getitem_list, true) || changed;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
};
|
||||
|
||||
/* Merge the get_item nodes that have same index.
|
||||
* subgraph graph_kernel(%para1, %para2)
|
||||
* %1 = TensorAdd(%para1, %para2)
|
||||
* %2 = Neg(%1)
|
||||
* %3 = make_tuple(%1, %2)
|
||||
* return (%3)
|
||||
* %1 = call @graph_kernel(%p1, %p2)
|
||||
* %2 = tuple_getitem(%1, 0)
|
||||
* %3 = tuple_getitem(%1, 1)
|
||||
* %4 = ControlDepend(%0, %2)
|
||||
* %5 = other_user(%3)
|
||||
* --->
|
||||
* subgraph graph_kernel(%para1, %para2)
|
||||
* %1 = TensorAdd(%para1, %para2)
|
||||
* %2 = Neg(%1)
|
||||
* %3 = make_tuple(%1, %2)
|
||||
* return (%3)
|
||||
* %1 = call @graph_kernel(%p1, %p2)
|
||||
* %3 = tuple_getitem(%1, 1)
|
||||
* %4 = ControlDepend(%0, %3)
|
||||
* %5 = other_user(%3)
|
||||
*
|
||||
* Then the output 0 can be eliminate in the later pass.
|
||||
*/
|
||||
class EliminateGetitemForControlDepend : public Pass {
|
||||
public:
|
||||
bool Run(const FuncGraphPtr &func_graph) {
|
||||
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = false;
|
||||
for (const auto &node : todos) {
|
||||
getitems_.clear();
|
||||
GetGraphKernelGetitemList(mng, node, &getitems_, false);
|
||||
if (getitems_.empty()) continue;
|
||||
indexes_.clear();
|
||||
GetIndexesToControlDepend(mng);
|
||||
FilterRedundantOutputs(node);
|
||||
if (indexes_.empty()) continue;
|
||||
size_t index = GetFinalIndex(node);
|
||||
changed = ReplaceGetitems(mng, index) || changed;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtrList getitems_; // Users of GraphKernel node with multiple outputs.
|
||||
std::vector<size_t> indexes_; // Indexes of MakeTuple to be eliminated.
|
||||
|
||||
bool ReplaceGetitems(const FuncGraphManagerPtr &mng, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(getitems_[index]);
|
||||
bool changed = false;
|
||||
for (auto i : indexes_) {
|
||||
if (i != index) {
|
||||
MS_EXCEPTION_IF_NULL(getitems_[i]);
|
||||
mng->Replace(getitems_[i], getitems_[index]);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Find the redundant output index.
|
||||
// the real output should have multiple users.
|
||||
void FilterRedundantOutputs(const AnfNodePtr &node) {
|
||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
auto &users = mng->node_users();
|
||||
auto maketuple = func_graph->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(maketuple);
|
||||
std::vector<size_t> result;
|
||||
for (auto i : indexes_) {
|
||||
auto real_output = maketuple->input(i);
|
||||
if (users[real_output].size() > 1) {
|
||||
result.push_back(i);
|
||||
}
|
||||
}
|
||||
indexes_ = std::move(result);
|
||||
}
|
||||
|
||||
// Get the nodes that only have ControlDepend users.
|
||||
void GetIndexesToControlDepend(const FuncGraphManagerPtr &mng) {
|
||||
for (size_t i = 0; i < getitems_.size(); ++i) {
|
||||
const AnfNodePtr &getitem = getitems_[i];
|
||||
if (getitem == nullptr) {
|
||||
continue;
|
||||
}
|
||||
const auto &getitem_user = mng->node_users()[getitem];
|
||||
if (std::all_of(getitem_user.begin(), getitem_user.end(), [](const std::pair<AnfNodePtr, int> &user) {
|
||||
return IsPrimitiveCNode(user.first, prim::kPrimControlDepend);
|
||||
})) {
|
||||
indexes_.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetFinalIndex(const AnfNodePtr &node) {
|
||||
auto is_redundant_index = [this](size_t i) {
|
||||
return std::find(indexes_.begin(), indexes_.end(), i) != indexes_.end();
|
||||
};
|
||||
for (size_t i = 0; i < getitems_.size(); ++i) {
|
||||
if (getitems_[i] != nullptr && !is_redundant_index(i)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return indexes_[0];
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Remove the output without user or with virtual user (like ControlDepend)
|
||||
bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
bool changed = std::make_shared<MergeRepeatedGetitem>()->Run(func_graph);
|
||||
changed = std::make_shared<EliminateGetitemForControlDepend>()->Run(func_graph) || changed;
|
||||
changed = Process(func_graph) || changed;
|
||||
return changed;
|
||||
}
|
||||
|
||||
void EliminateRedundantOutput::UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset) {
|
||||
if (offset == 0) return;
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
int64_t index = SizeToLong(GetIndex(getitem));
|
||||
if (offset > index) {
|
||||
MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString();
|
||||
}
|
||||
index -= offset;
|
||||
auto idx_node = NewValueNode(MakeValue<int64_t>(index));
|
||||
auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
|
||||
idx_node->set_abstract(abstract);
|
||||
idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
|
||||
}
|
||||
|
||||
AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
|
||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto old_maketuple = func_graph->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(old_maketuple);
|
||||
AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
AbstractBasePtrList abstract_list;
|
||||
int64_t offset = 0;
|
||||
for (size_t i = 0; i < getitems.size(); ++i) {
|
||||
if (getitems[i] == nullptr) {
|
||||
offset++;
|
||||
} else {
|
||||
new_maketuple_inputs.push_back(old_maketuple->input(i + 1));
|
||||
abstract_list.push_back(old_maketuple->input(i + 1)->abstract());
|
||||
UpdateGetitemIndex(getitems[i]->cast<CNodePtr>(), offset);
|
||||
}
|
||||
}
|
||||
if (offset == 0) return nullptr;
|
||||
if (new_maketuple_inputs.size() == 1) {
|
||||
MS_LOG(EXCEPTION) << "Input of MakeTuple could not be empty";
|
||||
}
|
||||
if (new_maketuple_inputs.size() == 2) {
|
||||
func_graph->set_output(new_maketuple_inputs.back());
|
||||
} else {
|
||||
auto make_tuple = func_graph->NewCNode(new_maketuple_inputs);
|
||||
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
make_tuple->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
func_graph->set_output(make_tuple);
|
||||
}
|
||||
|
||||
auto old_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(old_cnode);
|
||||
AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end());
|
||||
AnfNodePtrList outputs;
|
||||
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
|
||||
auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs);
|
||||
SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs, AnfAlgo::GetProcessor(node));
|
||||
return graph_kernel_node;
|
||||
}
|
||||
|
||||
bool EliminateRedundantOutput::Process(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
|
||||
bool changed = false;
|
||||
for (auto node : todos) {
|
||||
AnfNodePtrList getitems;
|
||||
GetGraphKernelGetitemList(mng, node, &getitems, false);
|
||||
auto new_node = ReplaceMakeTuple(node, getitems);
|
||||
if (new_node != nullptr) {
|
||||
if (!IsPrimitiveCNode(AnfAlgo::GetCNodeFuncGraphPtr(new_node)->output(), prim::kPrimMakeTuple)) {
|
||||
// only one output, remove the getitem.
|
||||
auto i = std::find_if(getitems.begin(), getitems.end(), [](const AnfNodePtr &node) { return node != nullptr; });
|
||||
if (i != getitems.end()) {
|
||||
mng->Replace(*i, new_node);
|
||||
}
|
||||
} else {
|
||||
mng->Replace(node, new_node);
|
||||
}
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class EliminateRedundantOutput : public Pass {
|
||||
public:
|
||||
EliminateRedundantOutput() : Pass("eliminate_redundant_output") {}
|
||||
~EliminateRedundantOutput() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
bool Process(const FuncGraphPtr &func_graph);
|
||||
void UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset);
|
||||
AnfNodePtr ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_
|
|
@ -531,7 +531,7 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f
|
|||
}
|
||||
|
||||
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
const FuncGraphPtr &kernel_graph,
|
||||
const std::string &postfix) {
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
|
@ -861,23 +861,6 @@ void InitDependPrior(const std::vector<AnfNodePtr> &todos,
|
|||
}
|
||||
}
|
||||
|
||||
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend) {
|
||||
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
|
||||
if (iter->second.second == control_depend_node) {
|
||||
iter = depend_prior->erase(iter);
|
||||
continue;
|
||||
}
|
||||
++iter;
|
||||
}
|
||||
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_depend_prior;
|
||||
InitDependPrior(std::vector<AnfNodePtr>{new_control_depend}, &new_depend_prior);
|
||||
for (auto item : new_depend_prior) {
|
||||
depend_prior->insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;
|
||||
|
|
|
@ -65,8 +65,8 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphP
|
|||
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
|
||||
const AnfNodePtrList &outputs);
|
||||
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
const std::string &postfix);
|
||||
const FuncGraphPtr &kernel_graph,
|
||||
const std::string &postfix = "");
|
||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
|
||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
||||
std::map<std::string, AnfNodePtr> *address_node_map);
|
||||
|
@ -79,8 +79,6 @@ bool IsBasicFuseOp(const AnfNodePtr &node);
|
|||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior);
|
||||
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend);
|
||||
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@
|
|||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#include "debug/tensor_load.h"
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
|
||||
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
|
||||
|
@ -822,7 +822,7 @@ void AscendSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kern
|
|||
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
|
||||
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
||||
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
|
||||
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
||||
pm->AddPass(std::make_shared<opt::TensorPromotion>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
|
||||
|
|
|
@ -38,7 +38,7 @@
|
|||
#include "backend/optimizer/graph_kernel/add_atomic_clean_gpu.h"
|
||||
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
|
||||
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
|
||||
|
@ -171,7 +171,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
|
||||
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>());
|
||||
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
||||
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
|
||||
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
|
||||
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
|
||||
|
|
Loading…
Reference in New Issue