!20203 Revert "Ensure the correct execution order of the users of parameter or load"

Merge pull request !20203 from Margaret_wangrui/revert_order_enforce_r1.3
This commit is contained in:
zhangzhenghai 2021-07-14 01:03:43 +00:00 committed by Gitee
commit 9f3cfb8a4b
1 changed files with 78 additions and 39 deletions

View File

@ -66,13 +66,40 @@ class OrderEnforcer {
// Skip UpdateStates for IO.
return;
}
const size_t attach_index = 2;
auto &attach = update_state->input(attach_index);
if (IsPrimitiveCNode(attach, prim::kPrimLoad) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
auto updated_refs = FindUpdatedRefs(update_state);
if (updated_refs.empty()) {
// Skip UpdateStates that do not have updated refs.
return;
} else if (attach->isa<CNode>()) {
EnforceOrderForOtherCNode(attach->cast<CNodePtr>());
}
auto &attach = update_state->input(2);
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
// Handle UpdateState with Load.
EnforceOrderForLoad(update_state, attach->cast<CNodePtr>(), updated_refs);
} else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
// Handle UpdateState with MakeTuple.
EnforceOrderForTuple(update_state, attach->cast<CNodePtr>(), updated_refs);
}
}
std::unordered_set<AnfNodePtr> FindUpdatedRefs(const CNodePtr &update_state) {
std::unordered_set<AnfNodePtr> updated_refs;
auto &users = manager_->node_users()[update_state];
for (auto &user : users) {
auto cnode = dyn_cast<CNode>(user.first);
if (cnode == nullptr) {
continue;
}
if (cnode->IsApply(prim::kPrimLoad) || cnode->IsApply(prim::kPrimDepend) ||
cnode->IsApply(prim::kPrimUpdateState)) {
continue;
}
for (auto &input : cnode->inputs()) {
if (IsRef(input)) {
updated_refs.insert(input);
}
}
}
return updated_refs;
}
bool IsRef(const AnfNodePtr &node) {
@ -80,26 +107,52 @@ class OrderEnforcer {
return abs != nullptr && abs->isa<abstract::AbstractRef>();
}
void EnforceOrderForOtherCNode(const CNodePtr &cnode) {
// Find refs from the cnode inputs.
auto &inputs = cnode->inputs();
const size_t last_index = inputs.size() - 1;
auto last_input = cnode->input(last_index);
if (!IsPrimitiveCNode(last_input, prim::kPrimUpdateState)) {
void EnforceOrderForLoad(const CNodePtr &update_state, const CNodePtr &load,
const std::unordered_set<AnfNodePtr> &refs) {
auto parameter = load->input(1);
if (refs.find(parameter) == refs.end()) {
// Skip if loaded parameter is not updated.
return;
}
// Find load users, ignore processed nodes.
auto load_users = FindUsers(load, update_state);
auto parameter_users = FindUsers(parameter, update_state);
load_users.insert(parameter_users.begin(), parameter_users.end());
// Find load users that not depend on the UpdateState,
// and than let UpdateState depend on them.
AddInputEdges(update_state, load_users);
}
void EnforceOrderForTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
const std::unordered_set<AnfNodePtr> &refs) {
// The UpdateState should be the only one user of the make_tuple.
// for performance, we only check the number of output edges.
if (manager_->node_users()[make_tuple].size() != 1) {
return;
}
// Find load users from the tuple of Load nodes.
std::unordered_set<AnfNodePtr> all_load_users;
auto &inputs = make_tuple->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto &input = inputs.at(i);
if (!IsRef(input)) {
if (!IsPrimitiveCNode(input, prim::kPrimLoad)) {
// Skip non-Load nodes.
continue;
}
// load ref users
auto loads = FindLoadUsers(input);
for (auto load : loads) {
std::unordered_set<AnfNodePtr> load_users = FindUsers(load);
AddInputEdges(last_input->cast<CNodePtr>(), load_users);
auto load = input->cast<CNodePtr>();
auto parameter = load->input(1);
if (refs.find(parameter) == refs.end()) {
// Skip if loaded parameter is not updated.
continue;
}
auto load_users = FindUsers(load, make_tuple);
auto parameter_users = FindUsers(parameter, make_tuple);
all_load_users.insert(parameter_users.begin(), parameter_users.end());
all_load_users.insert(load_users.begin(), load_users.end());
}
// Find load users that not depend on the UpdateState,
// and than let UpdateState depend on them.
AddInputEdges(update_state, all_load_users);
}
bool IsInUpdateState(const AnfNodePtr &load_user, const CNodePtr &update_state) {
@ -107,9 +160,6 @@ class OrderEnforcer {
const size_t input_size = update_state->inputs().size();
for (size_t index = attach_index; index < input_size; index++) {
auto attach = update_state->input(attach_index);
if (attach == load_user) {
return true;
}
if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
auto attach_cnode = attach->cast<CNodePtr>();
auto inputs = attach_cnode->inputs();
@ -118,6 +168,8 @@ class OrderEnforcer {
if (has_load_user) {
return true;
}
} else if (attach == load_user) {
return true;
}
}
return false;
@ -127,7 +179,7 @@ class OrderEnforcer {
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
auto sorted_load_users = SortLoadUsers(load_users);
for (auto &load_user : sorted_load_users) {
if (!IsDependOn(load_user, update_state) && !IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
if (!IsDependOn(load_user, update_state)) {
processed_nodes_.insert(load_user);
if (!IsInUpdateState(load_user, update_state)) {
manager_->AddEdge(update_state, load_user);
@ -188,7 +240,7 @@ class OrderEnforcer {
}
// Find Load or parameter users as the candidate nodes to enforce order of execution.
std::unordered_set<AnfNodePtr> FindUsers(const AnfNodePtr &load_or_param) {
std::unordered_set<AnfNodePtr> FindUsers(const AnfNodePtr &load_or_param, const AnfNodePtr &exclude) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(load_or_param);
if (iter == node_users.end()) {
@ -198,6 +250,10 @@ class OrderEnforcer {
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
if (user_node == exclude) {
// Skip excluded node.
continue;
}
if (processed_nodes_.find(user_node) != processed_nodes_.end()) {
// Skip processed nodes.
continue;
@ -209,23 +265,6 @@ class OrderEnforcer {
return load_param_users;
}
std::unordered_set<AnfNodePtr> FindLoadUsers(const AnfNodePtr &param) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(param);
if (iter == node_users.end()) {
return {};
}
std::unordered_set<AnfNodePtr> loads;
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
if (IsPrimitiveCNode(user_node, prim::kPrimLoad)) {
loads.insert(user_node);
}
}
return loads;
}
private:
const FuncGraphPtr &func_graph_;
FuncGraphManagerPtr manager_;