!12621 Eliminate all redundant nodes related to UpdateStates.

From: @zh_qh
Reviewed-by: @ginfung,@hwhewei
Signed-off-by: @ginfung
This commit is contained in:
mindspore-ci-bot 2021-02-26 16:14:47 +08:00 committed by Gitee
commit b13cabeb10
6 changed files with 58 additions and 17 deletions

View File

@ -1878,10 +1878,9 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
const AnfNodePtr &front_node) {
auto node_users = front_func_graph_manager->node_users();
auto users = node_users[front_node];
auto &users = front_func_graph_manager->node_users()[front_node];
std::vector<AnfNodePtr> result;
for (auto user : users) {
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) {
continue;
}

View File

@ -433,12 +433,12 @@ class IncorporateGetitemSwitch : public AnfVisitor {
MS_EXCEPTION_IF_NULL(switch_call_cnode);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_users_map = manager->node_users();
auto &node_users_map = manager->node_users();
auto it = node_users_map.find(switch_call);
if (it == node_users_map.end()) {
return false;
}
auto node_users = it->second;
auto &node_users = it->second;
// If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.s
auto tuple_getitem_num = std::count_if(node_users.begin(), node_users.end(), [](std::pair<AnfNodePtr, int> &user) {
return IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem);

View File

@ -291,6 +291,50 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd
return make_tuple;
}
// Remove all nodes related to UpdateStates, if they're redundant.
void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_states) {
if (update_states.empty()) {
return;
}
auto mgr = GetManager(update_states[0]);
// 1. Remove the use of UpdateState nodes, except the last one.
for (auto i = update_states.size() - 1; i > 0; i--) {
auto &us = update_states[i];
mgr->Replace(us, us->input(kInputIndex));
}
// 2. Remove the Depend users of last UpdateState node.
auto &node_users = mgr->node_users();
auto iter = node_users.find(update_states[0]);
if (iter == node_users.end()) {
return;
}
auto &us_users = iter->second;
if (us_users.size() < 2) {
return;
}
std::vector<AnfNodePtr> depend_nodes;
for (auto &user : us_users) {
if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kAttachIndex) {
depend_nodes.emplace_back(user.first);
}
}
if (depend_nodes.empty()) {
return;
}
ssize_t end = 0;
// If all users are Depend CNode.
if (depend_nodes.size() == us_users.size()) {
end = 1;
}
for (ssize_t i = depend_nodes.size() - 1; i >= end; i--) {
const auto &depend_node = depend_nodes[i];
const auto &depend_cnode = depend_node->cast<CNodePtr>();
mgr->Replace(depend_cnode, depend_cnode->input(kInputIndex));
}
}
// Eliminate UpdateStates for consecutive Loads.
// Convert:
// x1 = Load(x1, u)
@ -336,10 +380,9 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
mgr->SetEdge(load, kAttachIndex, input_monad);
}
}
for (auto i = update_states.size() - 1; i > 0; i--) {
auto &us = update_states[i];
mgr->Replace(us, us->input(kInputIndex));
}
EliminateUselessNodesForUpdateStates(update_states);
if (make_tuple_inputs.size() == 1) {
// This should not happen.
MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2);

View File

@ -52,25 +52,24 @@ static bool IsInWhiteList(const CNodePtr &cnode) {
return false;
}
static void SetGradTag(const AnfNodePtr &node, NodeUsersMap node_users_map) {
auto node_users = node_users_map[node];
static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
const auto &node_users = manager->node_users()[node];
for (auto &user_pair : node_users) {
auto user_node = user_pair.first;
if (!user_node->grad()) {
user_node->set_grad(true);
SetGradTag(user_node, node_users_map);
SetGradTag(user_node, manager);
}
}
}
void PipelineTransformer::LabelRequiredGradCNode() {
auto parameters = root_->parameters();
auto node_users_map = manager_->node_users();
for (auto parameter : parameters) {
if (!ParameterRequireGrad(parameter)) {
continue;
}
SetGradTag(parameter, node_users_map);
SetGradTag(parameter, manager_);
}
}
@ -243,7 +242,7 @@ void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
while (need_coloring) {
need_coloring = false;
auto all_nodes = func->nodes();
auto node_users = manager_->node_users();
auto &node_users = manager_->node_users();
for (auto &node : all_nodes) {
if (node->isa<CNode>() || node->stage() == -1) {
continue;

View File

@ -58,7 +58,7 @@ void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, Addre
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_users = manager->node_users();
auto &node_users = manager->node_users();
if (node_users[anf_node_ptr].empty()) {
MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty.";
}

View File

@ -391,7 +391,7 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ(2, f->nodes().size());
ASSERT_EQ(1, g->nodes().size());
auto users = mng->node_users();
auto &users = mng->node_users();
for (auto& iter : users) {
ASSERT_EQ(1, iter.second.size());
}