diff --git a/mindspore/ccsrc/frontend/parallel/parameter_manager.cc b/mindspore/ccsrc/frontend/parallel/parameter_manager.cc index f522984065a..516a880e153 100644 --- a/mindspore/ccsrc/frontend/parallel/parameter_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/parameter_manager.cc @@ -432,7 +432,7 @@ void InitOptimizerState(const FuncGraphPtr &root) { auto graph_executor = pipeline::GraphExecutorPy::GetInstance(); MS_EXCEPTION_IF_NULL(graph_executor); auto phase = graph_executor->phase(); - auto py_obj = GetPyParameterObj(param_info, CLONED_OBJ); + auto py_obj = GetPyParameterObj(param_info, OBJ); if (py::isinstance(py_obj)) { MS_LOG(WARNING) << "Parameter: " << parameter->DebugString() << " can't find python obj."; continue; diff --git a/mindspore/python/mindspore/common/parameter.py b/mindspore/python/mindspore/common/parameter.py index 92ecd4be646..eb8ea268e00 100644 --- a/mindspore/python/mindspore/common/parameter.py +++ b/mindspore/python/mindspore/common/parameter.py @@ -491,7 +491,7 @@ class Parameter(Tensor_): else: info.cloned_obj = [x] self.param_info = info - param_info_clone.cloned_obj = x + param_info_clone.obj = x x.param_info = param_info_clone x.is_init = False x.init = self.init diff --git a/tests/ut/python/parallel/test_parameter_clone.py b/tests/ut/python/parallel/test_parameter_clone.py new file mode 100644 index 00000000000..01315145172 --- /dev/null +++ b/tests/ut/python/parallel/test_parameter_clone.py @@ -0,0 +1,41 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mindspore import Tensor +from mindspore import Parameter +from mindspore.common.initializer import Normal +import mindspore as ms + + +def test_parameter_clone(): + """ + Feature: test parameter clone api + Description: assert data and repr + Expectation: success + """ + tensor = Tensor(input_data=None, shape=(16, 32), dtype=ms.float32, init=Normal()) + param = Parameter(tensor, requires_grad=False) + param2 = param.clone() + + data1 = param.asnumpy() + data2 = param2.asnumpy() + repr1 = repr(param2) + assert (data1 == data2).all() + assert "requires_grad=False" in repr1 + assert "shape=(16, 32)" in repr1 + param3 = param2.clone() + data3 = param3.asnumpy() + repr2 = repr(param3) + assert (data1 == data3).all() + assert repr1 == repr2