forked from mindspore-Ecosystem/mindspore
Insert load handling for different cases of sequence parameters.
This commit is contained in:
parent
9818117c76
commit
79281813bd
|
@ -254,6 +254,20 @@ class COMMON_EXPORT AnfAlgo {
|
|||
return (abs != nullptr) && abs->isa<abstract::AbstractRefTensor>();
|
||||
}
|
||||
|
||||
// Check whether the sequence node has Ref abstract.
|
||||
static inline bool SequenceHasAbstractRef(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto &abs = node->abstract();
|
||||
if ((abs != nullptr) && (abs->isa<abstract::AbstractSequence>())) {
|
||||
auto abs_seq = abs->cast_ptr<abstract::AbstractSequence>();
|
||||
AbstractBasePtrList elements = abs_seq->elements();
|
||||
return std::any_of(elements.begin(), elements.end(), [](const AbstractBasePtr &element) {
|
||||
return (element != nullptr) && element->isa<abstract::AbstractRefTensor>();
|
||||
});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get the real output node and indexes of get item, make tuple, depend, load.
|
||||
static AnfNodePtr GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *const index_stack);
|
||||
static bool IsNopNode(const AnfNodePtr &node);
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "ir/param_info.h"
|
||||
#include "ir/cell.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "frontend/parallel/costmodel_context.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
|
@ -1019,7 +1020,7 @@ bool ExistTarget(const std::vector<AnfNodePtr> &all_nodes, const std::string &ta
|
|||
// If the return value of subgraph is Ref in control flow scenarios, should run graph mode with kernelbykernel.
|
||||
bool ExistSwitchRef(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
// %1 = switch(cond, func1, func2)
|
||||
// %2 = %1() if the abstract of the node is AbstractRefTensor, return true.
|
||||
// %2 = %1() if the abstract of the node is AbstractRefTensor or Tuple/List(AbstractRefTensor, ...), return true.
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto &node_users = manager->node_users();
|
||||
|
@ -1032,8 +1033,7 @@ bool ExistSwitchRef(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr
|
|||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
auto &user_node = user.first;
|
||||
const auto &abs = user_node->abstract();
|
||||
if (abs != nullptr && abs->isa<abstract::AbstractRefTensor>()) {
|
||||
if (common::AnfAlgo::HasAbstractRef(user_node) || common::AnfAlgo::SequenceHasAbstractRef(user_node)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "frontend/operator/composite/multitype_funcgraph.h"
|
||||
#include "utils/flags.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -103,22 +104,13 @@ int GetSideEffectPropagate(const PrimitivePtr &prim) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
// Return true if the node has Ref abstract.
|
||||
bool HasAbstractRef(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto &abs = node->abstract();
|
||||
return (abs != nullptr) && abs->isa<abstract::AbstractRefTensor>();
|
||||
}
|
||||
|
||||
// Gets ref inputs and its indexes from a cnode.
|
||||
RefInputs GetRefInputs(const CNodePtr &cnode) {
|
||||
RefInputs ref_inputs;
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto &input = cnode->inputs().at(i);
|
||||
if (HasAbstractRef(input)) {
|
||||
if (common::AnfAlgo::HasAbstractRef(input)) {
|
||||
ref_inputs[input].push_back(i);
|
||||
}
|
||||
}
|
||||
|
@ -132,14 +124,32 @@ bool HasRefInput(const CNodePtr &cnode) {
|
|||
}
|
||||
auto &inputs = cnode->inputs();
|
||||
// Return true if any of arguments is ref.
|
||||
return std::any_of(inputs.begin() + 1, inputs.end(), [](const auto &input) { return HasAbstractRef(input); });
|
||||
return std::any_of(inputs.begin() + 1, inputs.end(),
|
||||
[](const auto &input) { return common::AnfAlgo::HasAbstractRef(input); });
|
||||
}
|
||||
|
||||
// Return true if cnode has tuple(ref) or list(ref).
|
||||
bool HasRefSequenceInput(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr || cnode->inputs().empty()) {
|
||||
return false;
|
||||
}
|
||||
auto &inputs = cnode->inputs();
|
||||
for (size_t index = 1; index < inputs.size(); ++index) {
|
||||
auto input = cnode->input(index);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (common::AnfAlgo::SequenceHasAbstractRef(input)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Return true if we don't need Load for the given primitive.
|
||||
// i.e. keep Ref as Ref for some primitives.
|
||||
bool IsKeepRef(const PrimitivePtr &prim) {
|
||||
return (GetSideEffectPropagate(prim) != 0) || IsPrimitiveEquals(prim, prim::kPrimRefToEmbed) ||
|
||||
IsPrimitiveEquals(prim, prim::kPrimPull);
|
||||
IsPrimitiveEquals(prim, prim::kPrimPull) || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) ||
|
||||
IsPrimitiveEquals(prim, prim::kPrimMakeList);
|
||||
}
|
||||
|
||||
// Gets func_graph from the given cnode, return nullptr if it is not a func graph call.
|
||||
|
@ -177,6 +187,18 @@ prim::MultitypeFuncGraphPtr GetFuncMultitypeFuncGraph(const CNodePtr &cnode) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// The cnode is non-effect-node, and the cnode is real node, and the inputs of cnode is dynamic.
|
||||
bool IsNonEffectRealNodeAndInputIsDynamic(const CNodePtr &cnode) {
|
||||
static const PrimitiveSet dynamic_input_node_prims = {
|
||||
prim::kPrimStack, prim::kPrimConcat, prim::kPrimAddN, prim::kPrimIdentityN,
|
||||
prim::kPrimSparseConcat, prim::kPrimMeshgrid, prim::kPrimDynamicStitch};
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return dynamic_input_node_prims.find(prim) != dynamic_input_node_prims.end();
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// SCC (Strongly Connected Components) related types.
|
||||
// --------------------------------------------------------------------
|
||||
|
@ -383,11 +405,11 @@ class SideEffectFinder {
|
|||
std::vector<FuncGraphPtr> branches;
|
||||
auto true_branch = GetSwitchBranch(cnode, true_index);
|
||||
if (true_branch != nullptr) {
|
||||
branches.emplace_back(true_branch);
|
||||
(void)branches.emplace_back(true_branch);
|
||||
}
|
||||
auto false_branch = GetSwitchBranch(cnode, false_index);
|
||||
if (false_branch != nullptr) {
|
||||
branches.emplace_back(false_branch);
|
||||
(void)branches.emplace_back(false_branch);
|
||||
}
|
||||
if (branches.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
|
||||
|
@ -494,8 +516,9 @@ class SideEffectFinder {
|
|||
std::vector<FuncGraphPtr> GetSwitchLayerBranches(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
constexpr size_t func_tuple_index = 2;
|
||||
constexpr int recursive_level = 2;
|
||||
if (cnode->size() <= func_tuple_index) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(2);
|
||||
MS_LOG(EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(recursive_level);
|
||||
}
|
||||
auto func_tuple = cnode->inputs().at(func_tuple_index);
|
||||
return GetGraphsFromTuple(func_tuple);
|
||||
|
@ -528,15 +551,17 @@ class SideEffectFinder {
|
|||
std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) const {
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
auto &inputs = make_tuple->inputs();
|
||||
constexpr int recursive_level = 2;
|
||||
if (inputs.size() <= 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(2);
|
||||
MS_LOG(EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(recursive_level);
|
||||
}
|
||||
std::vector<FuncGraphPtr> graphs;
|
||||
graphs.reserve(inputs.size() - 1);
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(inputs.at(i));
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << make_tuple->DebugString(2) << " index=" << i;
|
||||
MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << make_tuple->DebugString(recursive_level)
|
||||
<< " index=" << i;
|
||||
continue;
|
||||
}
|
||||
graphs.push_back(func_graph);
|
||||
|
@ -596,10 +621,11 @@ class SideEffectFinder {
|
|||
MS_EXCEPTION_IF_NULL(tuple_indexes);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
|
||||
constexpr int recursive_level = 2;
|
||||
// Trace MakeTuple.
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
|
||||
if (tuple_indexes->empty()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected make_tuple: " << cnode->DebugString(2);
|
||||
MS_LOG(EXCEPTION) << "Unexpected make_tuple: " << cnode->DebugString(recursive_level);
|
||||
}
|
||||
// Pop out tuple index.
|
||||
auto top_index = tuple_indexes->top();
|
||||
|
@ -647,11 +673,11 @@ class SideEffectFinder {
|
|||
// %1 = J(primal)
|
||||
// tuple = %1(args)
|
||||
if (cnode->size() > 0 && IsPrimitiveCNode(cnode->input(0), prim::kPrimJ)) {
|
||||
MS_LOG(DEBUG) << "Tuple from J: " << cnode->DebugString(2);
|
||||
MS_LOG(DEBUG) << "Tuple from J: " << cnode->DebugString(recursive_level);
|
||||
return {EffectInfo::kDetected, false, false, false};
|
||||
}
|
||||
// Rare case.
|
||||
MS_LOG(WARNING) << "Tuple untraceable from: " << cnode->DebugString(2);
|
||||
MS_LOG(WARNING) << "Tuple untraceable from: " << cnode->DebugString(recursive_level);
|
||||
return {EffectInfo::kDetected, false, false, false};
|
||||
}
|
||||
|
||||
|
@ -981,6 +1007,9 @@ class SideEffectFinder {
|
|||
// load is inserted inside the func_graph f.
|
||||
info.load = HasRefInput(cnode);
|
||||
}
|
||||
if (!info.memory && IsNonEffectRealNodeAndInputIsDynamic(cnode)) {
|
||||
info.load = HasRefSequenceInput(cnode);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
|
@ -1293,7 +1322,7 @@ class AutoMonadConverter {
|
|||
// If the node has no side effects but 'no_eliminate' flag is set,
|
||||
// we save it to no_eliminate_nodes and handle them late.
|
||||
if (!info.memory && !info.io && IsNoEliminateNode(cnode)) {
|
||||
no_eliminate_nodes_.emplace_back(cnode);
|
||||
(void)no_eliminate_nodes_.emplace_back(cnode);
|
||||
}
|
||||
}
|
||||
cnode->SetEffectHandled(true);
|
||||
|
@ -1334,10 +1363,10 @@ class AutoMonadConverter {
|
|||
AbstractBasePtrList element_abstracts;
|
||||
tuple_inputs.reserve(no_eliminate_nodes_.size() + 1);
|
||||
element_abstracts.reserve(no_eliminate_nodes_.size());
|
||||
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
(void)tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (auto &node : no_eliminate_nodes_) {
|
||||
tuple_inputs.emplace_back(node);
|
||||
element_abstracts.emplace_back(node->abstract());
|
||||
(void)tuple_inputs.emplace_back(node);
|
||||
(void)element_abstracts.emplace_back(node->abstract());
|
||||
}
|
||||
auto make_tuple_node = func_graph_->NewCNode(tuple_inputs);
|
||||
make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
|
||||
|
@ -1351,8 +1380,9 @@ class AutoMonadConverter {
|
|||
// To: return output
|
||||
void ClearIsolatedNodes() const {
|
||||
auto output = GetGraphOutput();
|
||||
constexpr size_t attach_index = 2;
|
||||
if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
|
||||
IsPrimitiveCNode(output->cast<CNodePtr>()->input(2), prim::kPrimStopGradient)) {
|
||||
IsPrimitiveCNode(output->cast<CNodePtr>()->input(attach_index), prim::kPrimStopGradient)) {
|
||||
// Replace Depend(orig_output, StopGrad) node with orig_output.
|
||||
// After that, nodes may be eliminated if have no side effects.
|
||||
auto &orig_output = output->cast<CNodePtr>()->input(1);
|
||||
|
@ -1415,6 +1445,10 @@ class AutoMonadConverter {
|
|||
|
||||
void HandleLoad(const CNodePtr &cnode, bool update_state) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Check if a sequence which has ref exists in the inputs of the cnode, and the cnode is a real node.
|
||||
if (IsNonEffectRealNodeAndInputIsDynamic(cnode)) {
|
||||
return InsertLoadForSequenceRef(cnode, update_state);
|
||||
}
|
||||
if (IsValueNode<Primitive>(cnode->input(0))) {
|
||||
// For primitive calls that use Ref as input, insert Loads before them.
|
||||
InsertLoads(cnode, update_state);
|
||||
|
@ -1426,6 +1460,75 @@ class AutoMonadConverter {
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr NewItemNode(const AnfNodePtr &node, const AbstractBasePtr &seq_abs, const AbstractBasePtr &item_abs,
|
||||
size_t index) {
|
||||
std::vector<AnfNodePtr> item_inputs;
|
||||
if (seq_abs->isa<abstract::AbstractTuple>()) {
|
||||
(void)item_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
|
||||
} else if (seq_abs->isa<abstract::AbstractList>()) {
|
||||
(void)item_inputs.emplace_back(NewValueNode(prim::kPrimListGetItem));
|
||||
}
|
||||
(void)item_inputs.emplace_back(node);
|
||||
(void)item_inputs.emplace_back(NewValueNode(SizeToLong(index)));
|
||||
auto new_item = func_graph_->NewCNode(std::move(item_inputs));
|
||||
new_item->set_abstract(item_abs);
|
||||
if (item_abs->isa<abstract::AbstractRefTensor>()) {
|
||||
// Current u monad.
|
||||
auto current_u = GetUniverse();
|
||||
// Make a Load for item node.
|
||||
new_item = MakeLoad(node, new_item, current_u);
|
||||
}
|
||||
return new_item;
|
||||
}
|
||||
|
||||
// params = (param1, param2, ..., value)
|
||||
// addn(params, xxx) non-effect-node need insert load for params.
|
||||
void InsertLoadForSequenceRef(const CNodePtr &cnode, bool update_state) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
abstract::AbstractBasePtrList new_seq_abstracts;
|
||||
for (size_t index = 1; index < inputs.size(); ++index) {
|
||||
const auto &input = inputs[index];
|
||||
const auto &input_abs = input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(input_abs);
|
||||
if (!input_abs->isa<abstract::AbstractTuple>() && !input_abs->isa<abstract::AbstractList>()) {
|
||||
(void)new_seq_abstracts.emplace_back(input_abs);
|
||||
continue;
|
||||
}
|
||||
// Handle the input which is sequence.
|
||||
std::vector<AnfNodePtr> new_sequence_inputs;
|
||||
if (input_abs->isa<abstract::AbstractTuple>()) {
|
||||
(void)new_sequence_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
} else if (input_abs->isa<abstract::AbstractList>()) {
|
||||
(void)new_sequence_inputs.emplace_back(NewValueNode(prim::kPrimMakeList));
|
||||
}
|
||||
auto seq_abs = input_abs->cast_ptr<abstract::AbstractSequence>();
|
||||
MS_EXCEPTION_IF_NULL(seq_abs);
|
||||
const auto &elements = seq_abs->elements();
|
||||
for (size_t item_index = 0; item_index < elements.size(); ++item_index) {
|
||||
const auto &item_abs = elements[item_index];
|
||||
auto item = NewItemNode(input, input_abs, item_abs, item_index);
|
||||
(void)new_sequence_inputs.emplace_back(item);
|
||||
(void)new_seq_abstracts.emplace_back(item->abstract());
|
||||
}
|
||||
auto new_seq = func_graph_->NewCNode(std::move(new_sequence_inputs));
|
||||
MS_LOG(DEBUG) << "Replace the input of non-effect-node:" << cnode->DebugString()
|
||||
<< " with:" << new_seq->DebugString();
|
||||
if (input_abs->isa<abstract::AbstractTuple>()) {
|
||||
new_seq->set_abstract(std::make_shared<abstract::AbstractTuple>(new_seq_abstracts));
|
||||
} else if (input_abs->isa<abstract::AbstractList>()) {
|
||||
new_seq->set_abstract(std::make_shared<abstract::AbstractList>(new_seq_abstracts));
|
||||
}
|
||||
manager_->SetEdge(cnode, index, new_seq);
|
||||
if (update_state) {
|
||||
auto current_u = GetUniverse();
|
||||
// In the order_enforce phase, the cnode will be added to the updatestate to ensure the order,
|
||||
// and the input of the updatestate is maintained here to 2.
|
||||
// to ensure the verification of the updatestate in the relevant pass.
|
||||
u_ = UpdateState(current_u, new_seq);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Insert Loads for a primitive cnode that use Ref as input.
|
||||
// for example, from:
|
||||
|
@ -1478,18 +1581,18 @@ class AutoMonadConverter {
|
|||
for (size_t index : ref_input.second) {
|
||||
manager_->SetEdge(cnode, index, load);
|
||||
}
|
||||
loads.emplace_back(std::move(load));
|
||||
(void)loads.emplace_back(std::move(load));
|
||||
}
|
||||
return loads;
|
||||
}
|
||||
|
||||
CNodePtr MakeLoad(const CNodePtr &cnode, const AnfNodePtr &ref, const AnfNodePtr &u) {
|
||||
CNodePtr MakeLoad(const AnfNodePtr &node, const AnfNodePtr &ref, const AnfNodePtr &u) {
|
||||
static const std::string primitive_target = "primitive_target";
|
||||
// Create Load cnode.
|
||||
auto load_prim = NewValueNode(prim::kPrimLoad);
|
||||
auto load_cnode = func_graph_->NewCNode({load_prim, ref, u});
|
||||
// Set device target for Load CNode.
|
||||
std::string target = GetCNodeTarget(cnode);
|
||||
std::string target = GetCNodeTarget(node);
|
||||
load_cnode->set_user_data(primitive_target, std::make_shared<std::string>(target));
|
||||
// Set load_cnode abstract to Tensor according the input Ref[Tensor].
|
||||
auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(ref->abstract());
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 5cd93e4d7af7e98e05fdcc98c901c33c971523c8
|
||||
Subproject commit 763ef7f119301f465a75c7adc7b42ed60d3cfde2
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -84,6 +84,7 @@ def test_auto_monad_addn_adam():
|
|||
allclose_nparray(new_var_pyn.asnumpy(), new_var.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(new_m_pyn.asnumpy(), new_m.asnumpy(), 0.001, 0.001)
|
||||
allclose_nparray(new_v_pyn.asnumpy(), new_v.asnumpy(), 0.001, 0.001)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class AutoMonadTwoAssignTwoAddnDependencyNet(Cell):
|
||||
|
@ -257,3 +258,133 @@ def test_parameter_tuple_assign():
|
|||
out = net(x)
|
||||
assert out[0] == 2
|
||||
assert out[1] == 0
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_tuple_assign_addn():
|
||||
"""
|
||||
Feature: Auto monad feature.
|
||||
Description: Parameter tuple assign and addn.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.assign = P.Assign()
|
||||
self.addn = P.AddN()
|
||||
self.param1 = Parameter(Tensor(1), name="param1")
|
||||
self.param2 = Parameter(Tensor(2), name="param2")
|
||||
|
||||
def construct(self, x):
|
||||
params = (self.param1, self.param2)
|
||||
res1 = self.addn(params)
|
||||
self.assign(params[0], x)
|
||||
res2 = self.addn(params)
|
||||
self.assign(params[1], x * 2)
|
||||
res3 = self.addn(params)
|
||||
res4 = params[0] + params[1]
|
||||
res = (res1, res2, res3, res4)
|
||||
return res
|
||||
|
||||
x = Tensor(3)
|
||||
net = Net()
|
||||
out = net(x)
|
||||
assert out == (3, 5, 9, 9)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_tuple_assign_addn_inner_net():
|
||||
"""
|
||||
Feature: Auto monad feature.
|
||||
Description: Parameter tuple assign and addn.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class InnerNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.assign = P.Assign()
|
||||
self.addn = P.AddN()
|
||||
self.param1 = Parameter(Tensor(1), name="param1")
|
||||
self.param2 = Parameter(Tensor(2), name="param2")
|
||||
|
||||
def construct(self, x):
|
||||
params = (self.param1, self.param2)
|
||||
res1 = self.addn(params)
|
||||
self.assign(params[0], x)
|
||||
res2 = self.addn(params)
|
||||
res = (res1, res2, self.param1, self.param2)
|
||||
return res
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inner_net = InnerNet()
|
||||
self.addn = P.AddN()
|
||||
self.assign = P.Assign()
|
||||
|
||||
def construct(self, x, y):
|
||||
inner_net_res = self.inner_net(x)
|
||||
params = (inner_net_res[2], inner_net_res[3])
|
||||
out_res1 = self.addn(params)
|
||||
self.assign(inner_net_res[2], y)
|
||||
out_res2 = self.addn(params)
|
||||
self.assign(inner_net_res[3], 2 * y)
|
||||
return out_res1 + out_res2, inner_net_res[2] + inner_net_res[3]
|
||||
|
||||
input_x = Tensor(3)
|
||||
input_y = Tensor(5)
|
||||
net = Net()
|
||||
out = net(input_x, input_y)
|
||||
assert out == (12, 15)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_parameter_tuple_assign_addn_inner_net_control_flow():
|
||||
"""
|
||||
Feature: Auto monad feature.
|
||||
Description: Parameter tuple assign and addn.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class InnerNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param1 = Parameter(Tensor(1), name="param1")
|
||||
self.param2 = Parameter(Tensor(2), name="param2")
|
||||
|
||||
def construct(self, x):
|
||||
if x > 0:
|
||||
return self.param1, self.param2
|
||||
return self.param2, self.param1
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inner_net = InnerNet()
|
||||
self.addn = P.AddN()
|
||||
self.assign = P.Assign()
|
||||
|
||||
def construct(self, x, y):
|
||||
inner_params = self.inner_net(x)
|
||||
out_res1 = self.addn(inner_params)
|
||||
self.assign(inner_params[1], y)
|
||||
out_res2 = self.addn(inner_params)
|
||||
self.assign(inner_params[0], 2 * y)
|
||||
return out_res1 + out_res2, inner_params[0] + inner_params[1]
|
||||
|
||||
input_x = Tensor(3)
|
||||
input_y = Tensor(5)
|
||||
net = Net()
|
||||
out = net(input_x, input_y)
|
||||
assert out == (9, 15)
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
import mindspore.common._monad as monad
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -354,6 +355,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
|||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
self.load = P.Load()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
|
@ -399,7 +401,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
|||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if not overflow:
|
||||
self.optimizer(grads)
|
||||
return (loss, cond, scaling_sens)
|
||||
return loss, cond, self.load(scaling_sens, monad.U)
|
||||
|
||||
|
||||
class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell):
|
||||
|
|
|
@ -25,6 +25,7 @@ from mindspore.common.parameter import Parameter, ParameterTuple
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
import mindspore.common._monad as monad
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore import context
|
||||
from .bert_model import BertModel
|
||||
|
@ -377,6 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
self.load = P.Load()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
|
@ -433,4 +435,4 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if not overflow:
|
||||
self.optimizer(grads)
|
||||
return (loss, cond, scaling_sens)
|
||||
return loss, cond, self.load(scaling_sens, monad.U)
|
||||
|
|
Loading…
Reference in New Issue