!12041 unify parameter output type

From: @zhangbuxue
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-02-04 00:01:36 +08:00 committed by Gitee
commit e6221008a1
5 changed files with 6 additions and 5 deletions

View File

@ -544,7 +544,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
} }
if (CanSpecializeNode(func)) { if (CanSpecializeNode(func)) {
// for primitive node , we build the primitive node with inferred attributes in the first pass // for primitive node, we build the primitive node with inferred attributes in the first pass
// so we do not build replaced node again here in second pass // so we do not build replaced node again here in second pass
if (IsValueNode<Primitive>(func)) { if (IsValueNode<Primitive>(func)) {
new_inputs[0] = func; new_inputs[0] = func;

View File

@ -105,7 +105,7 @@ class Parameter(Tensor_):
>>> print(net(x)) >>> print(net(x))
[[2.]] [[2.]]
>>> net.weight.set_data(Tensor(np.zeros((1,2)))) >>> net.weight.set_data(Tensor(np.zeros((1,2))))
Parameter (name=w) Parameter (name=w, shape=(1, 2), dtype=Float64, requires_grad=True)
>>> print(net(x)) >>> print(net(x))
[[0.]] [[0.]]
""" """

View File

@ -168,6 +168,7 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0"; MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0";
} }
auto depends = args_spec_list[0]->Broaden(); auto depends = args_spec_list[0]->Broaden();
// For scalar, need to set value to kAnyValue, because broaden scalar will not change the value.
if (depends->isa<AbstractScalar>()) { if (depends->isa<AbstractScalar>()) {
depends->set_value(kAnyValue); depends->set_value(kAnyValue);
} }

View File

@ -218,7 +218,7 @@ class AssignAdd(PrimitiveWithInfer):
>>> value = Tensor(np.ones([1]).astype(np.int64)*100) >>> value = Tensor(np.ones([1]).astype(np.int64)*100)
>>> output = net(value) >>> output = net(value)
>>> print(output) >>> print(output)
Parameter (name=global_step) Parameter (name=global_step, shape=(1,), dtype=Int64, requires_grad=True)
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
@ -273,7 +273,7 @@ class AssignSub(PrimitiveWithInfer):
>>> value = Tensor(np.ones([1]).astype(np.int32)*100) >>> value = Tensor(np.ones([1]).astype(np.int32)*100)
>>> output = net(value) >>> output = net(value)
>>> print(output) >>> print(output)
Parameter (name=global_step) Parameter (name=global_step, shape=(1,), dtype=Int32, requires_grad=True)
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (

View File

@ -54,7 +54,7 @@ class Assign(PrimitiveWithCheck):
>>> net = Net() >>> net = Net()
>>> output = net(x) >>> output = net(x)
>>> print(output) >>> print(output)
Parameter (name=y) Parameter (name=y, shape=(1,), dtype=Float32, requires_grad=True)
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),