"solve repeat search in FindDestOps's DFS"

This commit is contained in:
xiao_yao1994 2022-06-27 19:41:55 +08:00
parent a7b5f49151
commit 186e6c4c43
3 changed files with 27 additions and 11 deletions

View File

@ -18,6 +18,7 @@
"mindspore/mindspore/ccsrc/utils/convert_utils_py.cc" "whitespace/indent"
"mindspore/mindspore/core/utils/log_adapter.cc" "runtime/references"
"mindspore/mindspore/ccsrc/runtime/hardware/device_context.h" "readability/braces"
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
# Modelzoo
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"

View File

@ -911,8 +911,8 @@ void DfGraphConvertor::BuildWhileSubGraph() {
continue;
}
SetNodeInput(it);
SetSubgraph(it);
SetOpControlInput(it);
SetSubgraph(it);
UpdateOpDesc(it);
}
MS_LOG(DEBUG) << "trace output";
@ -1783,14 +1783,16 @@ void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
auto user_node = iter.first;
auto name = GetCNodeTargetFuncName(user_node->cast<CNodePtr>());
if (name == prim::kPrimUpdateState->name()) {
FindDestOps(user_node, dst_node_list, false);
mindspore::HashMap<AnfNodePtr, DfsVisitFlag> flag_map;
FindDestOps(user_node, dst_node_list, false, &flag_map);
continue;
}
if (IsControlEdgeNode(user_node)) {
src_node_list->push_back(user_node);
continue;
}
FindDestOps(user_node, src_node_list, false);
mindspore::HashMap<AnfNodePtr, DfsVisitFlag> flag_map;
FindDestOps(user_node, src_node_list, false, &flag_map);
}
// add to cache
@ -1802,8 +1804,9 @@ void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
}
void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list,
bool top) {
bool top, mindspore::HashMap<AnfNodePtr, DfsVisitFlag> *flag_map) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(flag_map);
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
@ -1814,20 +1817,29 @@ void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
(*flag_map)[node] = DfsVisitFlag::kVisiting;
auto users = manager->node_users()[node];
for (const auto &iter : users) {
auto user_node = iter.first;
if (IsSubGraph() && user_node == call_node_in_while_body_) {
continue;
}
if (IsControlEdgeNode(user_node)) {
if (!top) {
node_list->push_back(user_node);
}
if ((*flag_map)[user_node] == DfsVisitFlag::kVisiting) {
MS_LOG(INFO) << "there exists a loop in graph.";
break;
} else if ((*flag_map)[user_node] == DfsVisitFlag::kVisited) {
continue;
} else {
FindDestOps(user_node, node_list, false);
if (IsControlEdgeNode(user_node)) {
if (!top) {
node_list->push_back(user_node);
}
} else {
FindDestOps(user_node, node_list, false, flag_map);
}
}
}
(*flag_map)[node] = DfsVisitFlag::kVisited;
}
void DfGraphConvertor::AutoMonadCollectInput(const AnfNodePtr &node) {
@ -1844,7 +1856,8 @@ void DfGraphConvertor::AutoMonadCollectInput(const AnfNodePtr &node) {
if (src_ops != nullptr) {
// Find dest ops list
std::shared_ptr<std::vector<AnfNodePtr>> dst_node_list = std::make_shared<std::vector<AnfNodePtr>>();
FindDestOps(node, dst_node_list, true);
mindspore::HashMap<AnfNodePtr, DfsVisitFlag> flag_map;
FindDestOps(node, dst_node_list, true, &flag_map);
for (auto &dest : *dst_node_list) {
AddEdgeToCache(node, dest);
}

View File

@ -54,6 +54,7 @@ using OpAdapterPtr = std::shared_ptr<BaseOpAdapter>;
using ParamIndexMap = std::map<std::size_t, std::size_t>;
enum class GraphType { kNormal, kCond, kBody, kAfter, kBranch };
enum class DfsVisitFlag { kUnVisited, kVisiting, kVisited };
class DfGraphConvertor {
public:
explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph) {
@ -199,7 +200,8 @@ class DfGraphConvertor {
bool IsControlEdgeNode(const AnfNodePtr &node);
void AddEdgeForLoad(const AnfNodePtr &node);
void AddEdgeToCache(const AnfNodePtr &src, const AnfNodePtr &dest);
void FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list, bool top);
void FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list, bool top,
mindspore::HashMap<AnfNodePtr, DfsVisitFlag> *flag_map);
AnfNodePtr ParseLoadInput(const CNodePtr &cnode);
void AutoMonadSetControlInput(const AnfNodePtr &node);
void AutoMonadCollectInput(const AnfNodePtr &node);