diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc index 65d3bdc0c49..4e23b61afc5 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc @@ -66,13 +66,40 @@ class OrderEnforcer { // Skip UpdateStates for IO. return; } - const size_t attach_index = 2; - auto &attach = update_state->input(attach_index); - if (IsPrimitiveCNode(attach, prim::kPrimLoad) || IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { + auto updated_refs = FindUpdatedRefs(update_state); + if (updated_refs.empty()) { + // Skip UpdateStates that do not have updated refs. return; - } else if (attach->isa()) { - EnforceOrderForOtherCNode(attach->cast()); } + auto &attach = update_state->input(2); + if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { + // Handle UpdateState with Load. + EnforceOrderForLoad(update_state, attach->cast(), updated_refs); + } else if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { + // Handle UpdateState with MakeTuple. + EnforceOrderForTuple(update_state, attach->cast(), updated_refs); + } + } + + std::unordered_set FindUpdatedRefs(const CNodePtr &update_state) { + std::unordered_set updated_refs; + auto &users = manager_->node_users()[update_state]; + for (auto &user : users) { + auto cnode = dyn_cast(user.first); + if (cnode == nullptr) { + continue; + } + if (cnode->IsApply(prim::kPrimLoad) || cnode->IsApply(prim::kPrimDepend) || + cnode->IsApply(prim::kPrimUpdateState)) { + continue; + } + for (auto &input : cnode->inputs()) { + if (IsRef(input)) { + updated_refs.insert(input); + } + } + } + return updated_refs; } bool IsRef(const AnfNodePtr &node) { @@ -80,26 +107,52 @@ class OrderEnforcer { return abs != nullptr && abs->isa(); } - void EnforceOrderForOtherCNode(const CNodePtr &cnode) { - // Find refs from the cnode inputs. - auto &inputs = cnode->inputs(); - const size_t last_index = inputs.size() - 1; - auto last_input = cnode->input(last_index); - if (!IsPrimitiveCNode(last_input, prim::kPrimUpdateState)) { + void EnforceOrderForLoad(const CNodePtr &update_state, const CNodePtr &load, + const std::unordered_set &refs) { + auto parameter = load->input(1); + if (refs.find(parameter) == refs.end()) { + // Skip if loaded parameter is not updated. return; } + // Find load users, ignore processed nodes. + auto load_users = FindUsers(load, update_state); + auto parameter_users = FindUsers(parameter, update_state); + load_users.insert(parameter_users.begin(), parameter_users.end()); + // Find load users that not depend on the UpdateState, + // and than let UpdateState depend on them. + AddInputEdges(update_state, load_users); + } + + void EnforceOrderForTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, + const std::unordered_set &refs) { + // The UpdateState should be the only one user of the make_tuple. + // for performance, we only check the number of output edges. + if (manager_->node_users()[make_tuple].size() != 1) { + return; + } + // Find load users from the tuple of Load nodes. + std::unordered_set all_load_users; + auto &inputs = make_tuple->inputs(); for (size_t i = 1; i < inputs.size(); ++i) { auto &input = inputs.at(i); - if (!IsRef(input)) { + if (!IsPrimitiveCNode(input, prim::kPrimLoad)) { + // Skip non-Load nodes. continue; } - // load ref users - auto loads = FindLoadUsers(input); - for (auto load : loads) { - std::unordered_set load_users = FindUsers(load); - AddInputEdges(last_input->cast(), load_users); + auto load = input->cast(); + auto parameter = load->input(1); + if (refs.find(parameter) == refs.end()) { + // Skip if loaded parameter is not updated. + continue; } + auto load_users = FindUsers(load, make_tuple); + auto parameter_users = FindUsers(parameter, make_tuple); + all_load_users.insert(parameter_users.begin(), parameter_users.end()); + all_load_users.insert(load_users.begin(), load_users.end()); } + // Find load users that not depend on the UpdateState, + // and than let UpdateState depend on them. + AddInputEdges(update_state, all_load_users); } bool IsInUpdateState(const AnfNodePtr &load_user, const CNodePtr &update_state) { @@ -107,9 +160,6 @@ class OrderEnforcer { const size_t input_size = update_state->inputs().size(); for (size_t index = attach_index; index < input_size; index++) { auto attach = update_state->input(attach_index); - if (attach == load_user) { - return true; - } if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { auto attach_cnode = attach->cast(); auto inputs = attach_cnode->inputs(); @@ -118,6 +168,8 @@ class OrderEnforcer { if (has_load_user) { return true; } + } else if (attach == load_user) { + return true; } } return false; @@ -127,7 +179,7 @@ class OrderEnforcer { void AddInputEdges(const CNodePtr &update_state, const std::unordered_set &load_users) { auto sorted_load_users = SortLoadUsers(load_users); for (auto &load_user : sorted_load_users) { - if (!IsDependOn(load_user, update_state) && !IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) { + if (!IsDependOn(load_user, update_state)) { processed_nodes_.insert(load_user); if (!IsInUpdateState(load_user, update_state)) { manager_->AddEdge(update_state, load_user); @@ -188,7 +240,7 @@ class OrderEnforcer { } // Find Load or parameter users as the candidate nodes to enforce order of execution. - std::unordered_set FindUsers(const AnfNodePtr &load_or_param) { + std::unordered_set FindUsers(const AnfNodePtr &load_or_param, const AnfNodePtr &exclude) { auto &node_users = manager_->node_users(); auto iter = node_users.find(load_or_param); if (iter == node_users.end()) { @@ -198,6 +250,10 @@ class OrderEnforcer { auto &users = iter->second; for (auto &user : users) { auto &user_node = user.first; + if (user_node == exclude) { + // Skip excluded node. + continue; + } if (processed_nodes_.find(user_node) != processed_nodes_.end()) { // Skip processed nodes. continue; @@ -209,23 +265,6 @@ class OrderEnforcer { return load_param_users; } - std::unordered_set FindLoadUsers(const AnfNodePtr ¶m) { - auto &node_users = manager_->node_users(); - auto iter = node_users.find(param); - if (iter == node_users.end()) { - return {}; - } - std::unordered_set loads; - auto &users = iter->second; - for (auto &user : users) { - auto &user_node = user.first; - if (IsPrimitiveCNode(user_node, prim::kPrimLoad)) { - loads.insert(user_node); - } - } - return loads; - } - private: const FuncGraphPtr &func_graph_; FuncGraphManagerPtr manager_;