consider controldepend edges in checkcircle

This commit is contained in:
lingyunli63 2020-11-19 20:22:56 +08:00
parent 534cc9bbe9
commit e6a5fc0739
8 changed files with 264 additions and 86 deletions

View File

@ -47,12 +47,13 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode
return is_fusable ? FOLLOW : EXCLUDE; return is_fusable ? FOLLOW : EXCLUDE;
} }
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) { std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
// Search fusable nodes according input direction. // Search fusable nodes according input direction.
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1); auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
if (used_nodes.size() > 1) { if (used_nodes.size() > 1) {
used_nodes = RemoveCircle(used_nodes, false); used_nodes = RemoveCircle(used_nodes, dep_pri, false);
} }
TopoSortForNodeList(&used_nodes); TopoSortForNodeList(&used_nodes);
return used_nodes; return used_nodes;
@ -78,7 +79,8 @@ void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &u
} }
bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng, bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng,
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv) { std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
AnfNodeSet outputs_set; AnfNodeSet outputs_set;
for (auto out : *outputs) { for (auto out : *outputs) {
outputs_set.insert(out); outputs_set.insert(out);
@ -112,6 +114,7 @@ bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_out
auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs); auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs);
mng->Replace(control_depend_node, new_control_depend); mng->Replace(control_depend_node, new_control_depend);
has_erase_outs = true; has_erase_outs = true;
UpdateControlDependNode(depend_prior, control_depend_node, new_control_depend);
} }
} else { } else {
it++; it++;
@ -120,7 +123,8 @@ bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_out
return has_erase_outs; return has_erase_outs;
} }
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
AnfNodePtrList vir_outputs; AnfNodePtrList vir_outputs;
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv; std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
auto fg_outputs = fg->output(); auto fg_outputs = fg->output();
@ -137,7 +141,7 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
} }
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv)) { if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv, depend_prior)) {
return; return;
} }
@ -159,6 +163,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
std::unordered_set<AnfNodePtr> *fused_ops) { std::unordered_set<AnfNodePtr> *fused_ops) {
bool changed = false; bool changed = false;
auto mng = kernel_graph->manager(); auto mng = kernel_graph->manager();
// depend_prior[depend] = pair(prior, controlDependNode)
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
InitDependPrior(todos, &depend_prior);
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = (*iter)->cast<CNodePtr>(); auto node = (*iter)->cast<CNodePtr>();
if (node == nullptr) { if (node == nullptr) {
@ -172,7 +181,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
continue; continue;
} }
auto fuse_nodes = FindFuseCNodes(node); auto fuse_nodes = FindFuseCNodes(node, depend_prior);
if (fuse_nodes.size() <= 1) { if (fuse_nodes.size() <= 1) {
continue; continue;
} }
@ -182,11 +191,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
AnfNodePtrList inputs; AnfNodePtrList inputs;
AnfNodePtrList outputs; AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
RemoveControlDependOut(fg, &outputs, mng); RemoveControlDependOut(fg, &outputs, mng, &depend_prior);
ConvertNonscalarTensorToParameter(fg, &inputs); ConvertNonscalarTensorToParameter(fg, &inputs);
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs); auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0])); SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fuse_new_node, outputs);
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
// Set graph kernel attr // Set graph kernel attr

View File

