forked from mindspore-Ecosystem/mindspore
fix bug in do signature
This commit is contained in:
parent
584641180f
commit
7e4d972f6f
|
@ -65,6 +65,7 @@ test_temp_summary_event_file/
|
||||||
*.ckpt
|
*.ckpt
|
||||||
*.shp
|
*.shp
|
||||||
*.pkl
|
*.pkl
|
||||||
|
*.pb
|
||||||
.clangd
|
.clangd
|
||||||
mindspore/version.py
|
mindspore/version.py
|
||||||
mindspore/default_config.py
|
mindspore/default_config.py
|
||||||
|
|
|
@ -253,7 +253,7 @@ std::string Dtype2String(const std::string &dtypes) {
|
||||||
std::string TypeId2String(TypeId type_id) {
|
std::string TypeId2String(TypeId type_id) {
|
||||||
auto iter = type_id_str_map.find(type_id);
|
auto iter = type_id_str_map.find(type_id);
|
||||||
if (iter == type_id_str_map.end()) {
|
if (iter == type_id_str_map.end()) {
|
||||||
MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
|
return std::string(TypeIdLabel(type_id));
|
||||||
}
|
}
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,16 +47,6 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
|
||||||
return empty;
|
return empty;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string GetOpName(const ValuePtr &function) {
|
|
||||||
std::string name = "";
|
|
||||||
if (function->isa<Primitive>()) {
|
|
||||||
name = function->cast<PrimitivePyPtr>()->name();
|
|
||||||
} else if (function->isa<MetaFuncGraph>()) {
|
|
||||||
name = function->cast<MetaFuncGraphPtr>()->name();
|
|
||||||
}
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
|
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
|
||||||
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
|
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
|
||||||
std::size_t sig_size = signature.size();
|
std::size_t sig_size = signature.size();
|
||||||
|
@ -93,7 +83,8 @@ void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number,
|
||||||
*max_type_number = type_number;
|
*max_type_number = type_number;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs) {
|
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs,
|
||||||
|
const std::set<size_t> &write_indexs) {
|
||||||
TypeId max_type_id = kTypeUnknown;
|
TypeId max_type_id = kTypeUnknown;
|
||||||
TypeId max_type = kTypeUnknown;
|
TypeId max_type = kTypeUnknown;
|
||||||
size_t max_type_number = 0;
|
size_t max_type_number = 0;
|
||||||
|
@ -103,7 +94,12 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
||||||
TypeId arg_type = kTypeUnknown;
|
TypeId arg_type = kTypeUnknown;
|
||||||
AbstractBasePtr arg_value = args_spec_list[index];
|
AbstractBasePtr arg_value = args_spec_list[index];
|
||||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
auto is_write = (write_indexs.find(index) != write_indexs.end());
|
||||||
|
if (is_write) {
|
||||||
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
|
||||||
|
} else {
|
||||||
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
if (arg_value->isa<abstract::AbstractTensor>()) {
|
||||||
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
||||||
|
@ -157,7 +153,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
||||||
|
|
||||||
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
||||||
std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
||||||
const abstract::AbstractBasePtrList &args_spec_list) {
|
const abstract::AbstractBasePtrList &args_spec_list,
|
||||||
|
const std::set<size_t> &write_indexs) {
|
||||||
// record index for signature.dtypes of the same type
|
// record index for signature.dtypes of the same type
|
||||||
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
||||||
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
||||||
|
@ -192,7 +189,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum
|
||||||
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
|
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs)));
|
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs)));
|
||||||
}
|
}
|
||||||
return dst_type;
|
return dst_type;
|
||||||
}
|
}
|
||||||
|
@ -205,9 +202,9 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
|
||||||
return NewCNode({cast_node, param, dtype_node}, graph);
|
return NewCNode({cast_node, param, dtype_node}, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list,
|
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
|
||||||
const FuncGraphPtr &graph, std::vector<AnfNodePtr> *const op_inputs,
|
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
|
||||||
const std::set<size_t> &write_indexs) {
|
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) {
|
||||||
std::vector<SignatureEnumDType> dtypes;
|
std::vector<SignatureEnumDType> dtypes;
|
||||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||||
[](const Signature &sig) { return sig.dtype; });
|
[](const Signature &sig) { return sig.dtype; });
|
||||||
|
@ -216,16 +213,23 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
|
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
|
||||||
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list);
|
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs);
|
||||||
// Identify which arg requires auto cast
|
// Identify which arg requires auto cast
|
||||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||||
auto it = dst_type.find(dtypes[i]);
|
auto it = dst_type.find(dtypes[i]);
|
||||||
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
auto rw_it = write_indexs.find(i);
|
||||||
|
auto is_write = (rw_it != write_indexs.end());
|
||||||
|
|
||||||
AbstractBasePtr arg_value = args_spec_list[i];
|
AbstractBasePtr arg_value = args_spec_list[i];
|
||||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
if (is_write) {
|
||||||
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
|
||||||
|
} else {
|
||||||
|
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
TypeId arg_type_id = kTypeUnknown;
|
TypeId arg_type_id = kTypeUnknown;
|
||||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
if (arg_value->isa<abstract::AbstractTensor>()) {
|
||||||
|
@ -243,10 +247,9 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
|
||||||
if (it_map == type_map.end()) {
|
if (it_map == type_map.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto rw_it = write_indexs.find(i);
|
if (is_write) {
|
||||||
if (rw_it != write_indexs.end()) {
|
|
||||||
if (arg_type_id != it->second) {
|
if (arg_type_id != it->second) {
|
||||||
MS_LOG(EXCEPTION) << "In op '" << GetOpName(graph) << "', argument '" << args_spec_list[i]
|
MS_LOG(EXCEPTION) << "In op '" << func_name << "', argument '" << args_spec_list[i]
|
||||||
<< "' can not cast type from '" << TypeIdLabel(arg_type_id) << "' to '"
|
<< "' can not cast type from '" << TypeIdLabel(arg_type_id) << "' to '"
|
||||||
<< TypeIdLabel(it->second) << "' automatically.";
|
<< TypeIdLabel(it->second) << "' automatically.";
|
||||||
}
|
}
|
||||||
|
@ -299,8 +302,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
||||||
if (sig == SignatureEnumRW::kRWRead) {
|
if (sig == SignatureEnumRW::kRWRead) {
|
||||||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
|
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
|
||||||
} else if (sig == SignatureEnumRW::kRWWrite) {
|
} else if (sig == SignatureEnumRW::kRWWrite) {
|
||||||
|
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
|
||||||
write_indexs.insert(i);
|
write_indexs.insert(i);
|
||||||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param});
|
|
||||||
}
|
}
|
||||||
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
||||||
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
|
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
|
||||||
|
@ -310,7 +313,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
||||||
}
|
}
|
||||||
// process default
|
// process default
|
||||||
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
|
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
|
||||||
DoAutoCast(signature, args_spec_list, func_graph, &op_inputs, write_indexs);
|
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs);
|
||||||
return func_graph->NewCNode(op_inputs);
|
return func_graph->NewCNode(op_inputs);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -160,7 +160,7 @@ AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const Primitive
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// arguments: value
|
// arguments: value
|
||||||
if (args_spec_list.size() != 1) {
|
if (args_spec_list.size() != 1) {
|
||||||
MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size()
|
MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size()
|
||||||
<< ".";
|
<< ".";
|
||||||
}
|
}
|
||||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||||
|
|
|
@ -81,8 +81,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
|
|
||||||
// Ref eliminate
|
// Ref eliminate
|
||||||
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
|
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
|
||||||
get_make_ref_eliminate_ =
|
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
|
||||||
MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue});
|
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
||||||
|
|
||||||
replace_refkey_by_param_ =
|
replace_refkey_by_param_ =
|
||||||
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
|
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
|
||||||
|
|
|
@ -48,6 +48,7 @@ class MakeRefEliminater : public AnfVisitor {
|
||||||
|
|
||||||
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
||||||
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
||||||
|
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
|
||||||
class GetMakeRefEliminater : public AnfVisitor {
|
class GetMakeRefEliminater : public AnfVisitor {
|
||||||
public:
|
public:
|
||||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
@ -71,6 +72,10 @@ class GetMakeRefEliminater : public AnfVisitor {
|
||||||
return ref->input(2);
|
return ref->input(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cnode->IsApply(prim::kPrimGetRefOrigin)) {
|
||||||
|
return ref->input(3);
|
||||||
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -315,7 +315,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
|
||||||
|
|
||||||
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
|
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
|
||||||
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
|
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
|
||||||
ValueNodePtr get_refkey_op = NewValueNode(prim::kPrimGetRefKey);
|
ValueNodePtr get_ref_origin_op = NewValueNode(prim::kPrimGetRefOrigin);
|
||||||
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
|
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
|
||||||
const std::string primitive_name("assign");
|
const std::string primitive_name("assign");
|
||||||
const std::string module_name("mindspore.ops.functional");
|
const std::string module_name("mindspore.ops.functional");
|
||||||
|
@ -329,8 +329,8 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
|
||||||
vec_states.emplace_back(make_tuple_op);
|
vec_states.emplace_back(make_tuple_op);
|
||||||
for (auto &item : state_assign_) {
|
for (auto &item : state_assign_) {
|
||||||
auto source = ReadVariable(item.second);
|
auto source = ReadVariable(item.second);
|
||||||
auto refkey = func_graph()->NewCNode({get_refkey_op, item.first});
|
auto origin = func_graph()->NewCNode({get_ref_origin_op, item.first});
|
||||||
auto assign = func_graph()->NewCNode({assign_op, refkey, source});
|
auto assign = func_graph()->NewCNode({assign_op, origin, source});
|
||||||
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
|
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
|
||||||
vec_states.emplace_back(assign);
|
vec_states.emplace_back(assign);
|
||||||
}
|
}
|
||||||
|
|
|
@ -801,8 +801,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
|
||||||
std::string AbstractRef::ToString() const {
|
std::string AbstractRef::ToString() const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
buffer << type_name() << "("
|
buffer << type_name() << "("
|
||||||
<< "key: " << ref_key_->ToString() << "ref_value: " << ref_->ToString()
|
<< "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString()
|
||||||
<< "origin_value: " << ref_origin_->ToString();
|
<< " origin_value: " << ref_origin_->ToString();
|
||||||
auto value = GetValueTrack();
|
auto value = GetValueTrack();
|
||||||
if (value) {
|
if (value) {
|
||||||
buffer << ", value: " << value->ToString();
|
buffer << ", value: " << value->ToString();
|
||||||
|
|
|
@ -783,7 +783,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
||||||
AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
|
AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
|
||||||
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
|
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
|
||||||
if (ref_abs == nullptr) {
|
if (ref_abs == nullptr) {
|
||||||
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref.";
|
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto key_abs = ref_abs->ref_key();
|
auto key_abs = ref_abs->ref_key();
|
||||||
|
|
|
@ -170,7 +170,7 @@ def get_py_obj_dtype(obj):
|
||||||
Type of MindSpore type.
|
Type of MindSpore type.
|
||||||
"""
|
"""
|
||||||
# Tensor
|
# Tensor
|
||||||
if hasattr(obj, 'dtype'):
|
if hasattr(obj, 'dtype') and callable(obj.dtype) and isinstance(obj.dtype(), typing.Type):
|
||||||
return tensor_type(obj.dtype())
|
return tensor_type(obj.dtype())
|
||||||
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
|
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
|
||||||
return function
|
return function
|
||||||
|
|
|
@ -31,7 +31,9 @@ from ...common.tensor import Tensor
|
||||||
from ..operations.math_ops import _infer_shape_reduce
|
from ..operations.math_ops import _infer_shape_reduce
|
||||||
from .._utils import get_concat_offset
|
from .._utils import get_concat_offset
|
||||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||||
|
from ..._c_expression import signature_rw as sig_rw
|
||||||
|
from ..._c_expression import signature_kind as sig_kind
|
||||||
|
from ..._c_expression import signature_dtype as sig_dtype
|
||||||
|
|
||||||
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
||||||
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
||||||
|
@ -2156,13 +2158,17 @@ class ScatterUpdate(PrimitiveWithInfer):
|
||||||
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
|
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
|
||||||
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||||
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||||
>>> op = P.ScatterNdUpdate()
|
>>> op = P.ScatterUpdate()
|
||||||
>>> output = op(input_x, indices, update)
|
>>> output = op(input_x, indices, update)
|
||||||
"""
|
"""
|
||||||
|
__mindspore_signature__ = (
|
||||||
|
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||||
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||||
|
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||||
|
)
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, use_locking=True):
|
def __init__(self, use_locking=True):
|
||||||
"""Init ScatterNdUpdate"""
|
"""Init ScatterUpdate"""
|
||||||
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||||
|
|
||||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||||
|
@ -2201,7 +2207,11 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
||||||
>>> op = P.ScatterNdUpdate()
|
>>> op = P.ScatterNdUpdate()
|
||||||
>>> output = op(input_x, indices, update)
|
>>> output = op(input_x, indices, update)
|
||||||
"""
|
"""
|
||||||
|
__mindspore_signature__ = (
|
||||||
|
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||||
|
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||||
|
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||||
|
)
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, use_locking=True):
|
def __init__(self, use_locking=True):
|
||||||
"""Init ScatterNdUpdate"""
|
"""Init ScatterNdUpdate"""
|
||||||
|
|
|
@ -179,7 +179,7 @@ class AssignAdd(PrimitiveWithInfer):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def infer_dtype(self, variable, value):
|
def infer_dtype(self, variable, value):
|
||||||
args = {"value": value}
|
args = {"variable": variable, "value": value}
|
||||||
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
|
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@ -222,7 +222,7 @@ class AssignSub(PrimitiveWithInfer):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def infer_dtype(self, variable, value):
|
def infer_dtype(self, variable, value):
|
||||||
args = {"value": value}
|
args = {"variable": variable, "value": value}
|
||||||
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
|
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,8 @@ class Assign(PrimitiveWithInfer):
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
def infer_dtype(self, variable, value):
|
def infer_dtype(self, variable, value):
|
||||||
|
args = {"variable": variable, "value": value}
|
||||||
|
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test layer switch"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import mindspore
|
import mindspore
|
||||||
|
|
|
@ -345,19 +345,6 @@ class Conv2dNativeNet(nn.Cell):
|
||||||
return self.flatten(self.conv(input_x, self.weight))
|
return self.flatten(self.conv(input_x, self.weight))
|
||||||
|
|
||||||
|
|
||||||
class MakeRefKeyNet(nn.Cell):
|
|
||||||
""" MakeRefKeyNet definition """
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(MakeRefKeyNet, self).__init__()
|
|
||||||
self.y = Parameter(Tensor([1.0], mindspore.float32), name="y")
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
key = P.MakeRefKey("y")()
|
|
||||||
P.Assign()(key, x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class StateNet(nn.Cell):
|
class StateNet(nn.Cell):
|
||||||
""" StateTestTensor definition """
|
""" StateTestTensor definition """
|
||||||
|
|
||||||
|
@ -538,10 +525,6 @@ test_cases = [
|
||||||
'block': Grad(NetWithLossClass(Conv2dNativeNet())),
|
'block': Grad(NetWithLossClass(Conv2dNativeNet())),
|
||||||
'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 1764], np.float32))],
|
'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 1764], np.float32))],
|
||||||
}),
|
}),
|
||||||
('MakeRefKey', {
|
|
||||||
'block': MakeRefKeyNet(),
|
|
||||||
'desc_inputs': [Tensor([2.0], mindspore.float32)],
|
|
||||||
}),
|
|
||||||
('StateTest', {
|
('StateTest', {
|
||||||
'block': StateNet(),
|
'block': StateNet(),
|
||||||
'desc_inputs': [Tensor(np.ones([2, 1, 2, 2]).astype(np.float32))],
|
'desc_inputs': [Tensor(np.ones([2, 1, 2, 2]).astype(np.float32))],
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
test assign sub
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
import mindspore as ms
|
||||||
|
|
||||||
|
class AssignW(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(AssignW, self).__init__()
|
||||||
|
self.assign = P.Assign()
|
||||||
|
|
||||||
|
def construct(self, x, w):
|
||||||
|
self.assign(x, w)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.b = Parameter(initializer('ones', [5]), name='b')
|
||||||
|
self.assign = AssignW()
|
||||||
|
|
||||||
|
def construct(self, value):
|
||||||
|
return self.assign(self.b, value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_assign_through_cell():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
net = Net()
|
||||||
|
net.to_float(ms.float16)
|
||||||
|
net.add_flags_recursive(fp16=False)
|
||||||
|
input_data = Tensor(np.ones([5]).astype(np.float32))
|
||||||
|
net(input_data)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
net(None)
|
||||||
|
|
||||||
|
|
||||||
|
class NetScatterNdUpdate(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetScatterNdUpdate, self).__init__()
|
||||||
|
self.b = Parameter(initializer('ones', [5, 5]), name='b')
|
||||||
|
self.scatter = P.ScatterNdUpdate()
|
||||||
|
|
||||||
|
def construct(self, idx, x):
|
||||||
|
return self.scatter(self.b, idx, x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatter_nd_update():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = NetScatterNdUpdate()
|
||||||
|
x = Tensor(np.ones([5]).astype(np.float16))
|
||||||
|
idx = Tensor(np.ones([1]).astype(np.int32))
|
||||||
|
net(idx, x)
|
Loading…
Reference in New Issue