!20203 Revert "Ensure the correct execution order of the users of parameter or load"
Merge pull request !20203 from Margaret_wangrui/revert_order_enforce_r1.3
This commit is contained in:
commit
9f3cfb8a4b
|
@ -66,13 +66,40 @@ class OrderEnforcer {
|
|||
// Skip UpdateStates for IO.
|
||||
return;
|
||||
}
|
||||
const size_t attach_index = 2;
|
||||
auto &attach = update_state->input(attach_index);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimLoad) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
||||
auto updated_refs = FindUpdatedRefs(update_state);
|
||||
if (updated_refs.empty()) {
|
||||
// Skip UpdateStates that do not have updated refs.
|
||||
return;
|
||||
} else if (attach->isa<CNode>()) {
|
||||
EnforceOrderForOtherCNode(attach->cast<CNodePtr>());
|
||||
}
|
||||
auto &attach = update_state->input(2);
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
|
||||
// Handle UpdateState with Load.
|
||||
EnforceOrderForLoad(update_state, attach->cast<CNodePtr>(), updated_refs);
|
||||
} else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
||||
// Handle UpdateState with MakeTuple.
|
||||
EnforceOrderForTuple(update_state, attach->cast<CNodePtr>(), updated_refs);
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<AnfNodePtr> FindUpdatedRefs(const CNodePtr &update_state) {
|
||||
std::unordered_set<AnfNodePtr> updated_refs;
|
||||
auto &users = manager_->node_users()[update_state];
|
||||
for (auto &user : users) {
|
||||
auto cnode = dyn_cast<CNode>(user.first);
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (cnode->IsApply(prim::kPrimLoad) || cnode->IsApply(prim::kPrimDepend) ||
|
||||
cnode->IsApply(prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
for (auto &input : cnode->inputs()) {
|
||||
if (IsRef(input)) {
|
||||
updated_refs.insert(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
return updated_refs;
|
||||
}
|
||||
|
||||
bool IsRef(const AnfNodePtr &node) {
|
||||
|
@ -80,26 +107,52 @@ class OrderEnforcer {
|
|||
return abs != nullptr && abs->isa<abstract::AbstractRef>();
|
||||
}
|
||||
|
||||
void EnforceOrderForOtherCNode(const CNodePtr &cnode) {
|
||||
// Find refs from the cnode inputs.
|
||||
auto &inputs = cnode->inputs();
|
||||
const size_t last_index = inputs.size() - 1;
|
||||
auto last_input = cnode->input(last_index);
|
||||
if (!IsPrimitiveCNode(last_input, prim::kPrimUpdateState)) {
|
||||
void EnforceOrderForLoad(const CNodePtr &update_state, const CNodePtr &load,
|
||||
const std::unordered_set<AnfNodePtr> &refs) {
|
||||
auto parameter = load->input(1);
|
||||
if (refs.find(parameter) == refs.end()) {
|
||||
// Skip if loaded parameter is not updated.
|
||||
return;
|
||||
}
|
||||
// Find load users, ignore processed nodes.
|
||||
auto load_users = FindUsers(load, update_state);
|
||||
auto parameter_users = FindUsers(parameter, update_state);
|
||||
load_users.insert(parameter_users.begin(), parameter_users.end());
|
||||
// Find load users that not depend on the UpdateState,
|
||||
// and than let UpdateState depend on them.
|
||||
AddInputEdges(update_state, load_users);
|
||||
}
|
||||
|
||||
void EnforceOrderForTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
|
||||
const std::unordered_set<AnfNodePtr> &refs) {
|
||||
// The UpdateState should be the only one user of the make_tuple.
|
||||
// for performance, we only check the number of output edges.
|
||||
if (manager_->node_users()[make_tuple].size() != 1) {
|
||||
return;
|
||||
}
|
||||
// Find load users from the tuple of Load nodes.
|
||||
std::unordered_set<AnfNodePtr> all_load_users;
|
||||
auto &inputs = make_tuple->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto &input = inputs.at(i);
|
||||
if (!IsRef(input)) {
|
||||
if (!IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||
// Skip non-Load nodes.
|
||||
continue;
|
||||
}
|
||||
// load ref users
|
||||
auto loads = FindLoadUsers(input);
|
||||
for (auto load : loads) {
|
||||
std::unordered_set<AnfNodePtr> load_users = FindUsers(load);
|
||||
AddInputEdges(last_input->cast<CNodePtr>(), load_users);
|
||||
auto load = input->cast<CNodePtr>();
|
||||
auto parameter = load->input(1);
|
||||
if (refs.find(parameter) == refs.end()) {
|
||||
// Skip if loaded parameter is not updated.
|
||||
continue;
|
||||
}
|
||||
auto load_users = FindUsers(load, make_tuple);
|
||||
auto parameter_users = FindUsers(parameter, make_tuple);
|
||||
all_load_users.insert(parameter_users.begin(), parameter_users.end());
|
||||
all_load_users.insert(load_users.begin(), load_users.end());
|
||||
}
|
||||
// Find load users that not depend on the UpdateState,
|
||||
// and than let UpdateState depend on them.
|
||||
AddInputEdges(update_state, all_load_users);
|
||||
}
|
||||
|
||||
bool IsInUpdateState(const AnfNodePtr &load_user, const CNodePtr &update_state) {
|
||||
|
@ -107,9 +160,6 @@ class OrderEnforcer {
|
|||
const size_t input_size = update_state->inputs().size();
|
||||
for (size_t index = attach_index; index < input_size; index++) {
|
||||
auto attach = update_state->input(attach_index);
|
||||
if (attach == load_user) {
|
||||
return true;
|
||||
}
|
||||
if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
||||
auto attach_cnode = attach->cast<CNodePtr>();
|
||||
auto inputs = attach_cnode->inputs();
|
||||
|
@ -118,6 +168,8 @@ class OrderEnforcer {
|
|||
if (has_load_user) {
|
||||
return true;
|
||||
}
|
||||
} else if (attach == load_user) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
@ -127,7 +179,7 @@ class OrderEnforcer {
|
|||
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
|
||||
auto sorted_load_users = SortLoadUsers(load_users);
|
||||
for (auto &load_user : sorted_load_users) {
|
||||
if (!IsDependOn(load_user, update_state) && !IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
|
||||
if (!IsDependOn(load_user, update_state)) {
|
||||
processed_nodes_.insert(load_user);
|
||||
if (!IsInUpdateState(load_user, update_state)) {
|
||||
manager_->AddEdge(update_state, load_user);
|
||||
|
@ -188,7 +240,7 @@ class OrderEnforcer {
|
|||
}
|
||||
|
||||
// Find Load or parameter users as the candidate nodes to enforce order of execution.
|
||||
std::unordered_set<AnfNodePtr> FindUsers(const AnfNodePtr &load_or_param) {
|
||||
std::unordered_set<AnfNodePtr> FindUsers(const AnfNodePtr &load_or_param, const AnfNodePtr &exclude) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(load_or_param);
|
||||
if (iter == node_users.end()) {
|
||||
|
@ -198,6 +250,10 @@ class OrderEnforcer {
|
|||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
auto &user_node = user.first;
|
||||
if (user_node == exclude) {
|
||||
// Skip excluded node.
|
||||
continue;
|
||||
}
|
||||
if (processed_nodes_.find(user_node) != processed_nodes_.end()) {
|
||||
// Skip processed nodes.
|
||||
continue;
|
||||
|
@ -209,23 +265,6 @@ class OrderEnforcer {
|
|||
return load_param_users;
|
||||
}
|
||||
|
||||
std::unordered_set<AnfNodePtr> FindLoadUsers(const AnfNodePtr ¶m) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(param);
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
std::unordered_set<AnfNodePtr> loads;
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
auto &user_node = user.first;
|
||||
if (IsPrimitiveCNode(user_node, prim::kPrimLoad)) {
|
||||
loads.insert(user_node);
|
||||
}
|
||||
}
|
||||
return loads;
|
||||
}
|
||||
|
||||
private:
|
||||
const FuncGraphPtr &func_graph_;
|
||||
FuncGraphManagerPtr manager_;
|
||||
|
|
Loading…
Reference in New Issue