@ -15,13 +15,14 @@
*/ */
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h" #include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include <memory>
#include <string>
#include <algorithm> #include <algorithm>
#include <unordered_set>
#include <map> #include <map>
#include <set> #include <memory>
#include <queue> #include <queue>
#include <string>
#include <set>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
@ -97,15 +98,29 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNod
} }
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node, bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes) { std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) { if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
return false; return false;
} }
circle_nodes->clear(); circle_nodes->clear();
auto InputEdges = [&depend_prior](CNodePtr cnode) {
std::set<AnfNodePtr> edges;
auto range = depend_prior.equal_range(cnode);
for (auto iter = range.first; iter != range.second; ++iter) {
edges.insert(iter->second.first);
}
auto inputs = cnode->inputs();
for (auto input : inputs) {
edges.insert(input);
}
return edges;
};
std::set<AnfNodePtr> cached_done_set; std::set<AnfNodePtr> cached_done_set;
auto cnode = check_node->cast<CNodePtr>(); auto cnode = check_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs(); const auto &inputs = InputEdges(cnode);
// there is a input not in fused_op_set, but the input depends on the fused_op_set // there is a input not in fused_op_set, but the input depends on the fused_op_set
for (auto input : inputs) { for (auto input : inputs) {
if (input->isa<CNode>() && !fused_op_set.count(input)) { if (input->isa<CNode>() && !fused_op_set.count(input)) {
@ -128,7 +143,7 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto cnode_ptr = node->cast<CNodePtr>(); auto cnode_ptr = node->cast<CNodePtr>();
for (auto it : cnode_ptr->inputs()) { for (auto it : InputEdges(cnode_ptr)) {
if (it->isa<CNode>()) { if (it->isa<CNode>()) {
todos.push_back(it); todos.push_back(it);
} }
@ -148,7 +163,9 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
return !circle_nodes->empty(); return !circle_nodes->empty();
} }
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) { std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
bool is_backward) {
std::set<AnfNodePtr> cached_unconnected_set; std::set<AnfNodePtr> cached_unconnected_set;
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end()); std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto include = [&fused_op_set](const AnfNodePtr &node) { auto include = [&fused_op_set](const AnfNodePtr &node) {
@ -161,7 +178,7 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
std::vector<AnfNodePtr> circle_nodes; std::vector<AnfNodePtr> circle_nodes;
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
circle_nodes.clear(); circle_nodes.clear();
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes); bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior);
// delete the circle node and the node which depend on the circle node in fused op // delete the circle node and the node which depend on the circle node in fused op
if (has_circle) { if (has_circle) {
auto mng = (*iter)->func_graph()->manager(); auto mng = (*iter)->func_graph()->manager();
@ -294,7 +311,8 @@ void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
lst->assign(res.begin(), res.end()); lst->assign(res.begin(), res.end());
} }
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) { std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
auto func_graph = cnode->func_graph(); auto func_graph = cnode->func_graph();
auto mng = func_graph->manager(); auto mng = func_graph->manager();
// Search fusable nodes according input direction. // Search fusable nodes according input direction.
@ -307,7 +325,7 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end()); used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
if (used_nodes.size() > 1) { if (used_nodes.size() > 1) {
used_nodes = RemoveCircle(used_nodes); used_nodes = RemoveCircle(used_nodes, dep_pri);
} }
used_nodes = RemoveWildGetitem(used_nodes); used_nodes = RemoveWildGetitem(used_nodes);
TopoSortForNodeList(&used_nodes); TopoSortForNodeList(&used_nodes);
@ -316,8 +334,18 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) { bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
auto todos = TopoSort(kernel_graph->get_return());
std::reverse(todos.begin(), todos.end());
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
InitDependPrior(todos, &depend_prior);
bool changed = false; bool changed = false;
auto &todos = kernel_graph->execution_order();
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = *iter; auto node = *iter;
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) { if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
@ -333,13 +361,16 @@ bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph)
} }
} }
auto fuse_nodes = FindFuseCNodes(node); auto fuse_nodes = FindFuseCNodes(node->cast<CNodePtr>(), depend_prior);
if (fuse_nodes.size() <= 1) { if (fuse_nodes.size() <= 1) {
continue; continue;
} }
changed = true; changed = true;
FuseNodesToSubGraph(fuse_nodes, kernel_graph, ""); AnfNodePtr fused_new_node;
AnfNodePtrList old_outputs;
std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs);
} }
return changed; return changed;
} }

View File

@ -16,11 +16,13 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
#include <limits>
#include <map>
#include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include <memory>
#include <limits>
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
@ -29,7 +31,9 @@ namespace opt {
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
"LambNextMV", "LambUpdateWithLR"}; "LambNextMV", "LambUpdateWithLR"};
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward = true); std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
bool is_backward = true);
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst); void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);

