fix sync GraphCell params data from device bug

This commit is contained in:
buxue 2021-12-02 16:41:04 +08:00
parent 258809d98f
commit d7dc539b6e
6 changed files with 170 additions and 86 deletions

View File

@ -20,6 +20,45 @@
#include "pybind_api/api_register.h" #include "pybind_api/api_register.h"
namespace mindspore { namespace mindspore {
py::dict UpdateFuncGraphHyperParams(const FuncGraphPtr &func_graph, const py::dict &params_init) {
py::dict hyper_params;
for (const auto &param : func_graph->parameters()) {
auto param_node = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_node);
py::str param_name = py::str(param_node->name());
if (param_node->has_default()) {
const char kModelName[] = "mindspore";
const char kClassName[] = "Parameter";
const py::module &mod = py::module::import(kModelName);
const py::object &fn = mod.attr(kClassName);
const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(old_value);
py::object new_param;
if (params_init.contains(param_name)) {
const auto &new_value = params_init[param_name].cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(new_value);
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
"The parameter '"
<< param_name.cast<std::string>() << "' has shape " << old_value->shape()
<< " and dtype " << TypeIdLabel(old_value->data_type())
<< ", but got the update Tensor with shape " << new_value->shape() << " and dtype "
<< TypeIdLabel(new_value->data_type()) << ".";
}
new_param = fn(*new_value);
} else {
new_param = fn(*old_value);
}
auto new_default_param = new_param.cast<tensor::TensorPtr>();
new_default_param->set_param_info(old_value->param_info());
param_node->set_default_param(new_default_param);
hyper_params[param_name] = new_param;
}
}
return hyper_params;
}
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
// Define python "MetaFuncGraph_" class // Define python "MetaFuncGraph_" class
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_") (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
@ -28,8 +67,11 @@ REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph") (void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
.def(py::init()) .def(py::init())
.def("str", &FuncGraph::ToString, "Get FuncGraph string representation.") .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.")
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph") .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph");
.def("update_hyper_params", &FuncGraph::UpdateHyperParams, py::arg("params_init"), }));
"Update FuncGraph hyper parameters, and return the updated parameters."); REGISTER_PYBIND_DEFINE(_c_expression, ([](pybind11::module *const m) {
(void)m->def("update_func_graph_hyper_params", &UpdateFuncGraphHyperParams,
py::arg("func_graph"), py::arg("params_init"),
"Update FuncGraph hyper parameters, and return the updated parameters.");
})); }));
} // namespace mindspore } // namespace mindspore

View File

