forked from mindspore-Ecosystem/mindspore
!20715 [ME][Auto-Monad]Add ref user to UpdateState to ensure the order
Merge pull request !20715 from Margaret_wangrui/auto_monad_ref_users_master
This commit is contained in:
commit
022c1c4583
|
@ -20,7 +20,6 @@
|
|||
#include <string>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -53,6 +53,7 @@
|
|||
#include "backend/optimizer/pass/communication_op_fusion.h"
|
||||
#include "backend/optimizer/gpu/concat_outputs_for_all_gather.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/pass/optimize_updatestate.h"
|
||||
#include "common/trans.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/data_dump/e2e_dump.h"
|
||||
|
@ -184,6 +185,8 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
|
||||
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
|
||||
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
|
||||
// Remove node only used by UpdateState, in order to ensure the correct execution sequence in CudnnInplaceAggregate.
|
||||
pm->AddPass(std::make_shared<opt::OptimizeUpdateState>());
|
||||
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
|
||||
pm->AddPass(std::make_shared<opt::ReluV2Pass>());
|
||||
pm->AddPass(std::make_shared<opt::AddReluV2Fusion>());
|
||||
|
|
|
@ -628,7 +628,6 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
MS_EXCEPTION_IF_NULL(cnode_inputs);
|
||||
auto origin_inputs = cnode->inputs();
|
||||
const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend);
|
||||
const bool is_updatestate = IsPrimitiveCNode(cnode, prim::kPrimUpdateState);
|
||||
// if has multiple depends,only select first depend as parameter
|
||||
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
|
||||
auto anf = origin_inputs[input_idx];
|
||||
|
@ -637,8 +636,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
||||
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
||||
continue;
|
||||
} else if ((is_depend && input_idx > kRealInputIndexInDepend) ||
|
||||
(is_updatestate && input_idx > kUpdateStateRealInput)) {
|
||||
} else if ((is_depend && input_idx > kRealInputIndexInDepend)) {
|
||||
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
|
||||
continue;
|
||||
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
|
||||
|
|
|
@ -79,6 +79,31 @@ class OrderEnforcer {
|
|||
return abs != nullptr && abs->isa<abstract::AbstractRef>();
|
||||
}
|
||||
|
||||
// Find Load or parameter users as the candidate nodes to enforce order of execution.
|
||||
std::unordered_set<AnfNodePtr> GetSpecialOperatorRealUsers(const AnfNodePtr &node) {
|
||||
auto &node_users = manager_->node_users();
|
||||
auto iter = node_users.find(node);
|
||||
if (iter == node_users.end()) {
|
||||
return {};
|
||||
}
|
||||
std::unordered_set<AnfNodePtr> real_users;
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
auto &user_node = user.first;
|
||||
real_users.insert(user_node);
|
||||
}
|
||||
return real_users;
|
||||
}
|
||||
|
||||
bool IsOneOfPrimitive(const AnfNodePtr &node, const std::set<PrimitivePtr> &special_node_types) {
|
||||
for (const auto &type : special_node_types) {
|
||||
if (IsPrimitiveCNode(node, type)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void EnforceOrderForOtherCNode(const CNodePtr &cnode) {
|
||||
// Find refs from the cnode inputs.
|
||||
auto &inputs = cnode->inputs();
|
||||
|
@ -87,6 +112,7 @@ class OrderEnforcer {
|
|||
if (!IsPrimitiveCNode(last_input, prim::kPrimUpdateState)) {
|
||||
return;
|
||||
}
|
||||
const std::set<PrimitivePtr> special_operators = {prim::kPrimExpandDims};
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto &input = inputs.at(i);
|
||||
if (!IsRef(input)) {
|
||||
|
@ -96,7 +122,17 @@ class OrderEnforcer {
|
|||
auto loads = FindLoadUsers(input);
|
||||
for (auto load : loads) {
|
||||
std::unordered_set<AnfNodePtr> load_users = FindUsers(load);
|
||||
AddInputEdges(last_input->cast<CNodePtr>(), load_users);
|
||||
std::unordered_set<AnfNodePtr> real_users;
|
||||
for (auto load_user : load_users) {
|
||||
// check the special operator, only one level of user is considered for now
|
||||
if (IsOneOfPrimitive(load_user, special_operators)) {
|
||||
std::unordered_set<AnfNodePtr> special_real_users = GetSpecialOperatorRealUsers(load_user);
|
||||
real_users.insert(special_real_users.begin(), special_real_users.end());
|
||||
} else {
|
||||
real_users.insert(load_user);
|
||||
}
|
||||
}
|
||||
AddInputEdges(last_input->cast<CNodePtr>(), real_users);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -126,7 +162,10 @@ class OrderEnforcer {
|
|||
void AddInputEdges(const CNodePtr &update_state, const std::unordered_set<AnfNodePtr> &load_users) {
|
||||
auto sorted_load_users = SortLoadUsers(load_users);
|
||||
for (auto &load_user : sorted_load_users) {
|
||||
if (!IsDependOn(load_user, update_state) && !IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
|
||||
if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
if (!IsDependOn(load_user, update_state)) {
|
||||
processed_nodes_.insert(load_user);
|
||||
if (!IsInUpdateState(load_user, update_state)) {
|
||||
manager_->AddEdge(update_state, load_user);
|
||||
|
@ -225,7 +264,6 @@ class OrderEnforcer {
|
|||
return loads;
|
||||
}
|
||||
|
||||
private:
|
||||
const FuncGraphPtr &func_graph_;
|
||||
FuncGraphManagerPtr manager_;
|
||||
std::unordered_map<AnfNodePtr, size_t> topo_sort_map_;
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "debug/rdr/running_data_recorder.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "debug/debugger/debugger.h"
|
||||
#include "backend/optimizer/pass/optimize_updatestate.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -251,6 +252,8 @@ void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph)
|
|||
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
|
||||
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
// Remove node only used by UpdateState, in order to ensure the correct execution sequence in CudnnInplaceAggregate.
|
||||
pm->AddPass(std::make_shared<opt::OptimizeUpdateState>());
|
||||
pm->AddPass(std::make_shared<opt::CudnnInplaceAggregate>());
|
||||
}
|
||||
pm->AddPass(std::make_shared<opt::ReluV2Pass>());
|
||||
|
|
|
@ -59,11 +59,11 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
|
|||
continue;
|
||||
}
|
||||
auto &node_users = iter->second;
|
||||
const bool has_outer_user = std::any_of(
|
||||
std::begin(node_users), std::end(node_users), [&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
|
||||
const bool is_outer_user = (seen.find(u.first) == seen.end());
|
||||
return is_outer_user && !(IsPrimitiveCNode(u.first, prim::kPrimUpdateState) && u.second > 2);
|
||||
});
|
||||
const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users),
|
||||
[&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
|
||||
const bool is_outer_user = (seen.find(u.first) == seen.end());
|
||||
return is_outer_user;
|
||||
});
|
||||
if (has_outer_user) {
|
||||
output.emplace_back(node);
|
||||
}
|
||||
|
@ -127,16 +127,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
for (size_t i = value_start_index; i < inps.size(); ++i) {
|
||||
args.emplace_back(NewValueNode(MakeValue(0)));
|
||||
}
|
||||
} else if (IsPrimitive(fn, prim::kPrimUpdateState)) {
|
||||
args.emplace_back(RefSubGraphNode(fg, inps[1], &inputs, &eqv));
|
||||
args.emplace_back(RefSubGraphNode(fg, inps[kUpdateStateRealInput], &inputs, &eqv));
|
||||
const size_t additional_input_index = 3;
|
||||
for (size_t i = additional_input_index; i < inps.size(); ++i) {
|
||||
auto &input = inps[i];
|
||||
if (eqv.find(input) != eqv.end()) {
|
||||
args.emplace_back(RefSubGraphNode(fg, input, &inputs, &eqv));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
|
||||
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import context, Tensor, Parameter
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore as ms
|
||||
import numpy as np
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class AutoMonadAddnAdamNet(Cell):
|
||||
def __init__(self, var, m, v):
|
||||
super().__init__()
|
||||
self.apply_adam = P.Adam()
|
||||
self.var = Parameter(var, name="var")
|
||||
self.m = Parameter(m, name="m")
|
||||
self.v = Parameter(v, name="v")
|
||||
self.addn = P.AddN()
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||
out = self.addn((self.var, self.m, self.v))
|
||||
self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
return out, self.var, self.m, self.v
|
||||
|
||||
|
||||
def _count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_me) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, \
|
||||
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||
format(data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
|
||||
if np.any(np.isnan(data_expected)):
|
||||
assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
|
||||
elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
|
||||
_count_unequal_element(data_expected, data_me, rtol, atol)
|
||||
else:
|
||||
assert True
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_auto_monad_addn_adam():
|
||||
var = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
m = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
v = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
net = AutoMonadAddnAdamNet(var, m, v)
|
||||
beta1_power = Tensor(0.9, ms.float32)
|
||||
beta2_power = Tensor(0.999, ms.float32)
|
||||
lr = Tensor(0.1, ms.float32)
|
||||
beta1 = Tensor(0.9, ms.float32)
|
||||
beta2 = Tensor(0.999, ms.float32)
|
||||
epsilon = Tensor(1e-8, ms.float32)
|
||||
grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
|
||||
out, new_var, new_m, new_v = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
net = AutoMonadAddnAdamNet(var, m, v)
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
out_pyn, new_var_pyn, new_m_pyn, new_v_pyn = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
allclose_nparray(out_pyn.asnumpy(), out.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(new_var_pyn.asnumpy(), new_var.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(new_m_pyn.asnumpy(), new_m.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(new_v_pyn.asnumpy(), new_v.asnumpy(), 0.001, 0.001)
|
Loading…
Reference in New Issue