support upodate parameters after load mindir

This commit is contained in:
buxue 2021-11-11 16:32:17 +08:00
parent 2b2194b0f4
commit 45e0a4b9bf
10 changed files with 229 additions and 23 deletions

View File

@ -124,9 +124,8 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
(py::object)
m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), "Load model as Graph.");
(void)m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), "Load model as Graph.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")

View File

@ -28,6 +28,8 @@ REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
.def(py::init())
.def("str", &FuncGraph::ToString, "Get FuncGraph string representation.")
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph");
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph")
.def("update_hyper_params", &FuncGraph::UpdateHyperParams, py::arg("params_init"),
"Update FuncGraph hyper parameters, and return the updated parameters.");
}));
} // namespace mindspore

View File

@ -556,6 +556,38 @@ size_t FuncGraph::GetDefaultValueCount() {
return parameter_default_value_.size() - LongToSize(null_count);
}
std::map<std::string, ValuePtr> FuncGraph::UpdateHyperParams(
const std::unordered_map<std::string, tensor::TensorPtr> &params_init) {
std::map<std::string, ValuePtr> hyper_params;
for (const auto &para : parameters_) {
auto param_node = para->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_node);
const std::string &param_name = param_node->name();
if (param_node->has_default()) {
if (params_init.find(param_name) != params_init.end()) {
const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>();
const auto &new_value = params_init.at(param_name);
MS_EXCEPTION_IF_NULL(old_value);
MS_EXCEPTION_IF_NULL(new_value);
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
"The parameter '"
<< param_name << "' has shape " << old_value->shape() << " and dtype "
<< TypeIdLabel(old_value->data_type()) << ", but got the update Tensor with shape "
<< new_value->shape() << " and dtype " << TypeIdLabel(new_value->data_type()) << ".";
}
auto new_default_param = std::make_shared<tensor::Tensor>(*new_value);
new_default_param->set_param_info(old_value->param_info());
param_node->set_default_param(new_default_param);
}
hyper_params[param_name] = param_node->default_param();
}
}
return hyper_params;
}
AnfNodePtr FuncGraph::GetVariableArgParameter() {
if (!has_vararg_) {
return nullptr;

View File

@ -214,6 +214,8 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
void ClearDefaultValues();
size_t GetDefaultValueCount();
std::map<std::string, ValuePtr> UpdateHyperParams(
const std::unordered_map<std::string, tensor::TensorPtr> &params_init);
std::map<std::string, AnfNodePtr> &parameter_default_value() { return parameter_default_value_; }
void set_has_vararg(bool has_) { has_vararg_ = has_; }
bool has_vararg() const { return has_vararg_; }

View File

@ -123,8 +123,8 @@ bool MindIRLoader::ParseGraphProto(mind_ir::GraphProto *graph, const std::string
return true;
}
std::vector<std::shared_ptr<FuncGraph>> MindIRLoader::LoadMindIRs(std::vector<std::string> file_names) {
std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
std::vector<FuncGraphPtr> MindIRLoader::LoadMindIRs(std::vector<std::string> file_names) {
std::vector<FuncGraphPtr> funcgraph_vec;
MS_LOG(DEBUG) << "Load multiple MindIR files.";
for (const auto &file_name : file_names) {
MS_LOG(DEBUG) << "Load " << file_name;
@ -133,7 +133,7 @@ std::vector<std::shared_ptr<FuncGraph>> MindIRLoader::LoadMindIRs(std::vector<st
return funcgraph_vec;
}
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const void *buffer, const size_t &size) {
FuncGraphPtr MindIRLoader::LoadMindIR(const void *buffer, const size_t &size) {
/* mindir -> func_graph
* only support lite */
mind_ir::ModelProto model;
@ -150,7 +150,7 @@ std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const void *buffer, const si
return func_graph;
}
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const std::string &file_name) {
FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name) {
if (file_name.length() > PATH_MAX) {
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
return nullptr;
@ -305,7 +305,7 @@ std::string LoadPreprocess(const std::string &file_name) {
return origin_model.preprocessor();
}
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
FuncGraphPtr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
MS_EXCEPTION_IF_NULL(buf);
std::string str(buf, buf_size);
mind_ir::ModelProto model_;

View File

@ -34,9 +34,10 @@ class MindIRLoader {
bool get_need_renormalize() const { return need_renormalize_; }
void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; }
std::shared_ptr<FuncGraph> LoadMindIR(const void *buffer, const size_t &size);
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name);
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names);
FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size);
FuncGraphPtr LoadMindIR(const std::string &file_name);
std::vector<FuncGraphPtr> LoadMindIRs(const std::vector<std::string> file_names);
private:
bool ParseModelProto(mind_ir::ModelProto *model, const std::string &path);
@ -51,6 +52,6 @@ class MindIRLoader {
std::string LoadPreprocess(const std::string &file_name);
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
FuncGraphPtr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
} // namespace mindspore
#endif // MINDSPORE_CORE_LOAD_MODEL_H

View File

@ -54,7 +54,9 @@ class Cell(Cell_):
currently. The bprop method must contain the self parameter.
Args:
auto_prefix (bool): Recursively generate namespaces. Default: True.
auto_prefix (bool): Recursively generate namespaces. It will affect the name of the parameter in the network.
If set to True, the network parameter name will be prefixed, otherwise it will not.
Default: True.
flags (dict): Network configuration information, currently it is used for the binding of network and dataset.
Users can also customize network attributes by this parameter. Default: None.
@ -65,13 +67,23 @@ class Cell(Cell_):
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> class MyCell(nn.Cell):
... def __init__(self):
... super(MyCell, self).__init__()
... self.relu = ops.ReLU()
... def __init__(self, forward_net):
... super(MyCell, self).__init__(auto_prefix=False)
... self.net = forward_net
... self.relu = ops.ReLU()
...
... def construct(self, x):
... return self.relu(x)
... def construct(self, x):
... y = self.net(x)
... return self.relu(y)
>>>
>>> inner_net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> my_net = MyCell(inner_net)
>>> print(my_net.trainable_params())
... # If the 'auto_prefix' set to True or not set when call the '__init__' method of the parent class,
... # the parameter's name will be 'net.weight'.
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
"""
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase',
@ -1640,7 +1652,17 @@ class GraphCell(Cell):
diagram, and can only use data that shape and type are the same as the input when exporting the MindIR.
Args:
graph (object): A compiled graph loaded from MindIR.
graph (FuncGraph): A compiled graph loaded from MindIR.
params_init (dict): Parameters need to be inited in the graph.
The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
If the parameter exists in the graph according to the name, update it's value.
If the parameter does not exist, ignore it. Default: None.
Raises:
TypeError: If the `graph` is not a FuncGraph.
TypeError: If the `params_init` is not a dict.
TypeError: If the key of the `params_init` is not a str.
TypeError: If the value of the `params_init` is neither a Tensor nor a Parameter.
ValueError: If the initial value's dtype and shape are not consistent with the parameter would be inited.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -1661,12 +1683,26 @@ class GraphCell(Cell):
[6. 9. 6.]
[4. 6. 4.]]]]
"""
def __init__(self, graph):
def __init__(self, graph, params_init=None):
super(GraphCell, self).__init__(auto_prefix=True)
if not isinstance(graph, FuncGraph):
raise TypeError(f"graph must be a FuncGraph loaded from MindIR, but got {type(graph)}.")
raise TypeError(f"The 'graph' must be a FuncGraph loaded from MindIR, but got {type(graph)}.")
self.graph = graph
params_init = {} if params_init is None else params_init
if not isinstance(params_init, dict):
raise TypeError(f"The 'params_init' must be a dict, but got {type(params_init)}.")
for name, value in params_init.items():
if not isinstance(name, str) or not isinstance(value, Tensor):
raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, "
f"but got the key type: {type(name)}, and the value type: {type(value)}")
params_dict = self.graph.update_hyper_params(params_init)
for name, value in params_dict.items():
param = Parameter(value)
param.param_info = value.param_info
self._params[name] = param
def construct(self, *inputs):
return self.graph(*inputs)

View File

@ -64,7 +64,6 @@ def _tensors_cast_datatype(datatype, param):
return F.cast(param, datatype)
class WithLossCell(Cell):
r"""
Cell with loss function.

View File

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model and parameters serialization."""
import copy
import json
import math

View File

@ -0,0 +1,133 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test get and init GraphCell parameters"""
import numpy as np
import pytest
from mindspore import nn
from mindspore import context
from mindspore import Tensor, Parameter
from mindspore import export, load, save_checkpoint, load_checkpoint
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.flag = False
self.weight = Parameter(np_param, requires_grad=True)
self.dense = nn.Dense(3, 4)
def construct(self, x, y):
if self.flag:
ret = self.dense(x * self.weight)
else:
ret = x * y * self.weight
self.weight += 1.0
return ret
np_a = np.ones((2, 3), np.float32) + 2
np_b = np.ones((2, 3), np.float32) + 3
np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)
input_a = Tensor(np_a)
input_b = Tensor(np_b)
mindir_path = "test_net.mindir"
ckpt_path = "test_net.ckpt"
def setup_module():
net = Net()
export(net, input_a, input_b, file_name=mindir_path[:-7], file_format='MINDIR')
def load_mindir_and_update_params():
load_net = nn.GraphCell(graph=load(mindir_path))
ret = load_net(input_a, input_b)
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
save_checkpoint(load_net, ckpt_path)
params_init = load_checkpoint(ckpt_path)
load_net_with_new_params = nn.GraphCell(graph=load(mindir_path), params_init=params_init)
return load_net_with_new_params
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_and_init_graph_cell_parameters():
"""
Description: load mind ir and update parameters.
Expectation: generate a graph with updated parameters.
"""
load_net = load_mindir_and_update_params()
ret = load_net(input_a, input_b)
assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_init_graph_cell_parameters_with_wrong_type():
"""
Description: load mind ir and update parameters with wrong type.
Expectation: raise a ValueError indicating the params type error.
"""
new_params = {"weight": np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)}
with pytest.raises(TypeError) as err:
graph = load(mindir_path)
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)
assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_init_graph_cell_parameters_with_wrong_shape():
"""
Description: load mind ir and update parameters with wrong tensor shape.
Expectation: raise a ValueError indicating the tensor shape error.
"""
new_params = {"weight": Parameter(np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32))}
with pytest.raises(ValueError) as err:
graph = load(mindir_path)
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_init_graph_cell_parameters_with_wrong_dtype():
"""
Description: load mind ir and update parameters with wrong tensor dtype.
Expectation: raise a ValueError indicating the tensor dtype error.
"""
new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
with pytest.raises(ValueError) as err:
graph = load(mindir_path)
load_net = nn.GraphCell(graph, params_init=new_params)
load_net(input_a, input_b)
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)