forked from mindspore-Ecosystem/mindspore
fix bug in do signature
This commit is contained in:
parent
584641180f
commit
7e4d972f6f
|
@ -65,6 +65,7 @@ test_temp_summary_event_file/
|
|||
*.ckpt
|
||||
*.shp
|
||||
*.pkl
|
||||
*.pb
|
||||
.clangd
|
||||
mindspore/version.py
|
||||
mindspore/default_config.py
|
||||
|
|
|
@ -253,7 +253,7 @@ std::string Dtype2String(const std::string &dtypes) {
|
|||
std::string TypeId2String(TypeId type_id) {
|
||||
auto iter = type_id_str_map.find(type_id);
|
||||
if (iter == type_id_str_map.end()) {
|
||||
MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
|
||||
return std::string(TypeIdLabel(type_id));
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
|
|
@ -47,16 +47,6 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
|
|||
return empty;
|
||||
}
|
||||
|
||||
const std::string GetOpName(const ValuePtr &function) {
|
||||
std::string name = "";
|
||||
if (function->isa<Primitive>()) {
|
||||
name = function->cast<PrimitivePyPtr>()->name();
|
||||
} else if (function->isa<MetaFuncGraph>()) {
|
||||
name = function->cast<MetaFuncGraphPtr>()->name();
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list,
|
||||
const std::vector<Signature> &signature, bool has_var, std::vector<AnfNodePtr> *const op_inputs) {
|
||||
std::size_t sig_size = signature.size();
|
||||
|
@ -93,7 +83,8 @@ void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number,
|
|||
*max_type_number = type_number;
|
||||
}
|
||||
|
||||
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs) {
|
||||
TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector<size_t> indexs,
|
||||
const std::set<size_t> &write_indexs) {
|
||||
TypeId max_type_id = kTypeUnknown;
|
||||
TypeId max_type = kTypeUnknown;
|
||||
size_t max_type_number = 0;
|
||||
|
@ -103,8 +94,13 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|||
TypeId arg_type = kTypeUnknown;
|
||||
AbstractBasePtr arg_value = args_spec_list[index];
|
||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||
auto is_write = (write_indexs.find(index) != write_indexs.end());
|
||||
if (is_write) {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
|
||||
} else {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||
}
|
||||
}
|
||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
||||
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
||||
auto tensor_type = tensor->element()->BuildType();
|
||||
|
@ -157,7 +153,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
|
|||
|
||||
// Get the largest type of index in the same SignatureEnumDType of arguments.
|
||||
std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnumDType> &dtypes,
|
||||
const abstract::AbstractBasePtrList &args_spec_list) {
|
||||
const abstract::AbstractBasePtrList &args_spec_list,
|
||||
const std::set<size_t> &write_indexs) {
|
||||
// record index for signature.dtypes of the same type
|
||||
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
|
||||
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
||||
|
@ -192,7 +189,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum
|
|||
(void)dst_type.insert(std::make_pair(type, kTypeUnknown));
|
||||
continue;
|
||||
}
|
||||
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs)));
|
||||
(void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indexs, write_indexs)));
|
||||
}
|
||||
return dst_type;
|
||||
}
|
||||
|
@ -205,9 +202,9 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
|
|||
return NewCNode({cast_node, param, dtype_node}, graph);
|
||||
}
|
||||
|
||||
void DoAutoCast(const std::vector<Signature> &signature, const abstract::AbstractBasePtrList &args_spec_list,
|
||||
const FuncGraphPtr &graph, std::vector<AnfNodePtr> *const op_inputs,
|
||||
const std::set<size_t> &write_indexs) {
|
||||
void DoAutoCast(const std::string &func_name, const std::vector<Signature> &signature,
|
||||
const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph,
|
||||
std::vector<AnfNodePtr> *const op_inputs, const std::set<size_t> &write_indexs) {
|
||||
std::vector<SignatureEnumDType> dtypes;
|
||||
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
||||
[](const Signature &sig) { return sig.dtype; });
|
||||
|
@ -216,17 +213,24 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
|
|||
return;
|
||||
}
|
||||
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
|
||||
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list);
|
||||
std::map<SignatureEnumDType, TypeId> dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs);
|
||||
// Identify which arg requires auto cast
|
||||
for (size_t i = 0; i < args_spec_list.size(); ++i) {
|
||||
auto it = dst_type.find(dtypes[i]);
|
||||
if (it == dst_type.end() || it->second == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
auto rw_it = write_indexs.find(i);
|
||||
auto is_write = (rw_it != write_indexs.end());
|
||||
|
||||
AbstractBasePtr arg_value = args_spec_list[i];
|
||||
if (arg_value->isa<abstract::AbstractRef>()) {
|
||||
if (is_write) {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin();
|
||||
} else {
|
||||
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref();
|
||||
}
|
||||
}
|
||||
TypeId arg_type_id = kTypeUnknown;
|
||||
if (arg_value->isa<abstract::AbstractTensor>()) {
|
||||
auto tensor = arg_value->cast<abstract::AbstractTensorPtr>();
|
||||
|
@ -243,10 +247,9 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
|
|||
if (it_map == type_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto rw_it = write_indexs.find(i);
|
||||
if (rw_it != write_indexs.end()) {
|
||||
if (is_write) {
|
||||
if (arg_type_id != it->second) {
|
||||
MS_LOG(EXCEPTION) << "In op '" << GetOpName(graph) << "', argument '" << args_spec_list[i]
|
||||
MS_LOG(EXCEPTION) << "In op '" << func_name << "', argument '" << args_spec_list[i]
|
||||
<< "' can not cast type from '" << TypeIdLabel(arg_type_id) << "' to '"
|
||||
<< TypeIdLabel(it->second) << "' automatically.";
|
||||
}
|
||||
|
@ -299,8 +302,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
if (sig == SignatureEnumRW::kRWRead) {
|
||||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param});
|
||||
} else if (sig == SignatureEnumRW::kRWWrite) {
|
||||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param});
|
||||
write_indexs.insert(i);
|
||||
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param});
|
||||
}
|
||||
// If sig is SignatureEnumRW::kRWRef, not do anything.
|
||||
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
|
||||
|
@ -310,7 +313,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
|
|||
}
|
||||
// process default
|
||||
ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs);
|
||||
DoAutoCast(signature, args_spec_list, func_graph, &op_inputs, write_indexs);
|
||||
DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs);
|
||||
return func_graph->NewCNode(op_inputs);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -160,7 +160,7 @@ AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const Primitive
|
|||
const AbstractBasePtrList &args_spec_list) {
|
||||
// arguments: value
|
||||
if (args_spec_list.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size()
|
||||
MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size()
|
||||
<< ".";
|
||||
}
|
||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||
|
|
|
@ -81,8 +81,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
|
||||
// Ref eliminate
|
||||
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
|
||||
get_make_ref_eliminate_ =
|
||||
MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue});
|
||||
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate",
|
||||
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
|
||||
|
||||
replace_refkey_by_param_ =
|
||||
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
|
||||
|
|
|
@ -48,6 +48,7 @@ class MakeRefEliminater : public AnfVisitor {
|
|||
|
||||
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
||||
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
||||
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
|
||||
class GetMakeRefEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
|
@ -71,6 +72,10 @@ class GetMakeRefEliminater : public AnfVisitor {
|
|||
return ref->input(2);
|
||||
}
|
||||
|
||||
if (cnode->IsApply(prim::kPrimGetRefOrigin)) {
|
||||
return ref->input(3);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -315,7 +315,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
|
|||
|
||||
ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple);
|
||||
ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend);
|
||||
ValueNodePtr get_refkey_op = NewValueNode(prim::kPrimGetRefKey);
|
||||
ValueNodePtr get_ref_origin_op = NewValueNode(prim::kPrimGetRefOrigin);
|
||||
ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient);
|
||||
const std::string primitive_name("assign");
|
||||
const std::string module_name("mindspore.ops.functional");
|
||||
|
@ -329,8 +329,8 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
|
|||
vec_states.emplace_back(make_tuple_op);
|
||||
for (auto &item : state_assign_) {
|
||||
auto source = ReadVariable(item.second);
|
||||
auto refkey = func_graph()->NewCNode({get_refkey_op, item.first});
|
||||
auto assign = func_graph()->NewCNode({assign_op, refkey, source});
|
||||
auto origin = func_graph()->NewCNode({get_ref_origin_op, item.first});
|
||||
auto assign = func_graph()->NewCNode({assign_op, origin, source});
|
||||
MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second;
|
||||
vec_states.emplace_back(assign);
|
||||
}
|
||||
|
|
|
@ -783,7 +783,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
|
||||
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
|
||||
if (ref_abs == nullptr) {
|
||||
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref.";
|
||||
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
auto key_abs = ref_abs->ref_key();
|
||||
|
|
|
@ -170,7 +170,7 @@ def get_py_obj_dtype(obj):
|
|||
Type of MindSpore type.
|
||||
"""
|
||||
# Tensor
|
||||
if hasattr(obj, 'dtype'):
|
||||
if hasattr(obj, 'dtype') and callable(obj.dtype) and isinstance(obj.dtype(), typing.Type):
|
||||
return tensor_type(obj.dtype())
|
||||
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
|
||||
return function
|
||||
|
|
|
@ -31,7 +31,9 @@ from ...common.tensor import Tensor
|
|||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from .._utils import get_concat_offset
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
|
||||
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
||||
|
@ -2156,13 +2158,17 @@ class ScatterUpdate(PrimitiveWithInfer):
|
|||
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
|
||||
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
||||
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||
>>> op = P.ScatterNdUpdate()
|
||||
>>> op = P.ScatterUpdate()
|
||||
>>> output = op(input_x, indices, update)
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=True):
|
||||
"""Init ScatterNdUpdate"""
|
||||
"""Init ScatterUpdate"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||
|
@ -2201,7 +2207,11 @@ class ScatterNdUpdate(PrimitiveWithInfer):
|
|||
>>> op = P.ScatterNdUpdate()
|
||||
>>> output = op(input_x, indices, update)
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
@prim_attr_register
|
||||
def __init__(self, use_locking=True):
|
||||
"""Init ScatterNdUpdate"""
|
||||
|
|
|
@ -179,7 +179,7 @@ class AssignAdd(PrimitiveWithInfer):
|
|||
return value
|
||||
|
||||
def infer_dtype(self, variable, value):
|
||||
args = {"value": value}
|
||||
args = {"variable": variable, "value": value}
|
||||
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return value
|
||||
|
||||
|
@ -222,7 +222,7 @@ class AssignSub(PrimitiveWithInfer):
|
|||
return value
|
||||
|
||||
def infer_dtype(self, variable, value):
|
||||
args = {"value": value}
|
||||
args = {"variable": variable, "value": value}
|
||||
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
|
||||
return value
|
||||
|
||||
|
|
|
@ -58,6 +58,8 @@ class Assign(PrimitiveWithInfer):
|
|||
return variable
|
||||
|
||||
def infer_dtype(self, variable, value):
|
||||
args = {"variable": variable, "value": value}
|
||||
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return variable
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,18 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test layer switch"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
|
|
|
@ -345,19 +345,6 @@ class Conv2dNativeNet(nn.Cell):
|
|||
return self.flatten(self.conv(input_x, self.weight))
|
||||
|
||||
|
||||
class MakeRefKeyNet(nn.Cell):
|
||||
""" MakeRefKeyNet definition """
|
||||
|
||||
def __init__(self):
|
||||
super(MakeRefKeyNet, self).__init__()
|
||||
self.y = Parameter(Tensor([1.0], mindspore.float32), name="y")
|
||||
|
||||
def construct(self, x):
|
||||
key = P.MakeRefKey("y")()
|
||||
P.Assign()(key, x)
|
||||
return x
|
||||
|
||||
|
||||
class StateNet(nn.Cell):
|
||||
""" StateTestTensor definition """
|
||||
|
||||
|
@ -538,10 +525,6 @@ test_cases = [
|
|||
'block': Grad(NetWithLossClass(Conv2dNativeNet())),
|
||||
'desc_inputs': [Tensor(np.ones([1, 3, 16, 16], np.float32)), Tensor(np.zeros([1, 1764], np.float32))],
|
||||
}),
|
||||
('MakeRefKey', {
|
||||
'block': MakeRefKeyNet(),
|
||||
'desc_inputs': [Tensor([2.0], mindspore.float32)],
|
||||
}),
|
||||
('StateTest', {
|
||||
'block': StateNet(),
|
||||
'desc_inputs': [Tensor(np.ones([2, 1, 2, 2]).astype(np.float32))],
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
test assign sub
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
import mindspore as ms
|
||||
|
||||
class AssignW(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AssignW, self).__init__()
|
||||
self.assign = P.Assign()
|
||||
|
||||
def construct(self, x, w):
|
||||
self.assign(x, w)
|
||||
return x
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.b = Parameter(initializer('ones', [5]), name='b')
|
||||
self.assign = AssignW()
|
||||
|
||||
def construct(self, value):
|
||||
return self.assign(self.b, value)
|
||||
|
||||
|
||||
def test_assign_through_cell():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
net = Net()
|
||||
net.to_float(ms.float16)
|
||||
net.add_flags_recursive(fp16=False)
|
||||
input_data = Tensor(np.ones([5]).astype(np.float32))
|
||||
net(input_data)
|
||||
with pytest.raises(TypeError):
|
||||
net(None)
|
||||
|
||||
|
||||
class NetScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetScatterNdUpdate, self).__init__()
|
||||
self.b = Parameter(initializer('ones', [5, 5]), name='b')
|
||||
self.scatter = P.ScatterNdUpdate()
|
||||
|
||||
def construct(self, idx, x):
|
||||
return self.scatter(self.b, idx, x)
|
||||
|
||||
|
||||
def test_scatter_nd_update():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = NetScatterNdUpdate()
|
||||
x = Tensor(np.ones([5]).astype(np.float16))
|
||||
idx = Tensor(np.ones([1]).astype(np.int32))
|
||||
net(idx, x)
|
Loading…
Reference in New Issue