forked from mindspore-Ecosystem/mindspore
fix sync GraphCell params data from device bug
This commit is contained in:
parent
258809d98f
commit
d7dc539b6e
|
@ -20,6 +20,45 @@
|
||||||
#include "pybind_api/api_register.h"
|
#include "pybind_api/api_register.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
py::dict UpdateFuncGraphHyperParams(const FuncGraphPtr &func_graph, const py::dict ¶ms_init) {
|
||||||
|
py::dict hyper_params;
|
||||||
|
for (const auto ¶m : func_graph->parameters()) {
|
||||||
|
auto param_node = param->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(param_node);
|
||||||
|
py::str param_name = py::str(param_node->name());
|
||||||
|
if (param_node->has_default()) {
|
||||||
|
const char kModelName[] = "mindspore";
|
||||||
|
const char kClassName[] = "Parameter";
|
||||||
|
const py::module &mod = py::module::import(kModelName);
|
||||||
|
const py::object &fn = mod.attr(kClassName);
|
||||||
|
const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(old_value);
|
||||||
|
py::object new_param;
|
||||||
|
|
||||||
|
if (params_init.contains(param_name)) {
|
||||||
|
const auto &new_value = params_init[param_name].cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(new_value);
|
||||||
|
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
|
||||||
|
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
|
||||||
|
"The parameter '"
|
||||||
|
<< param_name.cast<std::string>() << "' has shape " << old_value->shape()
|
||||||
|
<< " and dtype " << TypeIdLabel(old_value->data_type())
|
||||||
|
<< ", but got the update Tensor with shape " << new_value->shape() << " and dtype "
|
||||||
|
<< TypeIdLabel(new_value->data_type()) << ".";
|
||||||
|
}
|
||||||
|
new_param = fn(*new_value);
|
||||||
|
} else {
|
||||||
|
new_param = fn(*old_value);
|
||||||
|
}
|
||||||
|
auto new_default_param = new_param.cast<tensor::TensorPtr>();
|
||||||
|
new_default_param->set_param_info(old_value->param_info());
|
||||||
|
param_node->set_default_param(new_default_param);
|
||||||
|
hyper_params[param_name] = new_param;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hyper_params;
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
|
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
|
||||||
// Define python "MetaFuncGraph_" class
|
// Define python "MetaFuncGraph_" class
|
||||||
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
|
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
|
||||||
|
@ -28,8 +67,11 @@ REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
|
||||||
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
|
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
|
||||||
.def(py::init())
|
.def(py::init())
|
||||||
.def("str", &FuncGraph::ToString, "Get FuncGraph string representation.")
|
.def("str", &FuncGraph::ToString, "Get FuncGraph string representation.")
|
||||||
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph")
|
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph");
|
||||||
.def("update_hyper_params", &FuncGraph::UpdateHyperParams, py::arg("params_init"),
|
}));
|
||||||
"Update FuncGraph hyper parameters, and return the updated parameters.");
|
REGISTER_PYBIND_DEFINE(_c_expression, ([](pybind11::module *const m) {
|
||||||
|
(void)m->def("update_func_graph_hyper_params", &UpdateFuncGraphHyperParams,
|
||||||
|
py::arg("func_graph"), py::arg("params_init"),
|
||||||
|
"Update FuncGraph hyper parameters, and return the updated parameters.");
|
||||||
}));
|
}));
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -556,38 +556,6 @@ size_t FuncGraph::GetDefaultValueCount() {
|
||||||
return parameter_default_value_.size() - LongToSize(null_count);
|
return parameter_default_value_.size() - LongToSize(null_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, ValuePtr> FuncGraph::UpdateHyperParams(
|
|
||||||
const std::unordered_map<std::string, tensor::TensorPtr> ¶ms_init) {
|
|
||||||
std::map<std::string, ValuePtr> hyper_params;
|
|
||||||
for (const auto ¶ : parameters_) {
|
|
||||||
auto param_node = para->cast<ParameterPtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(param_node);
|
|
||||||
const std::string ¶m_name = param_node->name();
|
|
||||||
|
|
||||||
if (param_node->has_default()) {
|
|
||||||
if (params_init.find(param_name) != params_init.end()) {
|
|
||||||
const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>();
|
|
||||||
const auto &new_value = params_init.at(param_name);
|
|
||||||
MS_EXCEPTION_IF_NULL(old_value);
|
|
||||||
MS_EXCEPTION_IF_NULL(new_value);
|
|
||||||
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
|
|
||||||
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
|
|
||||||
"The parameter '"
|
|
||||||
<< param_name << "' has shape " << old_value->shape() << " and dtype "
|
|
||||||
<< TypeIdLabel(old_value->data_type()) << ", but got the update Tensor with shape "
|
|
||||||
<< new_value->shape() << " and dtype " << TypeIdLabel(new_value->data_type()) << ".";
|
|
||||||
}
|
|
||||||
|
|
||||||
auto new_default_param = std::make_shared<tensor::Tensor>(*new_value);
|
|
||||||
new_default_param->set_param_info(old_value->param_info());
|
|
||||||
param_node->set_default_param(new_default_param);
|
|
||||||
}
|
|
||||||
hyper_params[param_name] = param_node->default_param();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return hyper_params;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr FuncGraph::GetVariableArgParameter() {
|
AnfNodePtr FuncGraph::GetVariableArgParameter() {
|
||||||
if (!has_vararg_) {
|
if (!has_vararg_) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -215,8 +215,6 @@ class FuncGraph : public deprecated::api::FuncGraph, public FuncGraphBase, publi
|
||||||
void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
|
void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
|
||||||
void ClearDefaultValues();
|
void ClearDefaultValues();
|
||||||
size_t GetDefaultValueCount();
|
size_t GetDefaultValueCount();
|
||||||
std::map<std::string, ValuePtr> UpdateHyperParams(
|
|
||||||
const std::unordered_map<std::string, tensor::TensorPtr> ¶ms_init);
|
|
||||||
std::map<std::string, AnfNodePtr> ¶meter_default_value() { return parameter_default_value_; }
|
std::map<std::string, AnfNodePtr> ¶meter_default_value() { return parameter_default_value_; }
|
||||||
void set_has_vararg(bool has_) { has_vararg_ = has_; }
|
void set_has_vararg(bool has_) { has_vararg_ = has_; }
|
||||||
bool has_vararg() const { return has_vararg_; }
|
bool has_vararg() const { return has_vararg_; }
|
||||||
|
|
|
@ -26,7 +26,7 @@ from mindspore import log as logger
|
||||||
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from .. import context
|
from .. import context
|
||||||
from .._c_expression import init_pipeline, Cell_, FuncGraph, MixedPrecisionType
|
from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
|
||||||
from .._checkparam import Validator
|
from .._checkparam import Validator
|
||||||
from ..common import dtype as mstype
|
from ..common import dtype as mstype
|
||||||
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor
|
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor
|
||||||
|
@ -1703,10 +1703,8 @@ class GraphCell(Cell):
|
||||||
raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, "
|
raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, "
|
||||||
f"but got the key type: {type(name)}, and the value type: {type(value)}")
|
f"but got the key type: {type(name)}, and the value type: {type(value)}")
|
||||||
|
|
||||||
params_dict = self.graph.update_hyper_params(params_init)
|
params_dict = update_func_graph_hyper_params(self.graph, params_init)
|
||||||
for name, value in params_dict.items():
|
for name, param in params_dict.items():
|
||||||
param = Parameter(value)
|
|
||||||
param.param_info = value.param_info
|
|
||||||
self._params[name] = param
|
self._params[name] = param
|
||||||
|
|
||||||
def construct(self, *inputs):
|
def construct(self, *inputs):
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""test get and init GraphCell parameters"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import export, load, save_checkpoint, load_checkpoint
|
||||||
|
from mindspore import nn
|
||||||
|
|
||||||
|
|
||||||
|
class TestNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(TestNet, self).__init__()
|
||||||
|
self.flag = False
|
||||||
|
self.weight = Parameter(np_param, requires_grad=True)
|
||||||
|
self.dense = nn.Dense(3, 4)
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
if self.flag:
|
||||||
|
ret = self.dense(x * self.weight)
|
||||||
|
else:
|
||||||
|
ret = x * y * self.weight
|
||||||
|
self.weight += 1.0
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
np_a = np.ones((2, 3), np.float32) + 2
|
||||||
|
np_b = np.ones((2, 3), np.float32) + 3
|
||||||
|
np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)
|
||||||
|
input_a = Tensor(np_a)
|
||||||
|
input_b = Tensor(np_b)
|
||||||
|
|
||||||
|
|
||||||
|
def load_mindir_and_update_params(mindir_name, ckpt_name):
|
||||||
|
net = TestNet()
|
||||||
|
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||||
|
|
||||||
|
load_net = nn.GraphCell(graph=load(mindir_name))
|
||||||
|
ret = load_net(input_a, input_b)
|
||||||
|
save_checkpoint(load_net, ckpt_name)
|
||||||
|
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
|
||||||
|
assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 1.0)
|
||||||
|
|
||||||
|
params_init = load_checkpoint(ckpt_name)
|
||||||
|
load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init)
|
||||||
|
return load_net_with_new_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_and_init_graph_cell_parameters():
|
||||||
|
mindir_name = f"{context.get_context('mode')}_test_graph_cell_net.mindir"
|
||||||
|
ckpt_name = f"{context.get_context('mode')}_test_graph_cell_net.ckpt"
|
||||||
|
load_net = load_mindir_and_update_params(mindir_name, ckpt_name)
|
||||||
|
ret = load_net(input_a, input_b)
|
||||||
|
assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0))
|
||||||
|
assert np.array_equal(load_net.trainable_params()[0].asnumpy(), np_param + 2.0)
|
||||||
|
|
||||||
|
if os.path.isfile(mindir_name):
|
||||||
|
os.remove(mindir_name)
|
||||||
|
if os.path.isfile(ckpt_name):
|
||||||
|
os.remove(ckpt_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_get_and_init_graph_cell_parameters_in_graph_mode():
|
||||||
|
"""
|
||||||
|
Description: load mind ir and update parameters in graph mode.
|
||||||
|
Expectation: generate a graph with updated parameters.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
get_and_init_graph_cell_parameters()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_get_and_init_graph_cell_parameters_in_pynative_mode():
|
||||||
|
"""
|
||||||
|
Description: load mind ir and update parameters in pynative mode.
|
||||||
|
Expectation: generate a graph with updated parameters.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
get_and_init_graph_cell_parameters()
|
|
@ -13,18 +13,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""test get and init GraphCell parameters"""
|
"""test init GraphCell parameters with illegal data"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mindspore import nn
|
|
||||||
from mindspore import context
|
|
||||||
from mindspore import Tensor, Parameter
|
from mindspore import Tensor, Parameter
|
||||||
from mindspore import export, load, save_checkpoint, load_checkpoint
|
from mindspore import context
|
||||||
|
from mindspore import export, load
|
||||||
|
from mindspore import nn
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
|
@ -50,44 +49,17 @@ input_a = Tensor(np_a)
|
||||||
input_b = Tensor(np_b)
|
input_b = Tensor(np_b)
|
||||||
|
|
||||||
|
|
||||||
def load_mindir_and_update_params():
|
def remove_generated_file(file_name):
|
||||||
net = Net()
|
if os.path.isfile(file_name):
|
||||||
mindir_name = "net_0.mindir"
|
os.remove(file_name)
|
||||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
|
||||||
|
|
||||||
load_net = nn.GraphCell(graph=load(mindir_name))
|
|
||||||
ret = load_net(input_a, input_b)
|
|
||||||
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
|
|
||||||
|
|
||||||
ckpt_name = "net_0.ckpt"
|
|
||||||
save_checkpoint(load_net, ckpt_name)
|
|
||||||
params_init = load_checkpoint(ckpt_name)
|
|
||||||
load_net_with_new_params = nn.GraphCell(graph=load(mindir_name), params_init=params_init)
|
|
||||||
return load_net_with_new_params
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_cpu
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_get_and_init_graph_cell_parameters():
|
|
||||||
"""
|
|
||||||
Description: load mind ir and update parameters.
|
|
||||||
Expectation: generate a graph with updated parameters.
|
|
||||||
"""
|
|
||||||
load_net = load_mindir_and_update_params()
|
|
||||||
ret = load_net(input_a, input_b)
|
|
||||||
|
|
||||||
assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_cpu
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_init_graph_cell_parameters_with_wrong_type():
|
def test_init_graph_cell_parameters_with_wrong_type():
|
||||||
"""
|
"""
|
||||||
Description: load mind ir and update parameters with wrong type.
|
Description: load mind ir and update parameters with wrong type.
|
||||||
Expectation: raise a ValueError indicating the params type error.
|
Expectation: raise a ValueError indicating the params type error.
|
||||||
"""
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
net = Net()
|
net = Net()
|
||||||
mindir_name = "net_1.mindir"
|
mindir_name = "net_1.mindir"
|
||||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||||
|
@ -99,16 +71,15 @@ def test_init_graph_cell_parameters_with_wrong_type():
|
||||||
load_net(input_a, input_b)
|
load_net(input_a, input_b)
|
||||||
|
|
||||||
assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value)
|
assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value)
|
||||||
|
remove_generated_file(mindir_name)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_cpu
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_init_graph_cell_parameters_with_wrong_shape():
|
def test_init_graph_cell_parameters_with_wrong_shape():
|
||||||
"""
|
"""
|
||||||
Description: load mind ir and update parameters with wrong tensor shape.
|
Description: load mind ir and update parameters with wrong tensor shape.
|
||||||
Expectation: raise a ValueError indicating the tensor shape error.
|
Expectation: raise a ValueError indicating the tensor shape error.
|
||||||
"""
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
net = Net()
|
net = Net()
|
||||||
mindir_name = "net_2.mindir"
|
mindir_name = "net_2.mindir"
|
||||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||||
|
@ -120,16 +91,15 @@ def test_init_graph_cell_parameters_with_wrong_shape():
|
||||||
load_net(input_a, input_b)
|
load_net(input_a, input_b)
|
||||||
|
|
||||||
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
|
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
|
||||||
|
remove_generated_file(mindir_name)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_cpu
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_init_graph_cell_parameters_with_wrong_dtype():
|
def test_init_graph_cell_parameters_with_wrong_dtype():
|
||||||
"""
|
"""
|
||||||
Description: load mind ir and update parameters with wrong tensor dtype.
|
Description: load mind ir and update parameters with wrong tensor dtype.
|
||||||
Expectation: raise a ValueError indicating the tensor dtype error.
|
Expectation: raise a ValueError indicating the tensor dtype error.
|
||||||
"""
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
net = Net()
|
net = Net()
|
||||||
mindir_name = "net_3.mindir"
|
mindir_name = "net_3.mindir"
|
||||||
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
export(net, input_a, input_b, file_name=mindir_name[:-7], file_format='MINDIR')
|
||||||
|
@ -141,3 +111,4 @@ def test_init_graph_cell_parameters_with_wrong_dtype():
|
||||||
load_net(input_a, input_b)
|
load_net(input_a, input_b)
|
||||||
|
|
||||||
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
|
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
|
||||||
|
remove_generated_file(mindir_name)
|
Loading…
Reference in New Issue