From 7e51edd277787432b550b6f56deaa8967550eb88 Mon Sep 17 00:00:00 2001 From: zhousiyi Date: Wed, 25 May 2022 03:27:53 +0000 Subject: [PATCH] fix random order in OrderEnforceAction by replace HashSet which is not fixed order with CompactSet --- .../jit/static_analysis/order_enforce.cc | 28 ++++++++++--------- mindspore/core/utils/compact_set.h | 7 +++++ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc index 62b9cc690e4..e5656849e92 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc @@ -24,6 +24,7 @@ #include #include "utils/hash_map.h" #include "utils/hash_set.h" +#include "utils/compact_set.h" #include "include/common/utils/utils.h" #include "mindspore/core/ops/core_ops.h" @@ -239,7 +240,7 @@ class OrderEnforcer { for (auto &load : loads) { // Find user nodes of the Load. auto load_users = FindLoadUsers(load); - mindspore::HashSet real_users; + mindspore::CompactSet real_users; for (auto &load_user : load_users) { // Check the special operator, only one level of user is considered for now. if (IsSpecialPrimitive(load_user)) { @@ -249,7 +250,7 @@ class OrderEnforcer { auto parallel__users = FindParallelNodeUsers(load_user); real_users.insert(parallel__users.begin(), parallel__users.end()); } else { - (void)real_users.insert(load_user); + real_users.insert(load_user); } } AddInputEdges(update_state, real_users); @@ -279,7 +280,7 @@ class OrderEnforcer { } // Add load users as input edges of the update_state node. - void AddInputEdges(const CNodePtr &update_state, const mindspore::HashSet &load_users) { + void AddInputEdges(const CNodePtr &update_state, const mindspore::CompactSet &load_users) { auto sorted_load_users = SortLoadUsers(load_users); for (auto &load_user : sorted_load_users) { if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) { @@ -295,7 +296,7 @@ class OrderEnforcer { } // Sort load users by their topo sort order. - std::vector SortLoadUsers(const mindspore::HashSet &load_users) { + std::vector SortLoadUsers(const mindspore::CompactSet &load_users) { std::vector vec{load_users.begin(), load_users.end()}; std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); }); return vec; @@ -348,36 +349,36 @@ class OrderEnforcer { using PredFunc = std::function; // Find user nodes for the given node. - mindspore::HashSet FindNodeUsers(const AnfNodePtr &node, const PredFunc &pred = nullptr) { + mindspore::CompactSet FindNodeUsers(const AnfNodePtr &node, const PredFunc &pred = nullptr) { auto &node_users = manager_->node_users(); auto iter = node_users.find(node); if (iter == node_users.end()) { return {}; } - mindspore::HashSet users; + mindspore::CompactSet users; for (auto &user : iter->second) { auto &user_node = user.first; if (pred == nullptr || pred(user_node)) { - (void)users.emplace(user_node); + users.insert(user_node); } } return users; } // Find real user nodes for the given parallel nodes. - mindspore::HashSet FindParallelNodeUsers(const AnfNodePtr &node) { + mindspore::CompactSet FindParallelNodeUsers(const AnfNodePtr &node) { auto &node_users = manager_->node_users(); auto iter = node_users.find(node); if (iter == node_users.end()) { return {}; } - mindspore::HashSet users; + mindspore::CompactSet users; for (auto &user : iter->second) { auto &user_node = user.first; if (!IsSpecialParallelPrimitive(user_node)) { - (void)users.emplace(user_node); + users.insert(user_node); } else { - mindspore::HashSet real_users; + mindspore::CompactSet real_users; real_users = FindParallelNodeUsers(user_node); users.insert(real_users.begin(), real_users.end()); } @@ -386,7 +387,7 @@ class OrderEnforcer { } // Find Load or parameter users as the candidate nodes to enforce order of execution. - mindspore::HashSet FindLoadUsers(const AnfNodePtr &load_or_param) { + mindspore::CompactSet FindLoadUsers(const AnfNodePtr &load_or_param) { return FindNodeUsers(load_or_param, [this](const AnfNodePtr &user_node) { // Skip processed nodes. return processed_nodes_.find(user_node) == processed_nodes_.end(); @@ -394,7 +395,7 @@ class OrderEnforcer { } // Find Load nodes for a parameter. - mindspore::HashSet FindLoadNodes(const AnfNodePtr ¶m) { + mindspore::CompactSet FindLoadNodes(const AnfNodePtr ¶m) { return FindNodeUsers(param, [this](const AnfNodePtr &user_node) { // Search for Load nodes only. return IsPrimitiveCNode(user_node, prim::kPrimLoad); @@ -548,6 +549,7 @@ class OrderEnforcer { const FuncGraphPtr &func_graph_; FuncGraphManagerPtr manager_; mindspore::HashMap topo_sort_map_; + // As of now it's no requirement for insertion order, so use the unordered set. mindspore::HashSet processed_nodes_; }; } // namespace diff --git a/mindspore/core/utils/compact_set.h b/mindspore/core/utils/compact_set.h index c123d9473c9..270ebf04848 100644 --- a/mindspore/core/utils/compact_set.h +++ b/mindspore/core/utils/compact_set.h @@ -46,6 +46,13 @@ class CompactSet { } } + template + void insert(InputIt first, InputIt last) { + for (; first != last; ++first) { + insert(*first); + } + } + iterator find(const T &e) { return std::find(data_.begin(), data_.end(), e); } const_iterator find(const T &e) const { return std::find(data_.begin(), data_.end(), e); }