!34935 fix random order in OrderEnforceAction by replace HashSet which is not fixed order with CompactSet
Merge pull request !34935 from xychow/fix-random-order-in-order-enforce-action
This commit is contained in:
commit
c7648b7cda
|
@ -24,6 +24,7 @@
|
|||
#include <memory>
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
#include "utils/compact_set.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
|
@ -239,7 +240,7 @@ class OrderEnforcer {
|
|||
for (auto &load : loads) {
|
||||
// Find user nodes of the Load.
|
||||
auto load_users = FindLoadUsers(load);
|
||||
mindspore::HashSet<AnfNodePtr> real_users;
|
||||
mindspore::CompactSet<AnfNodePtr> real_users;
|
||||
for (auto &load_user : load_users) {
|
||||
// Check the special operator, only one level of user is considered for now.
|
||||
if (IsSpecialPrimitive(load_user)) {
|
||||
|
@ -249,7 +250,7 @@ class OrderEnforcer {
|
|||
auto parallel__users = FindParallelNodeUsers(load_user);
|
||||
real_users.insert(parallel__users.begin(), parallel__users.end());
|
||||
} else {
|
||||
(void)real_users.insert(load_user);
|
||||
real_users.insert(load_user);
|
||||
}
|
||||
}
|
||||
AddInputEdges(update_state, real_users);
|
||||
|
@ -279,7 +280,7 @@ class OrderEnforcer {
|
|||
}
|
||||
|
||||
// Add load users as input edges of the update_state node.
|
||||
void AddInputEdges(const CNodePtr &update_state, const mindspore::HashSet<AnfNodePtr> &load_users) {
|
||||
void AddInputEdges(const CNodePtr &update_state, const mindspore::CompactSet<AnfNodePtr> &load_users) {
|
||||
auto sorted_load_users = SortLoadUsers(load_users);
|
||||
for (auto &load_user : sorted_load_users) {
|
||||
if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
|
||||
|
@ -295,7 +296,7 @@ class OrderEnforcer {
|
|||
}
|
||||
|
||||
// Sort load users by their topo sort order.
|
||||
std::vector<AnfNodePtr> SortLoadUsers(const mindspore::HashSet<AnfNodePtr> &load_users) {
|
||||
std::vector<AnfNodePtr> SortLoadUsers(const mindspore::CompactSet<AnfNodePtr> &load_users) {
|
||||
std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()};
|
||||
std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); });
|
||||
return vec;
|
||||
|
@ -348,36 +349,36 @@ class OrderEnforcer {
|
|||
using PredFunc = std::function<bool(const AnfNodePtr &)>;
|
||||
|
||||
// Find user nodes for the given node.
|
||||
mindspore::HashSet<AnfNodePtr> FindNodeUsers(const AnfNodePtr &node, const PredFunc &pred = nullptr) {
|
||||
mindspore::CompactSet<AnfNodePtr> FindNodeUsers(const AnfNodePtr &node, const PredFunc &pred = nullptr) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(node);
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
mindspore::HashSet<AnfNodePtr> users;
|
||||
mindspore::CompactSet<AnfNodePtr> users;
|
||||
for (auto &user : iter->second) {
|
||||
auto &user_node = user.first;
|
||||
if (pred == nullptr || pred(user_node)) {
|
||||
(void)users.emplace(user_node);
|
||||
users.insert(user_node);
|
||||
}
|
||||
}
|
||||
return users;
|
||||
}
|
||||
|
||||
// Find real user nodes for the given parallel nodes.
|
||||
mindspore::HashSet<AnfNodePtr> FindParallelNodeUsers(const AnfNodePtr &node) {
|
||||
mindspore::CompactSet<AnfNodePtr> FindParallelNodeUsers(const AnfNodePtr &node) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(node);
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
mindspore::HashSet<AnfNodePtr> users;
|
||||
mindspore::CompactSet<AnfNodePtr> users;
|
||||
for (auto &user : iter->second) {
|
||||
auto &user_node = user.first;
|
||||
if (!IsSpecialParallelPrimitive(user_node)) {
|
||||
(void)users.emplace(user_node);
|
||||
users.insert(user_node);
|
||||
} else {
|
||||
mindspore::HashSet<AnfNodePtr> real_users;
|
||||
mindspore::CompactSet<AnfNodePtr> real_users;
|
||||
real_users = FindParallelNodeUsers(user_node);
|
||||
users.insert(real_users.begin(), real_users.end());
|
||||
}
|
||||
|
@ -386,7 +387,7 @@ class OrderEnforcer {
|
|||
}
|
||||
|
||||
// Find Load or parameter users as the candidate nodes to enforce order of execution.
|
||||
mindspore::HashSet<AnfNodePtr> FindLoadUsers(const AnfNodePtr &load_or_param) {
|
||||
mindspore::CompactSet<AnfNodePtr> FindLoadUsers(const AnfNodePtr &load_or_param) {
|
||||
return FindNodeUsers(load_or_param, [this](const AnfNodePtr &user_node) {
|
||||
// Skip processed nodes.
|
||||
return processed_nodes_.find(user_node) == processed_nodes_.end();
|
||||
|
@ -394,7 +395,7 @@ class OrderEnforcer {
|
|||
}
|
||||
|
||||
// Find Load nodes for a parameter.
|
||||
mindspore::HashSet<AnfNodePtr> FindLoadNodes(const AnfNodePtr ¶m) {
|
||||
mindspore::CompactSet<AnfNodePtr> FindLoadNodes(const AnfNodePtr ¶m) {
|
||||
return FindNodeUsers(param, [this](const AnfNodePtr &user_node) {
|
||||
// Search for Load nodes only.
|
||||
return IsPrimitiveCNode(user_node, prim::kPrimLoad);
|
||||
|
@ -548,6 +549,7 @@ class OrderEnforcer {
|
|||
const FuncGraphPtr &func_graph_;
|
||||
FuncGraphManagerPtr manager_;
|
||||
mindspore::HashMap<AnfNodePtr, size_t> topo_sort_map_;
|
||||
// As of now it's no requirement for insertion order, so use the unordered set.
|
||||
mindspore::HashSet<AnfNodePtr> processed_nodes_;
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -46,6 +46,13 @@ class CompactSet {
|
|||
}
|
||||
}
|
||||
|
||||
template <class InputIt>
|
||||
void insert(InputIt first, InputIt last) {
|
||||
for (; first != last; ++first) {
|
||||
insert(*first);
|
||||
}
|
||||
}
|
||||
|
||||
iterator find(const T &e) { return std::find(data_.begin(), data_.end(), e); }
|
||||
|
||||
const_iterator find(const T &e) const { return std::find(data_.begin(), data_.end(), e); }
|
||||
|
|
Loading…
Reference in New Issue