"solve repeat search in FindDestOps's DFS"
This commit is contained in:
parent
a7b5f49151
commit
186e6c4c43
|
@ -18,6 +18,7 @@
|
|||
"mindspore/mindspore/ccsrc/utils/convert_utils_py.cc" "whitespace/indent"
|
||||
"mindspore/mindspore/core/utils/log_adapter.cc" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/runtime/hardware/device_context.h" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
|
||||
|
||||
# Modelzoo
|
||||
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
|
||||
|
|
|
@ -911,8 +911,8 @@ void DfGraphConvertor::BuildWhileSubGraph() {
|
|||
continue;
|
||||
}
|
||||
SetNodeInput(it);
|
||||
SetSubgraph(it);
|
||||
SetOpControlInput(it);
|
||||
SetSubgraph(it);
|
||||
UpdateOpDesc(it);
|
||||
}
|
||||
MS_LOG(DEBUG) << "trace output";
|
||||
|
@ -1783,14 +1783,16 @@ void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
|
|||
auto user_node = iter.first;
|
||||
auto name = GetCNodeTargetFuncName(user_node->cast<CNodePtr>());
|
||||
if (name == prim::kPrimUpdateState->name()) {
|
||||
FindDestOps(user_node, dst_node_list, false);
|
||||
mindspore::HashMap<AnfNodePtr, DfsVisitFlag> flag_map;
|
||||
FindDestOps(user_node, dst_node_list, false, &flag_map);
|
||||
continue;
|
||||
}
|
||||
if (IsControlEdgeNode(user_node)) {
|
||||
src_node_list->push_back(user_node);
|
||||
continue;
|
||||
}
|
||||
FindDestOps(user_node, src_node_list, false);
|
||||
mindspore::HashMap<AnfNodePtr, DfsVisitFlag> flag_map;
|
||||
FindDestOps(user_node, src_node_list, false, &flag_map);
|
||||
}
|
||||
|
||||
// add to cache
|
||||
|
@ -1802,8 +1804,9 @@ void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list,
|
||||
bool top) {
|
||||
bool top, mindspore::HashMap<AnfNodePtr, DfsVisitFlag> *flag_map) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(flag_map);
|
||||
auto func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
|
@ -1814,20 +1817,29 @@ void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr
|
|||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
(*flag_map)[node] = DfsVisitFlag::kVisiting;
|
||||
auto users = manager->node_users()[node];
|
||||
for (const auto &iter : users) {
|
||||
auto user_node = iter.first;
|
||||
if (IsSubGraph() && user_node == call_node_in_while_body_) {
|
||||
continue;
|
||||
}
|
||||
if (IsControlEdgeNode(user_node)) {
|
||||
if (!top) {
|
||||
node_list->push_back(user_node);
|
||||
}
|
||||
if ((*flag_map)[user_node] == DfsVisitFlag::kVisiting) {
|
||||
MS_LOG(INFO) << "there exists a loop in graph.";
|
||||
break;
|
||||
} else if ((*flag_map)[user_node] == DfsVisitFlag::kVisited) {
|
||||
continue;
|
||||
} else {
|
||||
FindDestOps(user_node, node_list, false);
|
||||
if (IsControlEdgeNode(user_node)) {
|
||||
if (!top) {
|
||||
node_list->push_back(user_node);
|
||||
}
|
||||
} else {
|
||||
FindDestOps(user_node, node_list, false, flag_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
(*flag_map)[node] = DfsVisitFlag::kVisited;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::AutoMonadCollectInput(const AnfNodePtr &node) {
|
||||
|
@ -1844,7 +1856,8 @@ void DfGraphConvertor::AutoMonadCollectInput(const AnfNodePtr &node) {
|
|||
if (src_ops != nullptr) {
|
||||
// Find dest ops list
|
||||
std::shared_ptr<std::vector<AnfNodePtr>> dst_node_list = std::make_shared<std::vector<AnfNodePtr>>();
|
||||
FindDestOps(node, dst_node_list, true);
|
||||
mindspore::HashMap<AnfNodePtr, DfsVisitFlag> flag_map;
|
||||
FindDestOps(node, dst_node_list, true, &flag_map);
|
||||
for (auto &dest : *dst_node_list) {
|
||||
AddEdgeToCache(node, dest);
|
||||
}
|
||||
|
|
|
@ -54,6 +54,7 @@ using OpAdapterPtr = std::shared_ptr<BaseOpAdapter>;
|
|||
|
||||
using ParamIndexMap = std::map<std::size_t, std::size_t>;
|
||||
enum class GraphType { kNormal, kCond, kBody, kAfter, kBranch };
|
||||
enum class DfsVisitFlag { kUnVisited, kVisiting, kVisited };
|
||||
class DfGraphConvertor {
|
||||
public:
|
||||
explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph) {
|
||||
|
@ -199,7 +200,8 @@ class DfGraphConvertor {
|
|||
bool IsControlEdgeNode(const AnfNodePtr &node);
|
||||
void AddEdgeForLoad(const AnfNodePtr &node);
|
||||
void AddEdgeToCache(const AnfNodePtr &src, const AnfNodePtr &dest);
|
||||
void FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list, bool top);
|
||||
void FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list, bool top,
|
||||
mindspore::HashMap<AnfNodePtr, DfsVisitFlag> *flag_map);
|
||||
AnfNodePtr ParseLoadInput(const CNodePtr &cnode);
|
||||
void AutoMonadSetControlInput(const AnfNodePtr &node);
|
||||
void AutoMonadCollectInput(const AnfNodePtr &node);
|
||||
|
|
Loading…
Reference in New Issue