@ -556,38 +556,6 @@ size_t FuncGraph::GetDefaultValueCount() {
return parameter_default_value_.size() - LongToSize(null_count); return parameter_default_value_.size() - LongToSize(null_count);
} }
std::map<std::string, ValuePtr> FuncGraph::UpdateHyperParams(
const std::unordered_map<std::string, tensor::TensorPtr> &params_init) {
std::map<std::string, ValuePtr> hyper_params;
for (const auto &para : parameters_) {
auto param_node = para->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_node);
const std::string &param_name = param_node->name();
if (param_node->has_default()) {
if (params_init.find(param_name) != params_init.end()) {
const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>();
const auto &new_value = params_init.at(param_name);
MS_EXCEPTION_IF_NULL(old_value);
MS_EXCEPTION_IF_NULL(new_value);
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
"The parameter '"
<< param_name << "' has shape " << old_value->shape() << " and dtype "
<< TypeIdLabel(old_value->data_type()) << ", but got the update Tensor with shape "
<< new_value->shape() << " and dtype " << TypeIdLabel(new_value->data_type()) << ".";
}
auto new_default_param = std::make_shared<tensor::Tensor>(*new_value);
new_default_param->set_param_info(old_value->param_info());
param_node->set_default_param(new_default_param);
}
hyper_params[param_name] = param_node->default_param();
}
}
return hyper_params;
}
AnfNodePtr FuncGraph::GetVariableArgParameter() { AnfNodePtr FuncGraph::GetVariableArgParameter() {
if (!has_vararg_) { if (!has_vararg_) {
return nullptr; return nullptr;

View File

@ -215,8 +215,6 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi
void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list); void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
void ClearDefaultValues(); void ClearDefaultValues();
size_t GetDefaultValueCount(); size_t GetDefaultValueCount();
std::map<std::string, ValuePtr> UpdateHyperParams(
const std::unordered_map<std::string, tensor::TensorPtr> &params_init);
std::map<std::string, AnfNodePtr> &parameter_default_value() { return parameter_default_value_; } std::map<std::string, AnfNodePtr> &parameter_default_value() { return parameter_default_value_; }
void set_has_vararg(bool has_) { has_vararg_ = has_; } void set_has_vararg(bool has_) { has_vararg_ = has_; }
bool has_vararg() const { return has_vararg_; } bool has_vararg() const { return has_vararg_; }

View File

@ -26,7 +26,7 @@ from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from .. import context from .. import context
from .._c_expression import init_pipeline, Cell_, FuncGraph, MixedPrecisionType from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
from .._checkparam import Validator from .._checkparam import Validator
from ..common import dtype as mstype from ..common import dtype as mstype
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor
@ -1703,10 +1703,8 @@ class GraphCell(Cell):
raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, " raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, "
f"but got the key type: {type(name)}, and the value type: {type(value)}") f"but got the key type: {type(name)}, and the value type: {type(value)}")
params_dict = self.graph.update_hyper_params(params_init) params_dict = update_func_graph_hyper_params(self.graph, params_init)
for name, value in params_dict.items(): for name, param in params_dict.items():
param = Parameter(value)
param.param_info = value.param_info
self._params[name] = param self._params[name] = param
def construct(self, *inputs): def construct(self, *inputs):

View File

@ -0,0 +1,107 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test get and init GraphCell parameters"""
import os
import numpy as np
import pytest
from mindspore import Tensor, Parameter
from mindspore import context
from mindspore import export, load, save_checkpoint, load_checkpoint
from mindspore import nn
class TestNet(nn.Cell):
def __init__(self):
super(TestNet, self).__init__()
self.flag = False
self.weight = Parameter(np_param, requires_grad=True)
self.dense = nn.Dense(3, 4)
def construct(self, x, y):
if self.flag:
ret = self.dense(x * self.weight)
else:
ret = x * y * self.weight
self.weight += 1.0
return ret
np_a = np.ones((2, 3), np.float32) + 2
np_b = np.ones((2, 3), np.float32) + 3
np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)
input_a = Tensor(np_a)
input_b = Tensor(np_b)
def load_mindir_and_update_params(mindir_name, ckpt_name):
net = TestNet()
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
load_net = nn.GraphCell(graph=load(mindir_name))
ret = load_net(input_a, input_b)
save_checkpoint(load_net, ckpt_name)
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 1.0)
params_init = load_checkpoint(ckpt_name)
load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init)
return load_net_with_new_params
def get_and_init_graph_cell_parameters():
mindir_name = f"{context.get_context('mode')}_test_graph_cell_net.mindir"
ckpt_name = f"{context.get_context('mode')}_test_graph_cell_net.ckpt"
load_net = load_mindir_and_update_params(mindir_name, ckpt_name)
ret = load_net(input_a, input_b)
assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0))
assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 2.0)
if os.path.isfile(mindir_name):
os.remove(mindir_name)
if os.path.isfile(ckpt_name):
os.remove(ckpt_name)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_get_and_init_graph_cell_parameters_in_graph_mode():
"""
Description: load mind ir and update parameters in graph mode.
Expectation: generate a graph with updated parameters.
"""
context.set_context(mode=context.GRAPH_MODE)
get_and_init_graph_cell_parameters()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_get_and_init_graph_cell_parameters_in_pynative_mode():
"""
Description: load mind ir and update parameters in pynative mode.
Expectation: generate a graph with updated parameters.
"""
context.set_context(mode=context.PYNATIVE_MODE)
get_and_init_graph_cell_parameters()

View File

@ -13,18 +13,17 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test get and init GraphCell parameters""" """test init GraphCell parameters with illegal data"""
import os
import numpy as np import numpy as np
import pytest import pytest
from mindspore import nn
from mindspore import context
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore import export, load, save_checkpoint, load_checkpoint from mindspore import context
from mindspore import export, load
from mindspore import nn
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell): class Net(nn.Cell):
@ -50,44 +49,17 @@ input_a = Tensor(np_a)
input_b = Tensor(np_b) input_b = Tensor(np_b)
def load_mindir_and_update_params(): def remove_generated_file(file_name):
net = Net() if os.path.isfile(file_name):
mindir_name = "net_0.mindir" os.remove(file_name)
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
load_net = nn.GraphCell(graph=load(mindir_name))
ret = load_net(input_a, input_b)
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
ckpt_name = "net_0.ckpt"
save_checkpoint(load_net, ckpt_name)
params_init = load_checkpoint(ckpt_name)
load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init)
return load_net_with_new_params
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_and_init_graph_cell_parameters():
"""
Description: load mind ir and update parameters.
Expectation: generate a graph with updated parameters.
"""
load_net = load_mindir_and_update_params()
ret = load_net(input_a, input_b)
assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_init_graph_cell_parameters_with_wrong_type(): def test_init_graph_cell_parameters_with_wrong_type():
""" """
Description: load mind ir and update parameters with wrong type. Description: load mind ir and update parameters with wrong type.
Expectation: raise a ValueError indicating the params type error. Expectation: raise a ValueError indicating the params type error.
""" """
context.set_context(mode=context.GRAPH_MODE)
net = Net() net = Net()
mindir_name = "net_1.mindir" mindir_name = "net_1.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
@ -99,16 +71,15 @@ def test_init_graph_cell_parameters_with_wrong_type():
load_net(input_a, input_b) load_net(input_a, input_b)
assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value) assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value)
remove_generated_file(mindir_name)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_init_graph_cell_parameters_with_wrong_shape(): def test_init_graph_cell_parameters_with_wrong_shape():
""" """
Description: load mind ir and update parameters with wrong tensor shape. Description: load mind ir and update parameters with wrong tensor shape.
Expectation: raise a ValueError indicating the tensor shape error. Expectation: raise a ValueError indicating the tensor shape error.
""" """
context.set_context(mode=context.PYNATIVE_MODE)
net = Net() net = Net()
mindir_name = "net_2.mindir" mindir_name = "net_2.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
@ -120,16 +91,15 @@ def test_init_graph_cell_parameters_with_wrong_shape():
load_net(input_a, input_b) load_net(input_a, input_b)
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value) assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
remove_generated_file(mindir_name)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_init_graph_cell_parameters_with_wrong_dtype(): def test_init_graph_cell_parameters_with_wrong_dtype():
""" """
Description: load mind ir and update parameters with wrong tensor dtype. Description: load mind ir and update parameters with wrong tensor dtype.
Expectation: raise a ValueError indicating the tensor dtype error. Expectation: raise a ValueError indicating the tensor dtype error.
""" """
context.set_context(mode=context.GRAPH_MODE)
net = Net() net = Net()
mindir_name = "net_3.mindir" mindir_name = "net_3.mindir"
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR') export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
@ -141,3 +111,4 @@ def test_init_graph_cell_parameters_with_wrong_dtype():
load_net(input_a, input_b) load_net(input_a, input_b)
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value) assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
remove_generated_file(mindir_name)