!19862 Ensure the correct execution order of the users of parameter or load

Merge pull request !19862 from Margaret_wangrui/order_enforce_2
This commit is contained in:
i-robot 2021-07-10 07:36:55 +00:00 committed by Gitee
commit 1ab6ee831e
1 changed files with 36 additions and 76 deletions

View File

@ -66,40 +66,13 @@ class OrderEnforcer {
// Skip UpdateStates for IO.
return;
}
auto updated_refs = FindUpdatedRefs(update_state);
if (updated_refs.empty()) {
// Skip UpdateStates that do not have updated refs.
const size_t attach_index = 2;
auto &attach = update_state->input(attach_index);
if (IsPrimitiveCNode(attach, prim::kPrimLoad) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
return;
} else if (attach->isa<CNode>()) {
EnforceOrderForOtherCNode(attach->cast<CNodePtr>());
}
auto &attach = update_state->input(2);
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
// Handle UpdateState with Load.
EnforceOrderForLoad(update_state, attach->cast<CNodePtr>(), updated_refs);
} else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
// Handle UpdateState with MakeTuple.
EnforceOrderForTuple(update_state, attach->cast<CNodePtr>(), updated_refs);
}
}
std::unordered_set<AnfNodePtr> FindUpdatedRefs(const CNodePtr &update_state) {
std::unordered_set<AnfNodePtr> updated_refs;
auto &users = manager_->node_users()[update_state];
for (auto &user : users) {
auto cnode = dyn_cast<CNode>(user.first);
if (cnode == nullptr) {
continue;
}
if (cnode->IsApply(prim::kPrimLoad) || cnode->IsApply(prim::kPrimDepend) ||
cnode->IsApply(prim::kPrimUpdateState)) {
continue;
}
for (auto &input : cnode->inputs()) {
if (IsRef(input)) {
updated_refs.insert(input);
}
}
}
return updated_refs;
}
bool IsRef(const AnfNodePtr &node) {
@ -107,52 +80,26 @@ class OrderEnforcer {
return abs != nullptr && abs->isa<abstract::AbstractRef>();
}
void EnforceOrderForLoad(const CNodePtr &update_state, const CNodePtr &load,
const std::unordered_set<AnfNodePtr> &refs) {
auto parameter = load->input(1);
if (refs.find(parameter) == refs.end()) {
// Skip if loaded parameter is not updated.
void EnforceOrderForOtherCNode(const CNodePtr &cnode) {
// Find refs from the cnode inputs.
auto &inputs = cnode->inputs();
const size_t last_index = inputs.size() - 1;
auto last_input = cnode->input(last_index);
if (!IsPrimitiveCNode(last_input, prim::kPrimUpdateState)) {
return;
}
// Find load users, ignore processed nodes.
auto load_users = FindUsers(load, update_state);
auto parameter_users = FindUsers(parameter, update_state);
load_users.insert(parameter_users.begin(), parameter_users.end());
// Find load users that not depend on the UpdateState,
// and than let UpdateState depend on them.
AddInputEdges(update_state, load_users);
}
void EnforceOrderForTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
const std::unordered_set<AnfNodePtr> &refs) {
// The UpdateState should be the only one user of the make_tuple.
// for performance, we only check the number of output edges.
if (manager_->node_users()[make_tuple].size() != 1) {
return;
}
// Find load users from the tuple of Load nodes.
std::unordered_set<AnfNodePtr> all_load_users;
auto &inputs = make_tuple->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto &input = inputs.at(i);
if (!IsPrimitiveCNode(input, prim::kPrimLoad)) {
// Skip non-Load nodes.
if (!IsRef(input)) {
continue;
}
auto load = input->cast<CNodePtr>();
auto parameter = load->input(1);
if (refs.find(parameter) == refs.end()) {
// Skip if loaded parameter is not updated.
continue;
// load ref users
auto loads = FindLoadUsers(input);
for (auto load : loads) {
std::unordered_set<AnfNodePtr> load_users = FindUsers(load);
AddInputEdges(last_input->cast<CNodePtr>(), load_users);
}
auto load_users = FindUsers(load, make_tuple);
auto parameter_users = FindUsers(parameter, make_tuple);
all_load_users.insert(parameter_users.begin(), parameter_users.end());
all_load_users.insert(load_users.begin(), load_users.end());
}
// Find load users that not depend on the UpdateState,
// and than let UpdateState depend on them.
AddInputEdges(update_state, all_load_users);
}
bool IsInUpdateState(const AnfNodePtr &load_user, const CNodePtr &update_state) {
@ -179,7 +126,7 @@ class OrderEnforcer {
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
auto sorted_load_users = SortLoadUsers(load_users);
for (auto &load_user : sorted_load_users) {
if (!IsDependOn(load_user, update_state)) {
if (!IsDependOn(load_user, update_state) && load_user != update_state) {
processed_nodes_.insert(load_user);
if (!IsInUpdateState(load_user, update_state)) {
manager_->AddEdge(update_state, load_user);
@ -240,7 +187,7 @@ class OrderEnforcer {
}
// Find Load or parameter users as the candidate nodes to enforce order of execution.
std::unordered_set<AnfNodePtr> FindUsers(const AnfNodePtr &load_or_param, const AnfNodePtr &exclude) {
std::unordered_set<AnfNodePtr> FindUsers(const AnfNodePtr &load_or_param) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(load_or_param);
if (iter == node_users.end()) {
@ -250,10 +197,6 @@ class OrderEnforcer {
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
if (user_node == exclude) {
// Skip excluded node.
continue;
}
if (processed_nodes_.find(user_node) != processed_nodes_.end()) {
// Skip processed nodes.
continue;
@ -265,6 +208,23 @@ class OrderEnforcer {
return load_param_users;
}
std::unordered_set<AnfNodePtr> FindLoadUsers(const AnfNodePtr &param) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(param);
if (iter == node_users.end()) {
return {};
}
std::unordered_set<AnfNodePtr> loads;
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
if (IsPrimitiveCNode(user_node, prim::kPrimLoad)) {
loads.insert(user_node);
}
}
return loads;
}
private:
const FuncGraphPtr &func_graph_;
FuncGraphManagerPtr manager_;