fix random order in OrderEnforceAction by replace HashSet which is not fixed order with CompactSet
This commit is contained in:
parent
0ee52f9f30
commit
7e51edd277
|
@ -24,6 +24,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "utils/hash_map.h"
|
#include "utils/hash_map.h"
|
||||||
#include "utils/hash_set.h"
|
#include "utils/hash_set.h"
|
||||||
|
#include "utils/compact_set.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "mindspore/core/ops/core_ops.h"
|
#include "mindspore/core/ops/core_ops.h"
|
||||||
|
|
||||||
|
@ -239,7 +240,7 @@ class OrderEnforcer {
|
||||||
for (auto &load : loads) {
|
for (auto &load : loads) {
|
||||||
// Find user nodes of the Load.
|
// Find user nodes of the Load.
|
||||||
auto load_users = FindLoadUsers(load);
|
auto load_users = FindLoadUsers(load);
|
||||||
mindspore::HashSet<AnfNodePtr> real_users;
|
mindspore::CompactSet<AnfNodePtr> real_users;
|
||||||
for (auto &load_user : load_users) {
|
for (auto &load_user : load_users) {
|
||||||
// Check the special operator, only one level of user is considered for now.
|
// Check the special operator, only one level of user is considered for now.
|
||||||
if (IsSpecialPrimitive(load_user)) {
|
if (IsSpecialPrimitive(load_user)) {
|
||||||
|
@ -249,7 +250,7 @@ class OrderEnforcer {
|
||||||
auto parallel__users = FindParallelNodeUsers(load_user);
|
auto parallel__users = FindParallelNodeUsers(load_user);
|
||||||
real_users.insert(parallel__users.begin(), parallel__users.end());
|
real_users.insert(parallel__users.begin(), parallel__users.end());
|
||||||
} else {
|
} else {
|
||||||
(void)real_users.insert(load_user);
|
real_users.insert(load_user);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AddInputEdges(update_state, real_users);
|
AddInputEdges(update_state, real_users);
|
||||||
|
@ -279,7 +280,7 @@ class OrderEnforcer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add load users as input edges of the update_state node.
|
// Add load users as input edges of the update_state node.
|
||||||
void AddInputEdges(const CNodePtr &update_state, const mindspore::HashSet<AnfNodePtr> &load_users) {
|
void AddInputEdges(const CNodePtr &update_state, const mindspore::CompactSet<AnfNodePtr> &load_users) {
|
||||||
auto sorted_load_users = SortLoadUsers(load_users);
|
auto sorted_load_users = SortLoadUsers(load_users);
|
||||||
for (auto &load_user : sorted_load_users) {
|
for (auto &load_user : sorted_load_users) {
|
||||||
if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
|
if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
|
||||||
|
@ -295,7 +296,7 @@ class OrderEnforcer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort load users by their topo sort order.
|
// Sort load users by their topo sort order.
|
||||||
std::vector<AnfNodePtr> SortLoadUsers(const mindspore::HashSet<AnfNodePtr> &load_users) {
|
std::vector<AnfNodePtr> SortLoadUsers(const mindspore::CompactSet<AnfNodePtr> &load_users) {
|
||||||
std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()};
|
std::vector<AnfNodePtr> vec{load_users.begin(), load_users.end()};
|
||||||
std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); });
|
std::sort(vec.begin(), vec.end(), [this](const AnfNodePtr &a, const AnfNodePtr &b) { return IsBefore(a, b); });
|
||||||
return vec;
|
return vec;
|
||||||
|
@ -348,36 +349,36 @@ class OrderEnforcer {
|
||||||
using PredFunc = std::function<bool(const AnfNodePtr &)>;
|
using PredFunc = std::function<bool(const AnfNodePtr &)>;
|
||||||
|
|
||||||
// Find user nodes for the given node.
|
// Find user nodes for the given node.
|
||||||
mindspore::HashSet<AnfNodePtr> FindNodeUsers(const AnfNodePtr &node, const PredFunc &pred = nullptr) {
|
mindspore::CompactSet<AnfNodePtr> FindNodeUsers(const AnfNodePtr &node, const PredFunc &pred = nullptr) {
|
||||||
auto &node_users = manager_->node_users();
|
auto &node_users = manager_->node_users();
|
||||||
auto iter = node_users.find(node);
|
auto iter = node_users.find(node);
|
||||||
if (iter == node_users.end()) {
|
if (iter == node_users.end()) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
mindspore::HashSet<AnfNodePtr> users;
|
mindspore::CompactSet<AnfNodePtr> users;
|
||||||
for (auto &user : iter->second) {
|
for (auto &user : iter->second) {
|
||||||
auto &user_node = user.first;
|
auto &user_node = user.first;
|
||||||
if (pred == nullptr || pred(user_node)) {
|
if (pred == nullptr || pred(user_node)) {
|
||||||
(void)users.emplace(user_node);
|
users.insert(user_node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return users;
|
return users;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find real user nodes for the given parallel nodes.
|
// Find real user nodes for the given parallel nodes.
|
||||||
mindspore::HashSet<AnfNodePtr> FindParallelNodeUsers(const AnfNodePtr &node) {
|
mindspore::CompactSet<AnfNodePtr> FindParallelNodeUsers(const AnfNodePtr &node) {
|
||||||
auto &node_users = manager_->node_users();
|
auto &node_users = manager_->node_users();
|
||||||
auto iter = node_users.find(node);
|
auto iter = node_users.find(node);
|
||||||
if (iter == node_users.end()) {
|
if (iter == node_users.end()) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
mindspore::HashSet<AnfNodePtr> users;
|
mindspore::CompactSet<AnfNodePtr> users;
|
||||||
for (auto &user : iter->second) {
|
for (auto &user : iter->second) {
|
||||||
auto &user_node = user.first;
|
auto &user_node = user.first;
|
||||||
if (!IsSpecialParallelPrimitive(user_node)) {
|
if (!IsSpecialParallelPrimitive(user_node)) {
|
||||||
(void)users.emplace(user_node);
|
users.insert(user_node);
|
||||||
} else {
|
} else {
|
||||||
mindspore::HashSet<AnfNodePtr> real_users;
|
mindspore::CompactSet<AnfNodePtr> real_users;
|
||||||
real_users = FindParallelNodeUsers(user_node);
|
real_users = FindParallelNodeUsers(user_node);
|
||||||
users.insert(real_users.begin(), real_users.end());
|
users.insert(real_users.begin(), real_users.end());
|
||||||
}
|
}
|
||||||
|
@ -386,7 +387,7 @@ class OrderEnforcer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find Load or parameter users as the candidate nodes to enforce order of execution.
|
// Find Load or parameter users as the candidate nodes to enforce order of execution.
|
||||||
mindspore::HashSet<AnfNodePtr> FindLoadUsers(const AnfNodePtr &load_or_param) {
|
mindspore::CompactSet<AnfNodePtr> FindLoadUsers(const AnfNodePtr &load_or_param) {
|
||||||
return FindNodeUsers(load_or_param, [this](const AnfNodePtr &user_node) {
|
return FindNodeUsers(load_or_param, [this](const AnfNodePtr &user_node) {
|
||||||
// Skip processed nodes.
|
// Skip processed nodes.
|
||||||
return processed_nodes_.find(user_node) == processed_nodes_.end();
|
return processed_nodes_.find(user_node) == processed_nodes_.end();
|
||||||
|
@ -394,7 +395,7 @@ class OrderEnforcer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find Load nodes for a parameter.
|
// Find Load nodes for a parameter.
|
||||||
mindspore::HashSet<AnfNodePtr> FindLoadNodes(const AnfNodePtr ¶m) {
|
mindspore::CompactSet<AnfNodePtr> FindLoadNodes(const AnfNodePtr ¶m) {
|
||||||
return FindNodeUsers(param, [this](const AnfNodePtr &user_node) {
|
return FindNodeUsers(param, [this](const AnfNodePtr &user_node) {
|
||||||
// Search for Load nodes only.
|
// Search for Load nodes only.
|
||||||
return IsPrimitiveCNode(user_node, prim::kPrimLoad);
|
return IsPrimitiveCNode(user_node, prim::kPrimLoad);
|
||||||
|
@ -548,6 +549,7 @@ class OrderEnforcer {
|
||||||
const FuncGraphPtr &func_graph_;
|
const FuncGraphPtr &func_graph_;
|
||||||
FuncGraphManagerPtr manager_;
|
FuncGraphManagerPtr manager_;
|
||||||
mindspore::HashMap<AnfNodePtr, size_t> topo_sort_map_;
|
mindspore::HashMap<AnfNodePtr, size_t> topo_sort_map_;
|
||||||
|
// As of now it's no requirement for insertion order, so use the unordered set.
|
||||||
mindspore::HashSet<AnfNodePtr> processed_nodes_;
|
mindspore::HashSet<AnfNodePtr> processed_nodes_;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -46,6 +46,13 @@ class CompactSet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class InputIt>
|
||||||
|
void insert(InputIt first, InputIt last) {
|
||||||
|
for (; first != last; ++first) {
|
||||||
|
insert(*first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
iterator find(const T &e) { return std::find(data_.begin(), data_.end(), e); }
|
iterator find(const T &e) { return std::find(data_.begin(), data_.end(), e); }
|
||||||
|
|
||||||
const_iterator find(const T &e) const { return std::find(data_.begin(), data_.end(), e); }
|
const_iterator find(const T &e) const { return std::find(data_.begin(), data_.end(), e); }
|
||||||
|
|
Loading…
Reference in New Issue