fix random order in OrderEnforceAction by replace HashSet which is not fixed order with CompactSet

This commit is contained in:
zhousiyi 2022-05-25 03:27:53 +00:00
parent 0ee52f9f30
commit 7e51edd277
2 changed files with 22 additions and 13 deletions

View File

@ -24,6 +24,7 @@
#include <memory>
#include "utils/hash_map.h"
#include "utils/hash_set.h"
#include "utils/compact_set.h"
#include "include/common/utils/utils.h"
#include "mindspore/core/ops/core_ops.h"
@ -239,7 +240,7 @@ class OrderEnforcer {
for (auto &load : loads) {
// Find user nodes of the Load.
auto load_users = FindLoadUsers(load);
mindspore::HashSet<AnfNodePtr> real_users;
mindspore::CompactSet<AnfNodePtr> real_users;
for (auto &load_user : load_users) {
// Check the special operator, only one level of user is considered for now.
if (IsSpecialPrimitive(load_user)) {
@ -249,7 +250,7 @@ class OrderEnforcer {
auto parallel__users = FindParallelNodeUsers(load_user);
real_users.insert(parallel__users.begin(), parallel__users.end());
} else {
(void)real_users.insert(load_user);
real_users.insert(load_user);
}
}
AddInputEdges(update_state, real_users);
@ -279,7 +280,7 @@ class OrderEnforcer {
}
// Add load users as input edges of the update_state node.
void AddInputEdges(const CNodePtr &update_state, const mindspore::HashSet<AnfNodePtr> &load_users) {
void AddInputEdges(const CNodePtr &update_state, const mindspore::CompactSet<AnfNodePtr> &load_users) {
auto sorted_load_users = SortLoadUsers(load_users);
for (auto &load_user : sorted_load_users) {
if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
@ -295,7 +296,7 @@ class OrderEnforcer {
}
// Sort load users by their topo sort order.
std::vector<AnfNodePtr> SortLoadUsers(const mindspore::HashSet<AnfNodePtr> &load_users) {
std::vector<AnfNodePtr> SortLoadUsers(const mindspore::CompactSet<AnfNodePtr> &load_users) {
std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()};
std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); });
return vec;
@ -348,36 +349,36 @@ class OrderEnforcer {
using PredFunc = std::function<bool(const AnfNodePtr &)>;
// Find user nodes for the given node.
mindspore::HashSet<AnfNodePtr> FindNodeUsers(const AnfNodePtr &node, const PredFunc &pred = nullptr) {
mindspore::CompactSet<AnfNodePtr> FindNodeUsers(const AnfNodePtr &node, const PredFunc &pred = nullptr) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(node);
if (iter == node_users.end()) {
return {};
}
mindspore::HashSet<AnfNodePtr> users;
mindspore::CompactSet<AnfNodePtr> users;
for (auto &user : iter->second) {
auto &user_node = user.first;
if (pred == nullptr || pred(user_node)) {
(void)users.emplace(user_node);
users.insert(user_node);
}
}
return users;
}
// Find real user nodes for the given parallel nodes.
mindspore::HashSet<AnfNodePtr> FindParallelNodeUsers(const AnfNodePtr &node) {
mindspore::CompactSet<AnfNodePtr> FindParallelNodeUsers(const AnfNodePtr &node) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(node);
if (iter == node_users.end()) {
return {};
}
mindspore::HashSet<AnfNodePtr> users;
mindspore::CompactSet<AnfNodePtr> users;
for (auto &user : iter->second) {
auto &user_node = user.first;
if (!IsSpecialParallelPrimitive(user_node)) {
(void)users.emplace(user_node);
users.insert(user_node);
} else {
mindspore::HashSet<AnfNodePtr> real_users;
mindspore::CompactSet<AnfNodePtr> real_users;
real_users = FindParallelNodeUsers(user_node);
users.insert(real_users.begin(), real_users.end());
}
@ -386,7 +387,7 @@ class OrderEnforcer {
}
// Find Load or parameter users as the candidate nodes to enforce order of execution.
mindspore::HashSet<AnfNodePtr> FindLoadUsers(const AnfNodePtr &load_or_param) {
mindspore::CompactSet<AnfNodePtr> FindLoadUsers(const AnfNodePtr &load_or_param) {
return FindNodeUsers(load_or_param, [this](const AnfNodePtr &user_node) {
// Skip processed nodes.
return processed_nodes_.find(user_node) == processed_nodes_.end();
@ -394,7 +395,7 @@ class OrderEnforcer {
}
// Find Load nodes for a parameter.
mindspore::HashSet<AnfNodePtr> FindLoadNodes(const AnfNodePtr &param) {
mindspore::CompactSet<AnfNodePtr> FindLoadNodes(const AnfNodePtr &param) {
return FindNodeUsers(param, [this](const AnfNodePtr &user_node) {
// Search for Load nodes only.
return IsPrimitiveCNode(user_node, prim::kPrimLoad);
@ -548,6 +549,7 @@ class OrderEnforcer {
const FuncGraphPtr &func_graph_;
FuncGraphManagerPtr manager_;
mindspore::HashMap<AnfNodePtr, size_t> topo_sort_map_;
// As of now it's no requirement for insertion order, so use the unordered set.
mindspore::HashSet<AnfNodePtr> processed_nodes_;
};
} // namespace

View File

@ -46,6 +46,13 @@ class CompactSet {
}
}
template <class InputIt>
void insert(InputIt first, InputIt last) {
for (; first != last; ++first) {
insert(*first);
}
}
iterator find(const T &e) { return std::find(data_.begin(), data_.end(), e); }
const_iterator find(const T &e) const { return std::find(data_.begin(), data_.end(), e); }