View File

@ -14,20 +14,24 @@
* limitations under the License. * limitations under the License.
*/ */
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include <map> #include <map>
#include <set>
#include <tuple> #include <tuple>
#include <unordered_set> #include <unordered_set>
#include "pipeline/jit/parse/python_adapter.h" #include <utility>
#include "pipeline/jit/action.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "vm/segment_runner.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h" #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "backend/optimizer/pass/const_input_to_attr_registry.h" #include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/action.h"
#include "vm/segment_runner.h"
#if ENABLE_GPU #if ENABLE_GPU
#include "runtime/device/gpu/kernel_info_setter.h" #include "runtime/device/gpu/kernel_info_setter.h"
#endif #endif
@ -526,12 +530,9 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f
} }
} }
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix) { const std::shared_ptr<session::KernelGraph> &kernel_graph,
if (fuse_nodes.empty()) { const std::string &postfix) {
return;
}
auto mng = kernel_graph->manager(); auto mng = kernel_graph->manager();
if (mng == nullptr) { if (mng == nullptr) {
mng = Manage(kernel_graph, true); mng = Manage(kernel_graph, true);
@ -565,6 +566,8 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
} }
fuse_op_name += postfix; fuse_op_name += postfix;
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
return std::make_tuple(fuse_new_node, src_outputs);
} }
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
@ -737,7 +740,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape,
prim::kPrimTranspose, prim::kPrimCast}; prim::kPrimTranspose, prim::kPrimCast, prim::kPrimRealDiv};
#elif ENABLE_GPU #elif ENABLE_GPU
std::vector<PrimitivePtr> fusible_basic_ops = { std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
@ -786,5 +789,123 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
device::gpu::SetKernelInfo(cnode, kernel_type); device::gpu::SetKernelInfo(cnode, kernel_type);
#endif #endif
} }
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto cnode = (*iter)->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
continue;
}
auto prior_node = cnode->input(kControlDependPriorIndex);
auto depend_node = cnode->input(kControlDependBehindIndex);
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(depend_node);
std::vector<AnfNodePtr> prior_nodes = {prior_node};
std::vector<AnfNodePtr> depend_nodes = {depend_node};
int depend_mode = 0;
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
}
auto GetOutputNodes = [cnode](const AnfNodePtr &param) -> std::vector<AnfNodePtr> {
std::vector<AnfNodePtr> out_nodes;
auto user_set = param->func_graph()->manager()->node_users()[param];
for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) {
if (iter->first != cnode) {
out_nodes.push_back(iter->first);
}
}
return out_nodes;
};
if (prior_node->isa<Parameter>() && depend_mode == 1) {
prior_nodes = GetOutputNodes(prior_node);
}
if (depend_node->isa<Parameter>()) {
depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{};
}
std::vector<AnfNodePtr> real_prior_nodes;
std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) {
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
}
prior_visited.clear();
std::vector<AnfNodePtr> real_depend_nodes;
std::set<AnfNodePtr> depend_visited;
for (const auto &tmp : depend_nodes) {
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
}
depend_visited.clear();
for (auto &prior : real_prior_nodes) {
if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) {
continue;
}
for (auto &depend : real_depend_nodes) {
if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) {
continue;
}
depend_prior->insert({depend, std::make_pair(prior, cnode)});
}
}
real_prior_nodes.clear();
real_depend_nodes.clear();
}
}
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend) {
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
if (iter->second.second == control_depend_node) {
iter = depend_prior->erase(iter);
continue;
}
++iter;
}
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_depend_prior;
InitDependPrior(std::vector<AnfNodePtr>{new_control_depend}, &new_depend_prior);
for (auto item : new_depend_prior) {
depend_prior->insert(item);
}
}
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;
for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) {
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimMakeTuple)) {
MS_LOG(ERROR) << "Need real outputs of makeTuple";
}
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimTupleGetItem)) {
continue;
}
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
if (iter->first == outputs[out_idx]) {
new_fuse_cnode_dep_pri.insert({new_fuse_cnode, iter->second});
iter = depend_prior->erase(iter);
continue;
}
if (iter->second.first == outputs[out_idx]) {
new_fuse_cnode_dep_pri.insert({iter->first, std::make_pair(new_fuse_cnode, iter->second.second)});
iter = depend_prior->erase(iter);
continue;
}
++iter;
}
}
for (auto item : new_fuse_cnode_dep_pri) {
depend_prior->insert(item);
}
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -15,17 +15,20 @@
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
#include <string>
#include <vector>
#include <memory>
#include <map> #include <map>
#include <memory>
#include <set>
#include <string>
#include <tuple> #include <tuple>
#include <unordered_set> #include <unordered_set>
#include <nlohmann/json.hpp> #include <utility>
#include <vector>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include <nlohmann/json.hpp>
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -48,8 +51,9 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphP
const AnfNodePtrList &outputs); const AnfNodePtrList &outputs);
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode, void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
const AnfNodePtrList &outputs); const AnfNodePtrList &outputs);
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes, std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix); const std::shared_ptr<session::KernelGraph> &kernel_graph,
const std::string &postfix);
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc); bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc, bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map); std::map<std::string, AnfNodePtr> *address_node_map);
@ -60,6 +64,12 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
std::vector<PrimitivePtr> GetFusibleOpList(); std::vector<PrimitivePtr> GetFusibleOpList();
bool IsBasicFuseOp(const AnfNodePtr &node); bool IsBasicFuseOp(const AnfNodePtr &node);
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior);
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend);
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_

