fix bug in do signature

This commit is contained in:
Wei Luning 2020-06-01 16:59:23 +08:00
parent 584641180f
commit 7e4d972f6f
16 changed files with 153 additions and 59 deletions

1
.gitignore vendored
View File

@ -65,6 +65,7 @@ test_temp_summary_event_file/
*.ckpt *.ckpt
*.shp *.shp
*.pkl *.pkl
*.pb
.clangd .clangd
mindspore/version.py mindspore/version.py
mindspore/default_config.py mindspore/default_config.py

View File

@ -253,7 +253,7 @@ std::string Dtype2String(const std::string &dtypes) {
std::string TypeId2String(TypeId type_id) { std::string TypeId2String(TypeId type_id) {
auto iter = type_id_str_map.find(type_id); auto iter = type_id_str_map.find(type_id);
if (iter == type_id_str_map.end()) { if (iter == type_id_str_map.end()) {
MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id); return std::string(TypeIdLabel(type_id));
} }
return iter->second; return iter->second;
} }

View File

@ -47,16 +47,6 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
return empty; return empty;
} }
const std::string GetOpName(const ValuePtr &function) {
std::string name = "";
if (function->isa<Primitive>()) {
name = function->cast<PrimitivePyPtr>()->name();
} else if (function->isa<MetaFuncGraph>()) {
name = function->cast<MetaFuncGraphPtr>()->name();
}
return name;
}
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) { const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
std::size_t sig_size = signature.size(); std::size_t sig_size = signature.size();
@ -93,7 +83,8 @@ void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number,
*max_type_number = type_number; *max_type_number = type_number;
} }
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs) { TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs,
const std::set<size_t> &write_indexs) {
TypeId max_type_id = kTypeUnknown; TypeId max_type_id = kTypeUnknown;
TypeId max_type = kTypeUnknown; TypeId max_type = kTypeUnknown;
size_t max_type_number = 0; size_t max_type_number = 0;
@ -103,7 +94,12 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
TypeId arg_type = kTypeUnknown; TypeId arg_type = kTypeUnknown;
AbstractBasePtr arg_value = args_spec_list[index]; AbstractBasePtr arg_value = args_spec_list[index];
if (arg_value->isa<abstract::AbstractRef>()) { if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); auto is_write = (write_indexs.find(index) != write_indexs.end());
if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
} }
if (arg_value->isa<abstract::AbstractTensor>()) { if (arg_value->isa<abstract::AbstractTensor>()) {
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>(); auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
@ -157,7 +153,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
// Get the largest type of index in the same SignatureEnumDType of arguments. // Get the largest type of index in the same SignatureEnumDType of arguments.
std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes, std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
const abstract::AbstractBasePtrList &args_spec_list) { const abstract::AbstractBasePtrList &args_spec_list,
const std::set<size_t> &write_indexs) {
// record index for signature.dtypes of the same type // record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
@ -192,7 +189,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum
(void)dst_type.insert(std::make_pair(type, kTypeUnknown)); (void)dst_type.insert(std::make_pair(type, kTypeUnknown));
continue; continue;
} }
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs))); (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs)));
} }
return dst_type; return dst_type;
} }
@ -205,9 +202,9 @@ AnfNodePtr DoCast(const AnfNodePtr &param, const TypeId &type_id, const FuncGrap
return NewCNode({cast_node, param, dtype_node}, graph); return NewCNode({cast_node, param, dtype_node}, graph);
} }
void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list, void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
const FuncGraphPtr &graph, std::vector<AnfNodePtr> *const op_inputs, const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
const std::set<size_t> &write_indexs) { std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) {
std::vector<SignatureEnumDType> dtypes; std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature &sig) { return sig.dtype; }); [](const Signature &sig) { return sig.dtype; });
@ -216,16 +213,23 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
return; return;
} }
// Stat the index of the arguments with the largest type in the same SignatureEnumDType. // Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list); std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs);
// Identify which arg requires auto cast // Identify which arg requires auto cast
for (size_t i = 0; i < args_spec_list.size(); ++i) { for (size_t i = 0; i < args_spec_list.size(); ++i) {
auto it = dst_type.find(dtypes[i]); auto it = dst_type.find(dtypes[i]);
if (it == dst_type.end() || it->second == kTypeUnknown) { if (it == dst_type.end() || it->second == kTypeUnknown) {
continue; continue;
} }
auto rw_it = write_indexs.find(i);
auto is_write = (rw_it != write_indexs.end());
AbstractBasePtr arg_value = args_spec_list[i]; AbstractBasePtr arg_value = args_spec_list[i];
if (arg_value->isa<abstract::AbstractRef>()) { if (arg_value->isa<abstract::AbstractRef>()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); if (is_write) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
} else {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
}
} }
TypeId arg_type_id = kTypeUnknown; TypeId arg_type_id = kTypeUnknown;
if (arg_value->isa<abstract::AbstractTensor>()) { if (arg_value->isa<abstract::AbstractTensor>()) {
@ -243,10 +247,9 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
if (it_map == type_map.end()) { if (it_map == type_map.end()) {
continue; continue;
} }
auto rw_it = write_indexs.find(i); if (is_write) {
if (rw_it != write_indexs.end()) {
if (arg_type_id != it->second) { if (arg_type_id != it->second) {
MS_LOG(EXCEPTION) << "In op '" << GetOpName(graph) << "', argument '" << args_spec_list[i] MS_LOG(EXCEPTION) << "In op '" << func_name << "', argument '" << args_spec_list[i]
<< "' can not cast type from '" << TypeIdLabel(arg_type_id) << "' to '" << "' can not cast type from '" << TypeIdLabel(arg_type_id) << "' to '"
<< TypeIdLabel(it->second) << "' automatically."; << TypeIdLabel(it->second) << "' automatically.";
} }
@ -299,8 +302,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
if (sig == SignatureEnumRW::kRWRead) { if (sig == SignatureEnumRW::kRWRead) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
} else if (sig == SignatureEnumRW::kRWWrite) { } else if (sig == SignatureEnumRW::kRWWrite) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
write_indexs.insert(i); write_indexs.insert(i);
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param});
} }
// If sig is SignatureEnumRW::kRWRef, not do anything. // If sig is SignatureEnumRW::kRWRef, not do anything.
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
@ -310,7 +313,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
} }
// process default // process default
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
DoAutoCast(signature, args_spec_list, func_graph, &op_inputs, write_indexs); DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs);
return func_graph->NewCNode(op_inputs); return func_graph->NewCNode(op_inputs);
} }
} // namespace } // namespace

