diff --git a/mindspore/ccsrc/pybind_api/ir/func_graph_py.cc b/mindspore/ccsrc/pybind_api/ir/func_graph_py.cc index a0ed710ff2e..ff79d2842f6 100644 --- a/mindspore/ccsrc/pybind_api/ir/func_graph_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/func_graph_py.cc @@ -20,6 +20,45 @@ #include "pybind_api/api_register.h" namespace mindspore { +py::dict UpdateFuncGraphHyperParams(const FuncGraphPtr &func_graph, const py::dict ¶ms_init) { + py::dict hyper_params; + for (const auto ¶m : func_graph->parameters()) { + auto param_node = param->cast(); + MS_EXCEPTION_IF_NULL(param_node); + py::str param_name = py::str(param_node->name()); + if (param_node->has_default()) { + const char kModelName[] = "mindspore"; + const char kClassName[] = "Parameter"; + const py::module &mod = py::module::import(kModelName); + const py::object &fn = mod.attr(kClassName); + const auto &old_value = param_node->default_param()->cast(); + MS_EXCEPTION_IF_NULL(old_value); + py::object new_param; + + if (params_init.contains(param_name)) { + const auto &new_value = params_init[param_name].cast(); + MS_EXCEPTION_IF_NULL(new_value); + if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) { + MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. " + "The parameter '" + << param_name.cast() << "' has shape " << old_value->shape() + << " and dtype " << TypeIdLabel(old_value->data_type()) + << ", but got the update Tensor with shape " << new_value->shape() << " and dtype " + << TypeIdLabel(new_value->data_type()) << "."; + } + new_param = fn(*new_value); + } else { + new_param = fn(*old_value); + } + auto new_default_param = new_param.cast(); + new_default_param->set_param_info(old_value->param_info()); + param_node->set_default_param(new_default_param); + hyper_params[param_name] = new_param; + } + } + return hyper_params; +} + REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { // Define python "MetaFuncGraph_" class (void)py::class_>(*m, "MetaFuncGraph_") @@ -28,8 +67,11 @@ REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { (void)py::class_(*m, "FuncGraph") .def(py::init()) .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.") - .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph") - .def("update_hyper_params", &FuncGraph::UpdateHyperParams, py::arg("params_init"), - "Update FuncGraph hyper parameters, and return the updated parameters."); + .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph"); + })); +REGISTER_PYBIND_DEFINE(_c_expression, ([](pybind11::module *const m) { + (void)m->def("update_func_graph_hyper_params", &UpdateFuncGraphHyperParams, + py::arg("func_graph"), py::arg("params_init"), + "Update FuncGraph hyper parameters, and return the updated parameters."); })); } // namespace mindspore diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index 9bf15c57962..732b4dcfcc8 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -556,38 +556,6 @@ size_t FuncGraph::GetDefaultValueCount() { return parameter_default_value_.size() - LongToSize(null_count); } -std::map FuncGraph::UpdateHyperParams( - const std::unordered_map ¶ms_init) { - std::map hyper_params; - for (const auto ¶ : parameters_) { - auto param_node = para->cast(); - MS_EXCEPTION_IF_NULL(param_node); - const std::string ¶m_name = param_node->name(); - - if (param_node->has_default()) { - if (params_init.find(param_name) != params_init.end()) { - const auto &old_value = param_node->default_param()->cast(); - const auto &new_value = params_init.at(param_name); - MS_EXCEPTION_IF_NULL(old_value); - MS_EXCEPTION_IF_NULL(new_value); - if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) { - MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. " - "The parameter '" - << param_name << "' has shape " << old_value->shape() << " and dtype " - << TypeIdLabel(old_value->data_type()) << ", but got the update Tensor with shape " - << new_value->shape() << " and dtype " << TypeIdLabel(new_value->data_type()) << "."; - } - - auto new_default_param = std::make_shared(*new_value); - new_default_param->set_param_info(old_value->param_info()); - param_node->set_default_param(new_default_param); - } - hyper_params[param_name] = param_node->default_param(); - } - } - return hyper_params; -} - AnfNodePtr FuncGraph::GetVariableArgParameter() { if (!has_vararg_) { return nullptr; diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 62974c36feb..7ea6e09183c 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -215,8 +215,6 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi void SetDefaultValues(const std::vector &name_list, const std::vector &value_list); void ClearDefaultValues(); size_t GetDefaultValueCount(); - std::map UpdateHyperParams( - const std::unordered_map ¶ms_init); std::map ¶meter_default_value() { return parameter_default_value_; } void set_has_vararg(bool has_) { has_vararg_ = has_; } bool has_vararg() const { return has_vararg_; } diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 5f7479900d9..41892b02f24 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -26,7 +26,7 @@ from mindspore import log as logger from mindspore.common.parameter import PARAMETER_NAME_DEFAULT from mindspore.context import ParallelMode from .. import context -from .._c_expression import init_pipeline, Cell_, FuncGraph, MixedPrecisionType +from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType from .._checkparam import Validator from ..common import dtype as mstype from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor @@ -1703,10 +1703,8 @@ class GraphCell(Cell): raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, " f"but got the key type: {type(name)}, and the value type: {type(value)}") - params_dict = self.graph.update_hyper_params(params_init) - for name, value in params_dict.items(): - param = Parameter(value) - param.param_info = value.param_info + params_dict = update_func_graph_hyper_params(self.graph, params_init) + for name, param in params_dict.items(): self._params[name] = param def construct(self, *inputs): diff --git a/tests/st/export_and_load/test_get_and_init_graph_cell_parameters.py b/tests/st/export_and_load/test_get_and_init_graph_cell_parameters.py new file mode 100644 index 00000000000..7bae5ee81f3 --- /dev/null +++ b/tests/st/export_and_load/test_get_and_init_graph_cell_parameters.py @@ -0,0 +1,107 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""test get and init GraphCell parameters""" +import os + +import numpy as np +import pytest + +from mindspore import Tensor, Parameter +from mindspore import context +from mindspore import export, load, save_checkpoint, load_checkpoint +from mindspore import nn + + +class TestNet(nn.Cell): + def __init__(self): + super(TestNet, self).__init__() + self.flag = False + self.weight = Parameter(np_param, requires_grad=True) + self.dense = nn.Dense(3, 4) + + def construct(self, x, y): + if self.flag: + ret = self.dense(x * self.weight) + else: + ret = x * y * self.weight + self.weight += 1.0 + return ret + + +np_a = np.ones((2, 3), np.float32) + 2 +np_b = np.ones((2, 3), np.float32) + 3 +np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) +input_a = Tensor(np_a) +input_b = Tensor(np_b) + + +def load_mindir_and_update_params(mindir_name, ckpt_name): + net = TestNet() + export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') + + load_net = nn.GraphCell(graph=load(mindir_name)) + ret = load_net(input_a, input_b) + save_checkpoint(load_net, ckpt_name) + assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param) + assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 1.0) + + params_init = load_checkpoint(ckpt_name) + load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init) + return load_net_with_new_params + + +def get_and_init_graph_cell_parameters(): + mindir_name = f"{context.get_context('mode')}_test_graph_cell_net.mindir" + ckpt_name = f"{context.get_context('mode')}_test_graph_cell_net.ckpt" + load_net = load_mindir_and_update_params(mindir_name, ckpt_name) + ret = load_net(input_a, input_b) + assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0)) + assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 2.0) + + if os.path.isfile(mindir_name): + os.remove(mindir_name) + if os.path.isfile(ckpt_name): + os.remove(ckpt_name) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_get_and_init_graph_cell_parameters_in_graph_mode(): + """ + Description: load mind ir and update parameters in graph mode. + Expectation: generate a graph with updated parameters. + """ + context.set_context(mode=context.GRAPH_MODE) + get_and_init_graph_cell_parameters() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_get_and_init_graph_cell_parameters_in_pynative_mode(): + """ + Description: load mind ir and update parameters in pynative mode. + Expectation: generate a graph with updated parameters. + """ + context.set_context(mode=context.PYNATIVE_MODE) + get_and_init_graph_cell_parameters() diff --git a/tests/st/export_and_load/test_get_and_init_GraphCell_parameters.py b/tests/ut/python/mindir/test_init_graph_cell_parameters_with_illegal_data.py similarity index 72% rename from tests/st/export_and_load/test_get_and_init_GraphCell_parameters.py rename to tests/ut/python/mindir/test_init_graph_cell_parameters_with_illegal_data.py index 86de28352e6..1284af08698 100644 --- a/tests/st/export_and_load/test_get_and_init_GraphCell_parameters.py +++ b/tests/ut/python/mindir/test_init_graph_cell_parameters_with_illegal_data.py @@ -13,18 +13,17 @@ # limitations under the License. # ============================================================================ -"""test get and init GraphCell parameters""" +"""test init GraphCell parameters with illegal data""" + +import os import numpy as np import pytest -from mindspore import nn -from mindspore import context from mindspore import Tensor, Parameter -from mindspore import export, load, save_checkpoint, load_checkpoint - - -context.set_context(mode=context.GRAPH_MODE) +from mindspore import context +from mindspore import export, load +from mindspore import nn class Net(nn.Cell): @@ -50,44 +49,17 @@ input_a = Tensor(np_a) input_b = Tensor(np_b) -def load_mindir_and_update_params(): - net = Net() - mindir_name = "net_0.mindir" - export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') - - load_net = nn.GraphCell(graph=load(mindir_name)) - ret = load_net(input_a, input_b) - assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param) - - ckpt_name = "net_0.ckpt" - save_checkpoint(load_net, ckpt_name) - params_init = load_checkpoint(ckpt_name) - load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init) - return load_net_with_new_params +def remove_generated_file(file_name): + if os.path.isfile(file_name): + os.remove(file_name) -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_get_and_init_graph_cell_parameters(): - """ - Description: load mind ir and update parameters. - Expectation: generate a graph with updated parameters. - """ - load_net = load_mindir_and_update_params() - ret = load_net(input_a, input_b) - - assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0)) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard def test_init_graph_cell_parameters_with_wrong_type(): """ Description: load mind ir and update parameters with wrong type. Expectation: raise a ValueError indicating the params type error. """ + context.set_context(mode=context.GRAPH_MODE) net = Net() mindir_name = "net_1.mindir" export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') @@ -99,16 +71,15 @@ def test_init_graph_cell_parameters_with_wrong_type(): load_net(input_a, input_b) assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value) + remove_generated_file(mindir_name) -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard def test_init_graph_cell_parameters_with_wrong_shape(): """ Description: load mind ir and update parameters with wrong tensor shape. Expectation: raise a ValueError indicating the tensor shape error. """ + context.set_context(mode=context.PYNATIVE_MODE) net = Net() mindir_name = "net_2.mindir" export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') @@ -120,16 +91,15 @@ def test_init_graph_cell_parameters_with_wrong_shape(): load_net(input_a, input_b) assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value) + remove_generated_file(mindir_name) -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard def test_init_graph_cell_parameters_with_wrong_dtype(): """ Description: load mind ir and update parameters with wrong tensor dtype. Expectation: raise a ValueError indicating the tensor dtype error. """ + context.set_context(mode=context.GRAPH_MODE) net = Net() mindir_name = "net_3.mindir" export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') @@ -141,3 +111,4 @@ def test_init_graph_cell_parameters_with_wrong_dtype(): load_net(input_a, input_b) assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value) + remove_generated_file(mindir_name)