View File

@ -1417,5 +1417,46 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A
} }
return device_shape; return device_shape;
} }
void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(visited);
if (visited->find(anf_node) != visited->end()) {
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
return;
}
visited->insert(anf_node);
if (AnfAlgo::IsRealKernel(anf_node)) {
result->emplace_back(anf_node);
return;
}
if (!anf_node->isa<CNode>()) {
return;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
}
auto input0 = cnode->input(0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
GetAllFatherRealNode(cnode->input(i), result, visited);
}
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
if (cnode->inputs().size() != kDependInputSize) {
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
}
}
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore

View File

@ -232,6 +232,9 @@ class AnfRuntimeAlgorithm {
static bool IsNodeDynamicShape(const AnfNodePtr &node); static bool IsNodeDynamicShape(const AnfNodePtr &node);
static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
// Find control_depend real input nodes.
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited);
}; };
} // namespace session } // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm; using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -725,47 +725,6 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
return output_nodes; return output_nodes;
} }
// Find control_depend real input nodes.
void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(visited);
if (visited->find(anf_node) != visited->end()) {
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
return;
}
visited->insert(anf_node);
if (AnfAlgo::IsRealKernel(anf_node)) {
result->emplace_back(anf_node);
return;
}
if (!anf_node->isa<CNode>()) {
return;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
}
auto input0 = cnode->input(0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
GetAllFatherRealNode(cnode->input(i), result, visited);
}
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
if (cnode->inputs().size() != kDependInputSize) {
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
}
}
// update the depend relations of control depend // update the depend relations of control depend
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) { void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
for (const auto &node : depends) { for (const auto &node : depends) {
@ -800,12 +759,12 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
std::vector<AnfNodePtr> real_prior_nodes; std::vector<AnfNodePtr> real_prior_nodes;
std::set<AnfNodePtr> prior_visited; std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) { for (const auto &tmp : prior_nodes) {
GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
} }
std::vector<AnfNodePtr> real_depend_nodes; std::vector<AnfNodePtr> real_depend_nodes;
std::set<AnfNodePtr> depend_visited; std::set<AnfNodePtr> depend_visited;
for (const auto &tmp : depend_nodes) { for (const auto &tmp : depend_nodes) {
GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
} }
UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes); UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes);
} }