View File

@ -160,7 +160,7 @@ AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const Primitive
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// arguments: value // arguments: value
if (args_spec_list.size() != 1) { if (args_spec_list.size() != 1) {
MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size() MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size()
<< "."; << ".";
} }
TypePtr type = args_spec_list[0]->GetTypeTrack(); TypePtr type = args_spec_list[0]->GetTypeTrack();

View File

@ -81,8 +81,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Ref eliminate // Ref eliminate
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
get_make_ref_eliminate_ = get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue}); {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
replace_refkey_by_param_ = replace_refkey_by_param_ =
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);

View File

@ -48,6 +48,7 @@ class MakeRefEliminater : public AnfVisitor {
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
class GetMakeRefEliminater : public AnfVisitor { class GetMakeRefEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -71,6 +72,10 @@ class GetMakeRefEliminater : public AnfVisitor {
return ref->input(2); return ref->input(2);
} }
if (cnode->IsApply(prim::kPrimGetRefOrigin)) {
return ref->input(3);
}
return nullptr; return nullptr;
} }
}; };

View File

@ -315,7 +315,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
ValueNodePtr get_refkey_op = NewValueNode(prim::kPrimGetRefKey); ValueNodePtr get_ref_origin_op = NewValueNode(prim::kPrimGetRefOrigin);
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
const std::string primitive_name("assign"); const std::string primitive_name("assign");
const std::string module_name("mindspore.ops.functional"); const std::string module_name("mindspore.ops.functional");
@ -329,8 +329,8 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
vec_states.emplace_back(make_tuple_op); vec_states.emplace_back(make_tuple_op);
for (auto &item : state_assign_) { for (auto &item : state_assign_) {
auto source = ReadVariable(item.second); auto source = ReadVariable(item.second);
auto refkey = func_graph()->NewCNode({get_refkey_op, item.first}); auto origin = func_graph()->NewCNode({get_ref_origin_op, item.first});
auto assign = func_graph()->NewCNode({assign_op, refkey, source}); auto assign = func_graph()->NewCNode({assign_op, origin, source});
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
vec_states.emplace_back(assign); vec_states.emplace_back(assign);
} }

View File

