fix random order in OrderEnforceAction by replace HashSet which is not fixed order with CompactSet

This commit is contained in:
zhousiyi 2022-05-25 03:27:53 +00:00
parent 0ee52f9f30
commit 7e51edd277
2 changed files with 22 additions and 13 deletions

View File

@ -24,6 +24,7 @@
#include <memory> #include <memory>
#include "utils/hash_map.h" #include "utils/hash_map.h"
#include "utils/hash_set.h" #include "utils/hash_set.h"
#include "utils/compact_set.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "mindspore/core/ops/core_ops.h" #include "mindspore/core/ops/core_ops.h"
@ -239,7 +240,7 @@ class OrderEnforcer {
for (auto &load : loads) { for (auto &load : loads) {
// Find user nodes of the Load. // Find user nodes of the Load.
auto load_users = FindLoadUsers(load); auto load_users = FindLoadUsers(load);
mindspore::HashSet<AnfNodePtr> real_users; mindspore::CompactSet<AnfNodePtr> real_users;
for (auto &load_user : load_users) { for (auto &load_user : load_users) {
// Check the special operator, only one level of user is considered for now. // Check the special operator, only one level of user is considered for now.
if (IsSpecialPrimitive(load_user)) { if (IsSpecialPrimitive(load_user)) {
@ -249,7 +250,7 @@ class OrderEnforcer {
auto parallel__users = FindParallelNodeUsers(load_user); auto parallel__users = FindParallelNodeUsers(load_user);
real_users.insert(parallel__users.begin(), parallel__users.end()); real_users.insert(parallel__users.begin(), parallel__users.end());
} else { } else {
(void)real_users.insert(load_user); real_users.insert(load_user);
} }
} }
AddInputEdges(update_state, real_users); AddInputEdges(update_state, real_users);
@ -279,7 +280,7 @@ class OrderEnforcer {
} }
// Add load users as input edges of the update_state node. // Add load users as input edges of the update_state node.
void AddInputEdges(const CNodePtr &update_state, const mindspore::HashSet<AnfNodePtr> &load_users) { void AddInputEdges(const CNodePtr &update_state, const mindspore::CompactSet<AnfNodePtr> &load_users) {
auto sorted_load_users = SortLoadUsers(load_users); auto sorted_load_users = SortLoadUsers(load_users);
for (auto &load_user : sorted_load_users) { for (auto &load_user : sorted_load_users) {
if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) { if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
@ -295,7 +296,7 @@ class OrderEnforcer {
} }
// Sort load users by their topo sort order. // Sort load users by their topo sort order.
std::vector<AnfNodePtr> SortLoadUsers(const mindspore::HashSet<AnfNodePtr> &load_users) { std::vector<AnfNodePtr> SortLoadUsers(const mindspore::CompactSet<AnfNodePtr> &load_users) {
std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()}; std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()};
std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); }); std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); });
return vec; return vec;
@ -348,36 +349,36 @@ class OrderEnforcer {
using PredFunc = std::function<bool(const AnfNodePtr &)>; using PredFunc = std::function<bool(const AnfNodePtr &)>;
// Find user nodes for the given node. // Find user nodes for the given node.
mindspore::HashSet<AnfNodePtr> FindNodeUsers(const AnfNodePtr &node, const PredFunc &pred = nullptr) { mindspore::CompactSet<AnfNodePtr> FindNodeUsers(const AnfNodePtr &node, const PredFunc &pred = nullptr) {
auto &node_users = manager_->node_users(); auto &node_users = manager_->node_users();
auto iter = node_users.find(node); auto iter = node_users.find(node);
if (iter == node_users.end()) { if (iter == node_users.end()) {
return {}; return {};
} }
mindspore::HashSet<AnfNodePtr> users; mindspore::CompactSet<AnfNodePtr> users;
for (auto &user : iter->second) { for (auto &user : iter->second) {
auto &user_node = user.first; auto &user_node = user.first;
if (pred == nullptr || pred(user_node)) { if (pred == nullptr || pred(user_node)) {
(void)users.emplace(user_node); users.insert(user_node);
} }
} }
return users; return users;
} }
// Find real user nodes for the given parallel nodes. // Find real user nodes for the given parallel nodes.
mindspore::HashSet<AnfNodePtr> FindParallelNodeUsers(const AnfNodePtr &node) { mindspore::CompactSet<AnfNodePtr> FindParallelNodeUsers(const AnfNodePtr &node) {
auto &node_users = manager_->node_users(); auto &node_users = manager_->node_users();
auto iter = node_users.find(node); auto iter = node_users.find(node);
if (iter == node_users.end()) { if (iter == node_users.end()) {
return {}; return {};
} }
mindspore::HashSet<AnfNodePtr> users; mindspore::CompactSet<AnfNodePtr> users;
for (auto &user : iter->second) { for (auto &user : iter->second) {
auto &user_node = user.first; auto &user_node = user.first;
if (!IsSpecialParallelPrimitive(user_node)) { if (!IsSpecialParallelPrimitive(user_node)) {
(void)users.emplace(user_node); users.insert(user_node);
} else { } else {
mindspore::HashSet<AnfNodePtr> real_users; mindspore::CompactSet<AnfNodePtr> real_users;
real_users = FindParallelNodeUsers(user_node); real_users = FindParallelNodeUsers(user_node);
users.insert(real_users.begin(), real_users.end()); users.insert(real_users.begin(), real_users.end());
} }
@ -386,7 +387,7 @@ class OrderEnforcer {
} }
// Find Load or parameter users as the candidate nodes to enforce order of execution. // Find Load or parameter users as the candidate nodes to enforce order of execution.
mindspore::HashSet<AnfNodePtr> FindLoadUsers(const AnfNodePtr &load_or_param) { mindspore::CompactSet<AnfNodePtr> FindLoadUsers(const AnfNodePtr &load_or_param) {
return FindNodeUsers(load_or_param, [this](const AnfNodePtr &user_node) { return FindNodeUsers(load_or_param, [this](const AnfNodePtr &user_node) {
// Skip processed nodes. // Skip processed nodes.
return processed_nodes_.find(user_node) == processed_nodes_.end(); return processed_nodes_.find(user_node) == processed_nodes_.end();
@ -394,7 +395,7 @@ class OrderEnforcer {
} }
// Find Load nodes for a parameter. // Find Load nodes for a parameter.
mindspore::HashSet<AnfNodePtr> FindLoadNodes(const AnfNodePtr &param) { mindspore::CompactSet<AnfNodePtr> FindLoadNodes(const AnfNodePtr &param) {
return FindNodeUsers(param, [this](const AnfNodePtr &user_node) { return FindNodeUsers(param, [this](const AnfNodePtr &user_node) {
// Search for Load nodes only. // Search for Load nodes only.
return IsPrimitiveCNode(user_node, prim::kPrimLoad); return IsPrimitiveCNode(user_node, prim::kPrimLoad);
@ -548,6 +549,7 @@ class OrderEnforcer {
const FuncGraphPtr &func_graph_; const FuncGraphPtr &func_graph_;
FuncGraphManagerPtr manager_; FuncGraphManagerPtr manager_;
mindspore::HashMap<AnfNodePtr, size_t> topo_sort_map_; mindspore::HashMap<AnfNodePtr, size_t> topo_sort_map_;
// As of now it's no requirement for insertion order, so use the unordered set.
mindspore::HashSet<AnfNodePtr> processed_nodes_; mindspore::HashSet<AnfNodePtr> processed_nodes_;
}; };
} // namespace } // namespace

View File

@ -46,6 +46,13 @@ class CompactSet {
} }
} }
template <class InputIt>
void insert(InputIt first, InputIt last) {
for (; first != last; ++first) {
insert(*first);
}
}
iterator find(const T &e) { return std::find(data_.begin(), data_.end(), e); } iterator find(const T &e) { return std::find(data_.begin(), data_.end(), e); }
const_iterator find(const T &e) const { return std::find(data_.begin(), data_.end(), e); } const_iterator find(const T &e) const { return std::find(data_.begin(), data_.end(), e); }