!20715 [ME][Auto-Monad]Add ref user to UpdateState to ensure the order

Merge pull request !20715 from Margaret_wangrui/auto_monad_ref_users_master
This commit is contained in:
i-robot 2021-07-26 11:43:43 +00:00 committed by Gitee
commit 022c1c4583
7 changed files with 136 additions and 22 deletions

View File

@ -20,7 +20,6 @@
#include <string>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace opt {

View File

@ -53,6 +53,7 @@
#include "backend/optimizer/pass/communication_op_fusion.h"
#include "backend/optimizer/gpu/concat_outputs_for_all_gather.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/pass/optimize_updatestate.h"
#include "common/trans.h"
#include "debug/anf_ir_dump.h"
#include "debug/data_dump/e2e_dump.h"
@ -184,6 +185,8 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
// Remove node only used by UpdateState, in order to ensure the correct execution sequence in CudnnInplaceAggregate.
pm->AddPass(std::make_shared<opt::OptimizeUpdateState>());
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
pm->AddPass(std::make_shared<opt::ReluV2Pass>());
pm->AddPass(std::make_shared<opt::AddReluV2Fusion>());

View File

@ -628,7 +628,6 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
MS_EXCEPTION_IF_NULL(cnode_inputs);
auto origin_inputs = cnode->inputs();
const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend);
const bool is_updatestate = IsPrimitiveCNode(cnode, prim::kPrimUpdateState);
// if has multiple depends,only select first depend as parameter
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
auto anf = origin_inputs[input_idx];
@ -637,8 +636,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if ((is_depend && input_idx > kRealInputIndexInDepend) ||
(is_updatestate && input_idx > kUpdateStateRealInput)) {
} else if ((is_depend && input_idx > kRealInputIndexInDepend)) {
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
continue;
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {

View File

@ -79,6 +79,31 @@ class OrderEnforcer {
return abs != nullptr && abs->isa<abstract::AbstractRef>();
}
// Find Load or parameter users as the candidate nodes to enforce order of execution.
std::unordered_set<AnfNodePtr> GetSpecialOperatorRealUsers(const AnfNodePtr &node) {
auto &node_users = manager_->node_users();
auto iter = node_users.find(node);
if (iter == node_users.end()) {
return {};
}
std::unordered_set<AnfNodePtr> real_users;
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
real_users.insert(user_node);
}
return real_users;
}
bool IsOneOfPrimitive(const AnfNodePtr &node, const std::set<PrimitivePtr> &special_node_types) {
for (const auto &type : special_node_types) {
if (IsPrimitiveCNode(node, type)) {
return true;
}
}
return false;
}
void EnforceOrderForOtherCNode(const CNodePtr &cnode) {
// Find refs from the cnode inputs.
auto &inputs = cnode->inputs();
@ -87,6 +112,7 @@ class OrderEnforcer {
if (!IsPrimitiveCNode(last_input, prim::kPrimUpdateState)) {
return;
}
const std::set<PrimitivePtr> special_operators = {prim::kPrimExpandDims};
for (size_t i = 1; i < inputs.size(); ++i) {
auto &input = inputs.at(i);
if (!IsRef(input)) {
@ -96,7 +122,17 @@ class OrderEnforcer {
auto loads = FindLoadUsers(input);
for (auto load : loads) {
std::unordered_set<AnfNodePtr> load_users = FindUsers(load);
AddInputEdges(last_input->cast<CNodePtr>(), load_users);
std::unordered_set<AnfNodePtr> real_users;
for (auto load_user : load_users) {
// check the special operator, only one level of user is considered for now
if (IsOneOfPrimitive(load_user, special_operators)) {
std::unordered_set<AnfNodePtr> special_real_users = GetSpecialOperatorRealUsers(load_user);
real_users.insert(special_real_users.begin(), special_real_users.end());
} else {
real_users.insert(load_user);
}
}
AddInputEdges(last_input->cast<CNodePtr>(), real_users);
}
}
}
@ -126,7 +162,10 @@ class OrderEnforcer {
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
auto sorted_load_users = SortLoadUsers(load_users);
for (auto &load_user : sorted_load_users) {
if (!IsDependOn(load_user, update_state) && !IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
continue;
}
if (!IsDependOn(load_user, update_state)) {
processed_nodes_.insert(load_user);
if (!IsInUpdateState(load_user, update_state)) {
manager_->AddEdge(update_state, load_user);
@ -225,7 +264,6 @@ class OrderEnforcer {
return loads;
}
private:
const FuncGraphPtr &func_graph_;
FuncGraphManagerPtr manager_;
std::unordered_map<AnfNodePtr, size_t> topo_sort_map_;

View File

@ -38,6 +38,7 @@
#include "debug/rdr/running_data_recorder.h"
#include "utils/comm_manager.h"
#include "debug/debugger/debugger.h"
#include "backend/optimizer/pass/optimize_updatestate.h"
namespace mindspore {
namespace device {
@ -251,6 +252,8 @@ void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph)
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
// Remove node only used by UpdateState, in order to ensure the correct execution sequence in CudnnInplaceAggregate.
pm->AddPass(std::make_shared<opt::OptimizeUpdateState>());
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
}
pm->AddPass(std::make_shared<opt::ReluV2Pass>());

View File

@ -59,11 +59,11 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
continue;
}
auto &node_users = iter->second;
const bool has_outer_user = std::any_of(
std::begin(node_users), std::end(node_users), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
const bool is_outer_user = (seen.find(u.first) == seen.end());
return is_outer_user && !(IsPrimitiveCNode(u.first, prim::kPrimUpdateState) && u.second > 2);
});
const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users),
[&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
const bool is_outer_user = (seen.find(u.first) == seen.end());
return is_outer_user;
});
if (has_outer_user) {
output.emplace_back(node);
}
@ -127,16 +127,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
for (size_t i = value_start_index; i < inps.size(); ++i) {
args.emplace_back(NewValueNode(MakeValue(0)));
}
} else if (IsPrimitive(fn, prim::kPrimUpdateState)) {
args.emplace_back(RefSubGraphNode(fg, inps[1], &inputs, &eqv));
args.emplace_back(RefSubGraphNode(fg, inps[kUpdateStateRealInput], &inputs, &eqv));
const size_t additional_input_index = 3;
for (size_t i = additional_input_index; i < inps.size(); ++i) {
auto &input = inps[i];
if (eqv.find(input) != eqv.end()) {
args.emplace_back(RefSubGraphNode(fg, input, &inputs, &eqv));
}
}
} else {
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });

View File

@ -0,0 +1,83 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest
from mindspore.nn import Cell
from mindspore import context, Tensor, Parameter
import mindspore.ops.operations as P
import mindspore as ms
import numpy as np
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class AutoMonadAddnAdamNet(Cell):
def __init__(self, var, m, v):
super().__init__()
self.apply_adam = P.Adam()
self.var = Parameter(var, name="var")
self.m = Parameter(m, name="m")
self.v = Parameter(v, name="v")
self.addn = P.AddN()
self.mul = P.Mul()
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
out = self.addn((self.var, self.m, self.v))
self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
return out, self.var, self.m, self.v
def _count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_me) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
if np.any(np.isnan(data_expected)):
assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
_count_unequal_element(data_expected, data_me, rtol, atol)
else:
assert True
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_auto_monad_addn_adam():
var = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
m = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
v = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
net = AutoMonadAddnAdamNet(var, m, v)
beta1_power = Tensor(0.9, ms.float32)
beta2_power = Tensor(0.999, ms.float32)
lr = Tensor(0.1, ms.float32)
beta1 = Tensor(0.9, ms.float32)
beta2 = Tensor(0.999, ms.float32)
epsilon = Tensor(1e-8, ms.float32)
grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
out, new_var, new_m, new_v = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
net = AutoMonadAddnAdamNet(var, m, v)
context.set_context(mode=context.PYNATIVE_MODE)
out_pyn, new_var_pyn, new_m_pyn, new_v_pyn = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
allclose_nparray(out_pyn.asnumpy(), out.asnumpy(), 0.001, 0.001)
allclose_nparray(new_var_pyn.asnumpy(), new_var.asnumpy(), 0.001, 0.001)
allclose_nparray(new_m_pyn.asnumpy(), new_m.asnumpy(), 0.001, 0.001)
allclose_nparray(new_v_pyn.asnumpy(), new_v.asnumpy(), 0.001, 0.001)