@ -801,8 +801,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
std::string AbstractRef::ToString() const { std::string AbstractRef::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
buffer << type_name() << "(" buffer << type_name() << "("
<< "key: " << ref_key_->ToString() << "ref_value: " << ref_->ToString() << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString()
<< "origin_value: " << ref_origin_->ToString(); << " origin_value: " << ref_origin_->ToString();
auto value = GetValueTrack(); auto value = GetValueTrack();
if (value) { if (value) {
buffer << ", value: " << value->ToString(); buffer << ", value: " << value->ToString();

View File

@ -783,7 +783,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>(); AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
if (ref_abs == nullptr) { if (ref_abs == nullptr) {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref."; MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
return nullptr; return nullptr;
} }
auto key_abs = ref_abs->ref_key(); auto key_abs = ref_abs->ref_key();

View File

@ -170,7 +170,7 @@ def get_py_obj_dtype(obj):
Type of MindSpore type. Type of MindSpore type.
""" """
# Tensor # Tensor
if hasattr(obj, 'dtype'): if hasattr(obj, 'dtype') and callable(obj.dtype) and isinstance(obj.dtype(), typing.Type):
return tensor_type(obj.dtype()) return tensor_type(obj.dtype())
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
return function return function

View File

@ -31,7 +31,9 @@ from ...common.tensor import Tensor
from ..operations.math_ops import _infer_shape_reduce from ..operations.math_ops import _infer_shape_reduce
from .._utils import get_concat_offset from .._utils import get_concat_offset
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
def _check_infer_attr_reduce(axis, keep_dims, prim_name): def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
@ -2156,13 +2158,17 @@ class ScatterUpdate(PrimitiveWithInfer):
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)) >>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate() >>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, update) >>> output = op(input_x, indices, update)
""" """
__mindspore_signature__ = (
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
"""Init ScatterNdUpdate""" """Init ScatterUpdate"""
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape): def infer_shape(self, x_shape, indices_shape, value_shape):
@ -2201,7 +2207,11 @@ class ScatterNdUpdate(PrimitiveWithInfer):
>>> op = P.ScatterNdUpdate() >>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update) >>> output = op(input_x, indices, update)
""" """
__mindspore_signature__ = (
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):
"""Init ScatterNdUpdate""" """Init ScatterNdUpdate"""

View File

@ -179,7 +179,7 @@ class AssignAdd(PrimitiveWithInfer):
return value return value
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"value": value} args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
return value return value
@ -222,7 +222,7 @@ class AssignSub(PrimitiveWithInfer):
return value return value
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"value": value} args = {"variable": variable, "value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
return value return value

View File

@ -58,6 +58,8 @@ class Assign(PrimitiveWithInfer):
return variable return variable
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return variable return variable

View File

@ -1,3 +1,18 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test layer switch"""
import numpy as np import numpy as np
import mindspore import mindspore

View File

@ -345,19 +345,6 @@ class Conv2dNativeNet(nn.Cell):
return self.flatten(self.conv(input_x, self.weight)) return self.flatten(self.conv(input_x, self.weight))
class MakeRefKeyNet(nn.Cell):
""" MakeRefKeyNet definition """
def __init__(self):
super(MakeRefKeyNet, self).__init__()
self.y = Parameter(Tensor([1.0], mindspore.float32), name="y")
def construct(self, x):
key = P.MakeRefKey("y")()
P.Assign()(key, x)
return x
class StateNet(nn.Cell): class StateNet(nn.Cell):
""" StateTestTensor definition """ """ StateTestTensor definition """
@ -538,10 +525,6 @@ test_cases = [
'block': Grad(NetWithLossClass(Conv2dNativeNet())), 'block': Grad(NetWithLossClass(Conv2dNativeNet())),
'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 1764], np.float32))], 'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 1764], np.float32))],
}), }),
('MakeRefKey', {
'block': MakeRefKeyNet(),
'desc_inputs': [Tensor([2.0], mindspore.float32)],
}),
('StateTest', { ('StateTest', {
'block': StateNet(), 'block': StateNet(),
'desc_inputs': [Tensor(np.ones([2, 1, 2, 2]).astype(np.float32))], 'desc_inputs': [Tensor(np.ones([2, 1, 2, 2]).astype(np.float32))],

View File

@ -0,0 +1,75 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test assign sub
"""
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
import mindspore as ms
class AssignW(nn.Cell):
def __init__(self):
super(AssignW, self).__init__()
self.assign = P.Assign()
def construct(self, x, w):
self.assign(x, w)
return x
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.b = Parameter(initializer('ones', [5]), name='b')
self.assign = AssignW()
def construct(self, value):
return self.assign(self.b, value)
def test_assign_through_cell():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = Net()
net.to_float(ms.float16)
net.add_flags_recursive(fp16=False)
input_data = Tensor(np.ones([5]).astype(np.float32))
net(input_data)
with pytest.raises(TypeError):
net(None)
class NetScatterNdUpdate(nn.Cell):
def __init__(self):
super(NetScatterNdUpdate, self).__init__()
self.b = Parameter(initializer('ones', [5, 5]), name='b')
self.scatter = P.ScatterNdUpdate()
def construct(self, idx, x):
return self.scatter(self.b, idx, x)
def test_scatter_nd_update():
context.set_context(mode=context.GRAPH_MODE)
net = NetScatterNdUpdate()
x = Tensor(np.ones([5]).astype(np.float16))
idx = Tensor(np.ones([1]).astype(np.int32))
net(idx, x)