forked from mindspore-Ecosystem/mindspore
!5264 [bug]fix bugs when parameters updata
Merge pull request !5264 from vlne-v1/I1SP3I-return-value-not-the-exact-value
This commit is contained in:
commit
fabebb2678
|
@ -751,19 +751,17 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) {
|
|||
return ExecDFGraph(info_, args, phase_s);
|
||||
}
|
||||
#else
|
||||
if (backend == "ms" || backend == "ge") {
|
||||
auto ret_val = std::make_shared<py::object>();
|
||||
if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) {
|
||||
if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) {
|
||||
return *ret_val;
|
||||
}
|
||||
auto ret_val = std::make_shared<py::object>();
|
||||
if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) {
|
||||
if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) {
|
||||
return *ret_val;
|
||||
}
|
||||
if (backend == "ge") {
|
||||
if (args.size() > 0) {
|
||||
return args[0];
|
||||
}
|
||||
return args;
|
||||
}
|
||||
if (backend == "ge") {
|
||||
if (args.size() > 0) {
|
||||
return args[0];
|
||||
}
|
||||
return args;
|
||||
}
|
||||
#endif
|
||||
std::size_t full_arg_size = ArgListSize(phase_s);
|
||||
|
|
|
@ -389,6 +389,8 @@ class Parameter(MetaTensor):
|
|||
raise RuntimeError("Must set or change parallel mode before any Initializer created.")
|
||||
if self.init_mode is None:
|
||||
return self
|
||||
if self.inited_param is not None:
|
||||
return self.inited_param
|
||||
if layout is not None:
|
||||
if not isinstance(layout, list):
|
||||
raise TypeError("The layout should be list! layout is {}.".format(layout))
|
||||
|
|
|
@ -36,8 +36,8 @@ abstract::AbstractBasePtr MetaTensor::ToAbstract() {
|
|||
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
|
||||
|
||||
// if is parameter always no value.
|
||||
if (is_parameter()) {
|
||||
auto param_name = param_info()->name();
|
||||
if (is_parameter_) {
|
||||
auto param_name = param_info_->name();
|
||||
auto ref_key = std::make_shared<RefKey>(param_name);
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
|
||||
|
|
|
@ -476,8 +476,8 @@ abstract::AbstractBasePtr Tensor::ToAbstract() {
|
|||
auto tensor_shape = tens->shape();
|
||||
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape);
|
||||
// if is parameter always no value.
|
||||
if (is_parameter()) {
|
||||
auto param_name = param_info()->name();
|
||||
if (is_parameter_) {
|
||||
auto param_name = param_info_->name();
|
||||
auto ref_key = std::make_shared<RefKey>(param_name);
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
abs_tensor = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_tensor);
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import context, Tensor, Parameter, ParameterTuple
|
||||
from mindspore import context, Tensor, Parameter, ParameterTuple, nn
|
||||
from mindspore._checkparam import _check_str_by_regular
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
|
@ -229,3 +229,25 @@ def test_parameter_lazy_init():
|
|||
para.set_parameter_data(initializer('ones', [1, 2], mstype.float32), slice_shape=True)
|
||||
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2)))
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_parameter_as_output():
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
initial_input = initializer('One', shape=(2,), dtype=mstype.int32)
|
||||
updated_input = Tensor([2, 2], mstype.int32)
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, initial, updated):
|
||||
super().__init__()
|
||||
self.initial = initial
|
||||
self.updated = updated
|
||||
self.p = Parameter(self.initial, name="weight")
|
||||
self.new_p = self.p.init_data()
|
||||
self.new_p.set_parameter_data(self.updated)
|
||||
def construct(self):
|
||||
return self.new_p
|
||||
|
||||
net = Net(initial_input, updated_input)
|
||||
output = net()
|
||||
assert np.array_equal(output.asnumpy(), np.array([2, 2], np.int32))
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
Loading…
Reference in New Issue