!12621 Eliminate all redundant nodes related to UpdateStates.
From: @zh_qh Reviewed-by: @ginfung,@hwhewei Signed-off-by: @ginfung
This commit is contained in:
commit
b13cabeb10
|
@ -1878,10 +1878,9 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
|
|||
|
||||
std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
|
||||
const AnfNodePtr &front_node) {
|
||||
auto node_users = front_func_graph_manager->node_users();
|
||||
auto users = node_users[front_node];
|
||||
auto &users = front_func_graph_manager->node_users()[front_node];
|
||||
std::vector<AnfNodePtr> result;
|
||||
for (auto user : users) {
|
||||
for (auto &user : users) {
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -433,12 +433,12 @@ class IncorporateGetitemSwitch : public AnfVisitor {
|
|||
MS_EXCEPTION_IF_NULL(switch_call_cnode);
|
||||
auto manager = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto node_users_map = manager->node_users();
|
||||
auto &node_users_map = manager->node_users();
|
||||
auto it = node_users_map.find(switch_call);
|
||||
if (it == node_users_map.end()) {
|
||||
return false;
|
||||
}
|
||||
auto node_users = it->second;
|
||||
auto &node_users = it->second;
|
||||
// If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.s
|
||||
auto tuple_getitem_num = std::count_if(node_users.begin(), node_users.end(), [](std::pair<AnfNodePtr, int> &user) {
|
||||
return IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem);
|
||||
|
|
|
@ -291,6 +291,50 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd
|
|||
return make_tuple;
|
||||
}
|
||||
|
||||
// Remove all nodes related to UpdateStates, if they're redundant.
|
||||
void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_states) {
|
||||
if (update_states.empty()) {
|
||||
return;
|
||||
}
|
||||
auto mgr = GetManager(update_states[0]);
|
||||
|
||||
// 1. Remove the use of UpdateState nodes, except the last one.
|
||||
for (auto i = update_states.size() - 1; i > 0; i--) {
|
||||
auto &us = update_states[i];
|
||||
mgr->Replace(us, us->input(kInputIndex));
|
||||
}
|
||||
|
||||
// 2. Remove the Depend users of last UpdateState node.
|
||||
auto &node_users = mgr->node_users();
|
||||
auto iter = node_users.find(update_states[0]);
|
||||
if (iter == node_users.end()) {
|
||||
return;
|
||||
}
|
||||
auto &us_users = iter->second;
|
||||
if (us_users.size() < 2) {
|
||||
return;
|
||||
}
|
||||
std::vector<AnfNodePtr> depend_nodes;
|
||||
for (auto &user : us_users) {
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kAttachIndex) {
|
||||
depend_nodes.emplace_back(user.first);
|
||||
}
|
||||
}
|
||||
if (depend_nodes.empty()) {
|
||||
return;
|
||||
}
|
||||
ssize_t end = 0;
|
||||
// If all users are Depend CNode.
|
||||
if (depend_nodes.size() == us_users.size()) {
|
||||
end = 1;
|
||||
}
|
||||
for (ssize_t i = depend_nodes.size() - 1; i >= end; i--) {
|
||||
const auto &depend_node = depend_nodes[i];
|
||||
const auto &depend_cnode = depend_node->cast<CNodePtr>();
|
||||
mgr->Replace(depend_cnode, depend_cnode->input(kInputIndex));
|
||||
}
|
||||
}
|
||||
|
||||
// Eliminate UpdateStates for consecutive Loads.
|
||||
// Convert:
|
||||
// x1 = Load(x1, u)
|
||||
|
@ -336,10 +380,9 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
|
|||
mgr->SetEdge(load, kAttachIndex, input_monad);
|
||||
}
|
||||
}
|
||||
for (auto i = update_states.size() - 1; i > 0; i--) {
|
||||
auto &us = update_states[i];
|
||||
mgr->Replace(us, us->input(kInputIndex));
|
||||
}
|
||||
|
||||
EliminateUselessNodesForUpdateStates(update_states);
|
||||
|
||||
if (make_tuple_inputs.size() == 1) {
|
||||
// This should not happen.
|
||||
MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2);
|
||||
|
|
|
@ -52,25 +52,24 @@ static bool IsInWhiteList(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static void SetGradTag(const AnfNodePtr &node, NodeUsersMap node_users_map) {
|
||||
auto node_users = node_users_map[node];
|
||||
static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
|
||||
const auto &node_users = manager->node_users()[node];
|
||||
for (auto &user_pair : node_users) {
|
||||
auto user_node = user_pair.first;
|
||||
if (!user_node->grad()) {
|
||||
user_node->set_grad(true);
|
||||
SetGradTag(user_node, node_users_map);
|
||||
SetGradTag(user_node, manager);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PipelineTransformer::LabelRequiredGradCNode() {
|
||||
auto parameters = root_->parameters();
|
||||
auto node_users_map = manager_->node_users();
|
||||
for (auto parameter : parameters) {
|
||||
if (!ParameterRequireGrad(parameter)) {
|
||||
continue;
|
||||
}
|
||||
SetGradTag(parameter, node_users_map);
|
||||
SetGradTag(parameter, manager_);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -243,7 +242,7 @@ void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
|
|||
while (need_coloring) {
|
||||
need_coloring = false;
|
||||
auto all_nodes = func->nodes();
|
||||
auto node_users = manager_->node_users();
|
||||
auto &node_users = manager_->node_users();
|
||||
for (auto &node : all_nodes) {
|
||||
if (node->isa<CNode>() || node->stage() == -1) {
|
||||
continue;
|
||||
|
|
|
@ -58,7 +58,7 @@ void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, Addre
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto node_users = manager->node_users();
|
||||
auto &node_users = manager->node_users();
|
||||
if (node_users[anf_node_ptr].empty()) {
|
||||
MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty.";
|
||||
}
|
||||
|
|
|
@ -391,7 +391,7 @@ TEST_F(TestManager, test_nested_manual) {
|
|||
ASSERT_EQ(2, f->nodes().size());
|
||||
ASSERT_EQ(1, g->nodes().size());
|
||||
|
||||
auto users = mng->node_users();
|
||||
auto &users = mng->node_users();
|
||||
for (auto& iter : users) {
|
||||
ASSERT_EQ(1, iter.second.size());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue