forked from mindspore-Ecosystem/mindspore
fix sync GraphCell params data from device bug
This commit is contained in:
parent
258809d98f
commit
d7dc539b6e
|
@ -20,6 +20,45 @@
|
|||
#include "pybind_api/api_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
py::dict UpdateFuncGraphHyperParams(const FuncGraphPtr &func_graph, const py::dict ¶ms_init) {
|
||||
py::dict hyper_params;
|
||||
for (const auto ¶m : func_graph->parameters()) {
|
||||
auto param_node = param->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
py::str param_name = py::str(param_node->name());
|
||||
if (param_node->has_default()) {
|
||||
const char kModelName[] = "mindspore";
|
||||
const char kClassName[] = "Parameter";
|
||||
const py::module &mod = py::module::import(kModelName);
|
||||
const py::object &fn = mod.attr(kClassName);
|
||||
const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(old_value);
|
||||
py::object new_param;
|
||||
|
||||
if (params_init.contains(param_name)) {
|
||||
const auto &new_value = params_init[param_name].cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_value);
|
||||
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
|
||||
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
|
||||
"The parameter '"
|
||||
<< param_name.cast<std::string>() << "' has shape " << old_value->shape()
|
||||
<< " and dtype " << TypeIdLabel(old_value->data_type())
|
||||
<< ", but got the update Tensor with shape " << new_value->shape() << " and dtype "
|
||||
<< TypeIdLabel(new_value->data_type()) << ".";
|
||||
}
|
||||
new_param = fn(*new_value);
|
||||
} else {
|
||||
new_param = fn(*old_value);
|
||||
}
|
||||
auto new_default_param = new_param.cast<tensor::TensorPtr>();
|
||||
new_default_param->set_param_info(old_value->param_info());
|
||||
param_node->set_default_param(new_default_param);
|
||||
hyper_params[param_name] = new_param;
|
||||
}
|
||||
}
|
||||
return hyper_params;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
|
||||
// Define python "MetaFuncGraph_" class
|
||||
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
|
||||
|
@ -28,8 +67,11 @@ REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
|
|||
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
|
||||
.def(py::init())
|
||||
.def("str", &FuncGraph::ToString, "Get FuncGraph string representation.")
|
||||
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph")
|
||||
.def("update_hyper_params", &FuncGraph::UpdateHyperParams, py::arg("params_init"),
|
||||
"Update FuncGraph hyper parameters, and return the updated parameters.");
|
||||
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph");
|
||||
}));
|
||||
REGISTER_PYBIND_DEFINE(_c_expression, ([](pybind11::module *const m) {
|
||||
(void)m->def("update_func_graph_hyper_params", &UpdateFuncGraphHyperParams,
|
||||
py::arg("func_graph"), py::arg("params_init"),
|
||||
"Update FuncGraph hyper parameters, and return the updated parameters.");
|
||||
}));
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -556,38 +556,6 @@ size_t FuncGraph::GetDefaultValueCount() {
|
|||
return parameter_default_value_.size() - LongToSize(null_count);
|
||||
}
|
||||
|
||||
std::map<std::string, ValuePtr> FuncGraph::UpdateHyperParams(
|
||||
const std::unordered_map<std::string, tensor::TensorPtr> ¶ms_init) {
|
||||
std::map<std::string, ValuePtr> hyper_params;
|
||||
for (const auto ¶ : parameters_) {
|
||||
auto param_node = para->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
const std::string ¶m_name = param_node->name();
|
||||
|
||||
if (param_node->has_default()) {
|
||||
if (params_init.find(param_name) != params_init.end()) {
|
||||
const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>();
|
||||
const auto &new_value = params_init.at(param_name);
|
||||
MS_EXCEPTION_IF_NULL(old_value);
|
||||
MS_EXCEPTION_IF_NULL(new_value);
|
||||
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
|
||||
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
|
||||
"The parameter '"
|
||||
<< param_name << "' has shape " << old_value->shape() << " and dtype "
|
||||
<< TypeIdLabel(old_value->data_type()) << ", but got the update Tensor with shape "
|
||||
<< new_value->shape() << " and dtype " << TypeIdLabel(new_value->data_type()) << ".";
|
||||
}
|
||||
|
||||
auto new_default_param = std::make_shared<tensor::Tensor>(*new_value);
|
||||
new_default_param->set_param_info(old_value->param_info());
|
||||
param_node->set_default_param(new_default_param);
|
||||
}
|
||||
hyper_params[param_name] = param_node->default_param();
|
||||
}
|
||||
}
|
||||
return hyper_params;
|
||||
}
|
||||
|
||||
AnfNodePtr FuncGraph::GetVariableArgParameter() {
|
||||
if (!has_vararg_) {
|
||||
return nullptr;
|
||||
|
|
|
@ -215,8 +215,6 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi
|
|||
void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
|
||||
void ClearDefaultValues();
|
||||
size_t GetDefaultValueCount();
|
||||
std::map<std::string, ValuePtr> UpdateHyperParams(
|
||||
const std::unordered_map<std::string, tensor::TensorPtr> ¶ms_init);
|
||||
std::map<std::string, AnfNodePtr> ¶meter_default_value() { return parameter_default_value_; }
|
||||
void set_has_vararg(bool has_) { has_vararg_ = has_; }
|
||||
bool has_vararg() const { return has_vararg_; }
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore import log as logger
|
|||
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
||||
from mindspore.context import ParallelMode
|
||||
from .. import context
|
||||
from .._c_expression import init_pipeline, Cell_, FuncGraph, MixedPrecisionType
|
||||
from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
||||
from .._checkparam import Validator
|
||||
from ..common import dtype as mstype
|
||||
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor
|
||||
|
@ -1703,10 +1703,8 @@ class GraphCell(Cell):
|
|||
raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, "
|
||||
f"but got the key type: {type(name)}, and the value type: {type(value)}")
|
||||
|
||||
params_dict = self.graph.update_hyper_params(params_init)
|
||||
for name, value in params_dict.items():
|
||||
param = Parameter(value)
|
||||
param.param_info = value.param_info
|
||||
params_dict = update_func_graph_hyper_params(self.graph, params_init)
|
||||
for name, param in params_dict.items():
|
||||
self._params[name] = param
|
||||
|
||||
def construct(self, *inputs):
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""test get and init GraphCell parameters"""
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import context
|
||||
from mindspore import export, load, save_checkpoint, load_checkpoint
|
||||
from mindspore import nn
|
||||
|
||||
|
||||
class TestNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TestNet, self).__init__()
|
||||
self.flag = False
|
||||
self.weight = Parameter(np_param, requires_grad=True)
|
||||
self.dense = nn.Dense(3, 4)
|
||||
|
||||
def construct(self, x, y):
|
||||
if self.flag:
|
||||
ret = self.dense(x * self.weight)
|
||||
else:
|
||||
ret = x * y * self.weight
|
||||
self.weight += 1.0
|
||||
return ret
|
||||
|
||||
|
||||
np_a = np.ones((2, 3), np.float32) + 2
|
||||
np_b = np.ones((2, 3), np.float32) + 3
|
||||
np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)
|
||||
input_a = Tensor(np_a)
|
||||
input_b = Tensor(np_b)
|
||||
|
||||
|
||||
def load_mindir_and_update_params(mindir_name, ckpt_name):
|
||||
net = TestNet()
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
||||
load_net = nn.GraphCell(graph=load(mindir_name))
|
||||
ret = load_net(input_a, input_b)
|
||||
save_checkpoint(load_net, ckpt_name)
|
||||
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
|
||||
assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 1.0)
|
||||
|
||||
params_init = load_checkpoint(ckpt_name)
|
||||
load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init)
|
||||
return load_net_with_new_params
|
||||
|
||||
|
||||
def get_and_init_graph_cell_parameters():
|
||||
mindir_name = f"{context.get_context('mode')}_test_graph_cell_net.mindir"
|
||||
ckpt_name = f"{context.get_context('mode')}_test_graph_cell_net.ckpt"
|
||||
load_net = load_mindir_and_update_params(mindir_name, ckpt_name)
|
||||
ret = load_net(input_a, input_b)
|
||||
assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0))
|
||||
assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 2.0)
|
||||
|
||||
if os.path.isfile(mindir_name):
|
||||
os.remove(mindir_name)
|
||||
if os.path.isfile(ckpt_name):
|
||||
os.remove(ckpt_name)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_get_and_init_graph_cell_parameters_in_graph_mode():
|
||||
"""
|
||||
Description: load mind ir and update parameters in graph mode.
|
||||
Expectation: generate a graph with updated parameters.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
get_and_init_graph_cell_parameters()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_get_and_init_graph_cell_parameters_in_pynative_mode():
|
||||
"""
|
||||
Description: load mind ir and update parameters in pynative mode.
|
||||
Expectation: generate a graph with updated parameters.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
get_and_init_graph_cell_parameters()
|
|
@ -13,18 +13,17 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""test get and init GraphCell parameters"""
|
||||
"""test init GraphCell parameters with illegal data"""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import export, load, save_checkpoint, load_checkpoint
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
from mindspore import context
|
||||
from mindspore import export, load
|
||||
from mindspore import nn
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -50,44 +49,17 @@ input_a = Tensor(np_a)
|
|||
input_b = Tensor(np_b)
|
||||
|
||||
|
||||
def load_mindir_and_update_params():
|
||||
net = Net()
|
||||
mindir_name = "net_0.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
||||
load_net = nn.GraphCell(graph=load(mindir_name))
|
||||
ret = load_net(input_a, input_b)
|
||||
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
|
||||
|
||||
ckpt_name = "net_0.ckpt"
|
||||
save_checkpoint(load_net, ckpt_name)
|
||||
params_init = load_checkpoint(ckpt_name)
|
||||
load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init)
|
||||
return load_net_with_new_params
|
||||
def remove_generated_file(file_name):
|
||||
if os.path.isfile(file_name):
|
||||
os.remove(file_name)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_get_and_init_graph_cell_parameters():
|
||||
"""
|
||||
Description: load mind ir and update parameters.
|
||||
Expectation: generate a graph with updated parameters.
|
||||
"""
|
||||
load_net = load_mindir_and_update_params()
|
||||
ret = load_net(input_a, input_b)
|
||||
|
||||
assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_init_graph_cell_parameters_with_wrong_type():
|
||||
"""
|
||||
Description: load mind ir and update parameters with wrong type.
|
||||
Expectation: raise a ValueError indicating the params type error.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
mindir_name = "net_1.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
@ -99,16 +71,15 @@ def test_init_graph_cell_parameters_with_wrong_type():
|
|||
load_net(input_a, input_b)
|
||||
|
||||
assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value)
|
||||
remove_generated_file(mindir_name)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_init_graph_cell_parameters_with_wrong_shape():
|
||||
"""
|
||||
Description: load mind ir and update parameters with wrong tensor shape.
|
||||
Expectation: raise a ValueError indicating the tensor shape error.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net = Net()
|
||||
mindir_name = "net_2.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
@ -120,16 +91,15 @@ def test_init_graph_cell_parameters_with_wrong_shape():
|
|||
load_net(input_a, input_b)
|
||||
|
||||
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
|
||||
remove_generated_file(mindir_name)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_init_graph_cell_parameters_with_wrong_dtype():
|
||||
"""
|
||||
Description: load mind ir and update parameters with wrong tensor dtype.
|
||||
Expectation: raise a ValueError indicating the tensor dtype error.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
mindir_name = "net_3.mindir"
|
||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||
|
@ -141,3 +111,4 @@ def test_init_graph_cell_parameters_with_wrong_dtype():
|
|||
load_net(input_a, input_b)
|
||||
|
||||
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
|
||||
remove_generated_file(mindir_name)
|
Loading…
Reference in New Issue