forked from mindspore-Ecosystem/mindspore
consider controldepend edges in checkcircle
This commit is contained in:
parent
534cc9bbe9
commit
e6a5fc0739
|
@ -47,12 +47,13 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode
|
||||||
return is_fusable ? FOLLOW : EXCLUDE;
|
return is_fusable ? FOLLOW : EXCLUDE;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
|
||||||
|
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
|
||||||
// Search fusable nodes according input direction.
|
// Search fusable nodes according input direction.
|
||||||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
|
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
|
||||||
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
||||||
if (used_nodes.size() > 1) {
|
if (used_nodes.size() > 1) {
|
||||||
used_nodes = RemoveCircle(used_nodes, false);
|
used_nodes = RemoveCircle(used_nodes, dep_pri, false);
|
||||||
}
|
}
|
||||||
TopoSortForNodeList(&used_nodes);
|
TopoSortForNodeList(&used_nodes);
|
||||||
return used_nodes;
|
return used_nodes;
|
||||||
|
@ -78,7 +79,8 @@ void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &u
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng,
|
bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng,
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv) {
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv,
|
||||||
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
|
||||||
AnfNodeSet outputs_set;
|
AnfNodeSet outputs_set;
|
||||||
for (auto out : *outputs) {
|
for (auto out : *outputs) {
|
||||||
outputs_set.insert(out);
|
outputs_set.insert(out);
|
||||||
|
@ -112,6 +114,7 @@ bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_out
|
||||||
auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs);
|
auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs);
|
||||||
mng->Replace(control_depend_node, new_control_depend);
|
mng->Replace(control_depend_node, new_control_depend);
|
||||||
has_erase_outs = true;
|
has_erase_outs = true;
|
||||||
|
UpdateControlDependNode(depend_prior, control_depend_node, new_control_depend);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
it++;
|
it++;
|
||||||
|
@ -120,7 +123,8 @@ bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_out
|
||||||
return has_erase_outs;
|
return has_erase_outs;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) {
|
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng,
|
||||||
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
|
||||||
AnfNodePtrList vir_outputs;
|
AnfNodePtrList vir_outputs;
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
|
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
|
||||||
auto fg_outputs = fg->output();
|
auto fg_outputs = fg->output();
|
||||||
|
@ -137,7 +141,7 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con
|
||||||
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
|
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv)) {
|
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv, depend_prior)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,6 +163,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
||||||
std::unordered_set<AnfNodePtr> *fused_ops) {
|
std::unordered_set<AnfNodePtr> *fused_ops) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
auto mng = kernel_graph->manager();
|
auto mng = kernel_graph->manager();
|
||||||
|
|
||||||
|
// depend_prior[depend] = pair(prior, controlDependNode)
|
||||||
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
|
||||||
|
InitDependPrior(todos, &depend_prior);
|
||||||
|
|
||||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||||
auto node = (*iter)->cast<CNodePtr>();
|
auto node = (*iter)->cast<CNodePtr>();
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
|
@ -172,7 +181,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fuse_nodes = FindFuseCNodes(node);
|
auto fuse_nodes = FindFuseCNodes(node, depend_prior);
|
||||||
if (fuse_nodes.size() <= 1) {
|
if (fuse_nodes.size() <= 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -182,11 +191,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
||||||
AnfNodePtrList inputs;
|
AnfNodePtrList inputs;
|
||||||
AnfNodePtrList outputs;
|
AnfNodePtrList outputs;
|
||||||
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
|
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
|
||||||
RemoveControlDependOut(fg, &outputs, mng);
|
RemoveControlDependOut(fg, &outputs, mng, &depend_prior);
|
||||||
ConvertNonscalarTensorToParameter(fg, &inputs);
|
ConvertNonscalarTensorToParameter(fg, &inputs);
|
||||||
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
|
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
|
||||||
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
|
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
|
||||||
|
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fuse_new_node, outputs);
|
||||||
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
|
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
|
||||||
|
|
||||||
// Set graph kernel attr
|
// Set graph kernel attr
|
||||||
|
|
|
@ -15,13 +15,14 @@
|
||||||
*/
|
*/
|
||||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <unordered_set>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <memory>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
#include <string>
|
||||||
|
#include <set>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
|
@ -97,15 +98,29 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNod
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
|
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
|
||||||
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes) {
|
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes,
|
||||||
|
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
|
||||||
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
|
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
circle_nodes->clear();
|
circle_nodes->clear();
|
||||||
|
|
||||||
|
auto InputEdges = [&depend_prior](CNodePtr cnode) {
|
||||||
|
std::set<AnfNodePtr> edges;
|
||||||
|
auto range = depend_prior.equal_range(cnode);
|
||||||
|
for (auto iter = range.first; iter != range.second; ++iter) {
|
||||||
|
edges.insert(iter->second.first);
|
||||||
|
}
|
||||||
|
auto inputs = cnode->inputs();
|
||||||
|
for (auto input : inputs) {
|
||||||
|
edges.insert(input);
|
||||||
|
}
|
||||||
|
return edges;
|
||||||
|
};
|
||||||
|
|
||||||
std::set<AnfNodePtr> cached_done_set;
|
std::set<AnfNodePtr> cached_done_set;
|
||||||
auto cnode = check_node->cast<CNodePtr>();
|
auto cnode = check_node->cast<CNodePtr>();
|
||||||
const auto &inputs = cnode->inputs();
|
const auto &inputs = InputEdges(cnode);
|
||||||
// there is a input not in fused_op_set, but the input depends on the fused_op_set
|
// there is a input not in fused_op_set, but the input depends on the fused_op_set
|
||||||
for (auto input : inputs) {
|
for (auto input : inputs) {
|
||||||
if (input->isa<CNode>() && !fused_op_set.count(input)) {
|
if (input->isa<CNode>() && !fused_op_set.count(input)) {
|
||||||
|
@ -128,7 +143,7 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
|
||||||
|
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
auto cnode_ptr = node->cast<CNodePtr>();
|
auto cnode_ptr = node->cast<CNodePtr>();
|
||||||
for (auto it : cnode_ptr->inputs()) {
|
for (auto it : InputEdges(cnode_ptr)) {
|
||||||
if (it->isa<CNode>()) {
|
if (it->isa<CNode>()) {
|
||||||
todos.push_back(it);
|
todos.push_back(it);
|
||||||
}
|
}
|
||||||
|
@ -148,7 +163,9 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
|
||||||
return !circle_nodes->empty();
|
return !circle_nodes->empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) {
|
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
||||||
|
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
|
||||||
|
bool is_backward) {
|
||||||
std::set<AnfNodePtr> cached_unconnected_set;
|
std::set<AnfNodePtr> cached_unconnected_set;
|
||||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||||
auto include = [&fused_op_set](const AnfNodePtr &node) {
|
auto include = [&fused_op_set](const AnfNodePtr &node) {
|
||||||
|
@ -161,7 +178,7 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
|
||||||
std::vector<AnfNodePtr> circle_nodes;
|
std::vector<AnfNodePtr> circle_nodes;
|
||||||
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
|
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
|
||||||
circle_nodes.clear();
|
circle_nodes.clear();
|
||||||
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes);
|
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior);
|
||||||
// delete the circle node and the node which depend on the circle node in fused op
|
// delete the circle node and the node which depend on the circle node in fused op
|
||||||
if (has_circle) {
|
if (has_circle) {
|
||||||
auto mng = (*iter)->func_graph()->manager();
|
auto mng = (*iter)->func_graph()->manager();
|
||||||
|
@ -294,7 +311,8 @@ void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
|
||||||
lst->assign(res.begin(), res.end());
|
lst->assign(res.begin(), res.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
|
||||||
|
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
|
||||||
auto func_graph = cnode->func_graph();
|
auto func_graph = cnode->func_graph();
|
||||||
auto mng = func_graph->manager();
|
auto mng = func_graph->manager();
|
||||||
// Search fusable nodes according input direction.
|
// Search fusable nodes according input direction.
|
||||||
|
@ -307,7 +325,7 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
||||||
|
|
||||||
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
|
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
|
||||||
if (used_nodes.size() > 1) {
|
if (used_nodes.size() > 1) {
|
||||||
used_nodes = RemoveCircle(used_nodes);
|
used_nodes = RemoveCircle(used_nodes, dep_pri);
|
||||||
}
|
}
|
||||||
used_nodes = RemoveWildGetitem(used_nodes);
|
used_nodes = RemoveWildGetitem(used_nodes);
|
||||||
TopoSortForNodeList(&used_nodes);
|
TopoSortForNodeList(&used_nodes);
|
||||||
|
@ -316,8 +334,18 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
||||||
|
|
||||||
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
auto mng = kernel_graph->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(kernel_graph, true);
|
||||||
|
kernel_graph->set_manager(mng);
|
||||||
|
}
|
||||||
|
auto todos = TopoSort(kernel_graph->get_return());
|
||||||
|
std::reverse(todos.begin(), todos.end());
|
||||||
|
|
||||||
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
|
||||||
|
InitDependPrior(todos, &depend_prior);
|
||||||
|
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
auto &todos = kernel_graph->execution_order();
|
|
||||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||||
auto node = *iter;
|
auto node = *iter;
|
||||||
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
|
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
|
||||||
|
@ -333,13 +361,16 @@ bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fuse_nodes = FindFuseCNodes(node);
|
auto fuse_nodes = FindFuseCNodes(node->cast<CNodePtr>(), depend_prior);
|
||||||
if (fuse_nodes.size() <= 1) {
|
if (fuse_nodes.size() <= 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
changed = true;
|
changed = true;
|
||||||
|
|
||||||
FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
|
AnfNodePtr fused_new_node;
|
||||||
|
AnfNodePtrList old_outputs;
|
||||||
|
std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
|
||||||
|
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs);
|
||||||
}
|
}
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,11 +16,13 @@
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
|
||||||
#include <limits>
|
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
#include "backend/session/kernel_graph.h"
|
#include "backend/session/kernel_graph.h"
|
||||||
|
|
||||||
|
@ -29,7 +31,9 @@ namespace opt {
|
||||||
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
||||||
"LambNextMV", "LambUpdateWithLR"};
|
"LambNextMV", "LambUpdateWithLR"};
|
||||||
|
|
||||||
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward = true);
|
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
||||||
|
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
|
||||||
|
bool is_backward = true);
|
||||||
|
|
||||||
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
|
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
|
||||||
|
|
||||||
|
|
|
@ -14,20 +14,24 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <set>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include "pipeline/jit/parse/python_adapter.h"
|
#include <utility>
|
||||||
#include "pipeline/jit/action.h"
|
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
|
||||||
#include "vm/segment_runner.h"
|
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
||||||
#include "ir/func_graph_cloner.h"
|
#include "ir/func_graph_cloner.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
#include "pipeline/jit/parse/python_adapter.h"
|
||||||
|
#include "pipeline/jit/action.h"
|
||||||
|
#include "vm/segment_runner.h"
|
||||||
#if ENABLE_GPU
|
#if ENABLE_GPU
|
||||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -526,12 +530,9 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||||
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix) {
|
const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||||
if (fuse_nodes.empty()) {
|
const std::string &postfix) {
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto mng = kernel_graph->manager();
|
auto mng = kernel_graph->manager();
|
||||||
if (mng == nullptr) {
|
if (mng == nullptr) {
|
||||||
mng = Manage(kernel_graph, true);
|
mng = Manage(kernel_graph, true);
|
||||||
|
@ -565,6 +566,8 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||||
}
|
}
|
||||||
fuse_op_name += postfix;
|
fuse_op_name += postfix;
|
||||||
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
||||||
|
|
||||||
|
return std::make_tuple(fuse_new_node, src_outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
||||||
|
@ -737,7 +740,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
|
||||||
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape,
|
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape,
|
||||||
prim::kPrimTranspose, prim::kPrimCast};
|
prim::kPrimTranspose, prim::kPrimCast, prim::kPrimRealDiv};
|
||||||
#elif ENABLE_GPU
|
#elif ENABLE_GPU
|
||||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
||||||
|
@ -786,5 +789,123 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
|
||||||
device::gpu::SetKernelInfo(cnode, kernel_type);
|
device::gpu::SetKernelInfo(cnode, kernel_type);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
|
||||||
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
|
||||||
|
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||||
|
auto cnode = (*iter)->cast<CNodePtr>();
|
||||||
|
if (cnode == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto prior_node = cnode->input(kControlDependPriorIndex);
|
||||||
|
auto depend_node = cnode->input(kControlDependBehindIndex);
|
||||||
|
MS_EXCEPTION_IF_NULL(prior_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(depend_node);
|
||||||
|
std::vector<AnfNodePtr> prior_nodes = {prior_node};
|
||||||
|
std::vector<AnfNodePtr> depend_nodes = {depend_node};
|
||||||
|
|
||||||
|
int depend_mode = 0;
|
||||||
|
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
|
||||||
|
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto GetOutputNodes = [cnode](const AnfNodePtr ¶m) -> std::vector<AnfNodePtr> {
|
||||||
|
std::vector<AnfNodePtr> out_nodes;
|
||||||
|
auto user_set = param->func_graph()->manager()->node_users()[param];
|
||||||
|
for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) {
|
||||||
|
if (iter->first != cnode) {
|
||||||
|
out_nodes.push_back(iter->first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out_nodes;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (prior_node->isa<Parameter>() && depend_mode == 1) {
|
||||||
|
prior_nodes = GetOutputNodes(prior_node);
|
||||||
|
}
|
||||||
|
if (depend_node->isa<Parameter>()) {
|
||||||
|
depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> real_prior_nodes;
|
||||||
|
std::set<AnfNodePtr> prior_visited;
|
||||||
|
for (const auto &tmp : prior_nodes) {
|
||||||
|
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
||||||
|
}
|
||||||
|
prior_visited.clear();
|
||||||
|
std::vector<AnfNodePtr> real_depend_nodes;
|
||||||
|
std::set<AnfNodePtr> depend_visited;
|
||||||
|
for (const auto &tmp : depend_nodes) {
|
||||||
|
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
||||||
|
}
|
||||||
|
depend_visited.clear();
|
||||||
|
|
||||||
|
for (auto &prior : real_prior_nodes) {
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (auto &depend : real_depend_nodes) {
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
depend_prior->insert({depend, std::make_pair(prior, cnode)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
real_prior_nodes.clear();
|
||||||
|
real_depend_nodes.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||||
|
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend) {
|
||||||
|
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
|
||||||
|
if (iter->second.second == control_depend_node) {
|
||||||
|
iter = depend_prior->erase(iter);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
++iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_depend_prior;
|
||||||
|
InitDependPrior(std::vector<AnfNodePtr>{new_control_depend}, &new_depend_prior);
|
||||||
|
for (auto item : new_depend_prior) {
|
||||||
|
depend_prior->insert(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||||
|
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
|
||||||
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;
|
||||||
|
|
||||||
|
for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) {
|
||||||
|
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimMakeTuple)) {
|
||||||
|
MS_LOG(ERROR) << "Need real outputs of makeTuple";
|
||||||
|
}
|
||||||
|
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimTupleGetItem)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
|
||||||
|
if (iter->first == outputs[out_idx]) {
|
||||||
|
new_fuse_cnode_dep_pri.insert({new_fuse_cnode, iter->second});
|
||||||
|
iter = depend_prior->erase(iter);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (iter->second.first == outputs[out_idx]) {
|
||||||
|
new_fuse_cnode_dep_pri.insert({iter->first, std::make_pair(new_fuse_cnode, iter->second.second)});
|
||||||
|
iter = depend_prior->erase(iter);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
++iter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto item : new_fuse_cnode_dep_pri) {
|
||||||
|
depend_prior->insert(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -15,17 +15,20 @@
|
||||||
*/
|
*/
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <nlohmann/json.hpp>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "backend/session/kernel_graph.h"
|
#include "backend/session/kernel_graph.h"
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -48,8 +51,9 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphP
|
||||||
const AnfNodePtrList &outputs);
|
const AnfNodePtrList &outputs);
|
||||||
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
|
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
|
||||||
const AnfNodePtrList &outputs);
|
const AnfNodePtrList &outputs);
|
||||||
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||||
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix);
|
const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||||
|
const std::string &postfix);
|
||||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
|
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
|
||||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
||||||
std::map<std::string, AnfNodePtr> *address_node_map);
|
std::map<std::string, AnfNodePtr> *address_node_map);
|
||||||
|
@ -60,6 +64,12 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
|
||||||
std::vector<PrimitivePtr> GetFusibleOpList();
|
std::vector<PrimitivePtr> GetFusibleOpList();
|
||||||
bool IsBasicFuseOp(const AnfNodePtr &node);
|
bool IsBasicFuseOp(const AnfNodePtr &node);
|
||||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||||
|
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
|
||||||
|
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior);
|
||||||
|
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||||
|
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend);
|
||||||
|
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||||
|
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||||
|
|
|
@ -1417,5 +1417,46 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A
|
||||||
}
|
}
|
||||||
return device_shape;
|
return device_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
||||||
|
std::set<AnfNodePtr> *visited) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(result);
|
||||||
|
MS_EXCEPTION_IF_NULL(visited);
|
||||||
|
if (visited->find(anf_node) != visited->end()) {
|
||||||
|
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
visited->insert(anf_node);
|
||||||
|
if (AnfAlgo::IsRealKernel(anf_node)) {
|
||||||
|
result->emplace_back(anf_node);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!anf_node->isa<CNode>()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (cnode->inputs().empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
|
||||||
|
}
|
||||||
|
auto input0 = cnode->input(0);
|
||||||
|
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
||||||
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||||
|
GetAllFatherRealNode(cnode->input(i), result, visited);
|
||||||
|
}
|
||||||
|
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
||||||
|
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
||||||
|
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||||
|
}
|
||||||
|
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
|
||||||
|
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
|
||||||
|
if (cnode->inputs().size() != kDependInputSize) {
|
||||||
|
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
|
||||||
|
}
|
||||||
|
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
|
||||||
|
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -232,6 +232,9 @@ class AnfRuntimeAlgorithm {
|
||||||
static bool IsNodeDynamicShape(const AnfNodePtr &node);
|
static bool IsNodeDynamicShape(const AnfNodePtr &node);
|
||||||
static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
||||||
static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
||||||
|
// Find control_depend real input nodes.
|
||||||
|
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
||||||
|
std::set<AnfNodePtr> *visited);
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||||
|
|
|
@ -725,47 +725,6 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
|
||||||
return output_nodes;
|
return output_nodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find control_depend real input nodes.
|
|
||||||
void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(result);
|
|
||||||
MS_EXCEPTION_IF_NULL(visited);
|
|
||||||
if (visited->find(anf_node) != visited->end()) {
|
|
||||||
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
visited->insert(anf_node);
|
|
||||||
if (AnfAlgo::IsRealKernel(anf_node)) {
|
|
||||||
result->emplace_back(anf_node);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!anf_node->isa<CNode>()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto cnode = anf_node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
if (cnode->inputs().empty()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
|
|
||||||
}
|
|
||||||
auto input0 = cnode->input(0);
|
|
||||||
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
|
||||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
|
||||||
GetAllFatherRealNode(cnode->input(i), result, visited);
|
|
||||||
}
|
|
||||||
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
|
||||||
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
|
||||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
|
||||||
}
|
|
||||||
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
|
|
||||||
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
|
|
||||||
if (cnode->inputs().size() != kDependInputSize) {
|
|
||||||
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
|
|
||||||
}
|
|
||||||
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
|
|
||||||
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// update the depend relations of control depend
|
// update the depend relations of control depend
|
||||||
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
|
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
|
||||||
for (const auto &node : depends) {
|
for (const auto &node : depends) {
|
||||||
|
@ -800,12 +759,12 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
|
||||||
std::vector<AnfNodePtr> real_prior_nodes;
|
std::vector<AnfNodePtr> real_prior_nodes;
|
||||||
std::set<AnfNodePtr> prior_visited;
|
std::set<AnfNodePtr> prior_visited;
|
||||||
for (const auto &tmp : prior_nodes) {
|
for (const auto &tmp : prior_nodes) {
|
||||||
GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> real_depend_nodes;
|
std::vector<AnfNodePtr> real_depend_nodes;
|
||||||
std::set<AnfNodePtr> depend_visited;
|
std::set<AnfNodePtr> depend_visited;
|
||||||
for (const auto &tmp : depend_nodes) {
|
for (const auto &tmp : depend_nodes) {
|
||||||
GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
||||||
}
|
}
|
||||||
UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes);
|
UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue