forked from mindspore-Ecosystem/mindspore
support upodate parameters after load mindir
This commit is contained in:
parent
2b2194b0f4
commit
45e0a4b9bf
|
@ -124,9 +124,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
(void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");
|
||||
|
||||
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
|
||||
(py::object)
|
||||
m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
|
||||
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), "Load model as Graph.");
|
||||
(void)m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
|
||||
py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), "Load model as Graph.");
|
||||
|
||||
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
|
||||
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
|
||||
|
|
|
@ -28,6 +28,8 @@ REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
|
|||
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
|
||||
.def(py::init())
|
||||
.def("str", &FuncGraph::ToString, "Get FuncGraph string representation.")
|
||||
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph");
|
||||
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph")
|
||||
.def("update_hyper_params", &FuncGraph::UpdateHyperParams, py::arg("params_init"),
|
||||
"Update FuncGraph hyper parameters, and return the updated parameters.");
|
||||
}));
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -556,6 +556,38 @@ size_t FuncGraph::GetDefaultValueCount() {
|
|||
return parameter_default_value_.size() - LongToSize(null_count);
|
||||
}
|
||||
|
||||
std::map<std::string, ValuePtr> FuncGraph::UpdateHyperParams(
|
||||
const std::unordered_map<std::string, tensor::TensorPtr> ¶ms_init) {
|
||||
std::map<std::string, ValuePtr> hyper_params;
|
||||
for (const auto ¶ : parameters_) {
|
||||
auto param_node = para->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
const std::string ¶m_name = param_node->name();
|
||||
|
||||
if (param_node->has_default()) {
|
||||
if (params_init.find(param_name) != params_init.end()) {
|
||||
const auto &old_value = param_node->default_param()->cast<tensor::TensorPtr>();
|
||||
const auto &new_value = params_init.at(param_name);
|
||||
MS_EXCEPTION_IF_NULL(old_value);
|
||||
MS_EXCEPTION_IF_NULL(new_value);
|
||||
if (new_value->shape() != old_value->shape() || new_value->data_type() != old_value->data_type()) {
|
||||
MS_EXCEPTION(ValueError) << "Only support update parameter by Tensor with same shape and dtype as it. "
|
||||
"The parameter '"
|
||||
<< param_name << "' has shape " << old_value->shape() << " and dtype "
|
||||
<< TypeIdLabel(old_value->data_type()) << ", but got the update Tensor with shape "
|
||||
<< new_value->shape() << " and dtype " << TypeIdLabel(new_value->data_type()) << ".";
|
||||
}
|
||||
|
||||
auto new_default_param = std::make_shared<tensor::Tensor>(*new_value);
|
||||
new_default_param->set_param_info(old_value->param_info());
|
||||
param_node->set_default_param(new_default_param);
|
||||
}
|
||||
hyper_params[param_name] = param_node->default_param();
|
||||
}
|
||||
}
|
||||
return hyper_params;
|
||||
}
|
||||
|
||||
AnfNodePtr FuncGraph::GetVariableArgParameter() {
|
||||
if (!has_vararg_) {
|
||||
return nullptr;
|
||||
|
|
|
@ -214,6 +214,8 @@ class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfo
|
|||
void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
|
||||
void ClearDefaultValues();
|
||||
size_t GetDefaultValueCount();
|
||||
std::map<std::string, ValuePtr> UpdateHyperParams(
|
||||
const std::unordered_map<std::string, tensor::TensorPtr> ¶ms_init);
|
||||
std::map<std::string, AnfNodePtr> ¶meter_default_value() { return parameter_default_value_; }
|
||||
void set_has_vararg(bool has_) { has_vararg_ = has_; }
|
||||
bool has_vararg() const { return has_vararg_; }
|
||||
|
|
|
@ -123,8 +123,8 @@ bool MindIRLoader::ParseGraphProto(mind_ir::GraphProto *graph, const std::string
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<FuncGraph>> MindIRLoader::LoadMindIRs(std::vector<std::string> file_names) {
|
||||
std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
|
||||
std::vector<FuncGraphPtr> MindIRLoader::LoadMindIRs(std::vector<std::string> file_names) {
|
||||
std::vector<FuncGraphPtr> funcgraph_vec;
|
||||
MS_LOG(DEBUG) << "Load multiple MindIR files.";
|
||||
for (const auto &file_name : file_names) {
|
||||
MS_LOG(DEBUG) << "Load " << file_name;
|
||||
|
@ -133,7 +133,7 @@ std::vector<std::shared_ptr<FuncGraph>> MindIRLoader::LoadMindIRs(std::vector<st
|
|||
return funcgraph_vec;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const void *buffer, const size_t &size) {
|
||||
FuncGraphPtr MindIRLoader::LoadMindIR(const void *buffer, const size_t &size) {
|
||||
/* mindir -> func_graph
|
||||
* only support lite */
|
||||
mind_ir::ModelProto model;
|
||||
|
@ -150,7 +150,7 @@ std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const void *buffer, const si
|
|||
return func_graph;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const std::string &file_name) {
|
||||
FuncGraphPtr MindIRLoader::LoadMindIR(const std::string &file_name) {
|
||||
if (file_name.length() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
||||
return nullptr;
|
||||
|
@ -305,7 +305,7 @@ std::string LoadPreprocess(const std::string &file_name) {
|
|||
return origin_model.preprocessor();
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
|
||||
FuncGraphPtr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
std::string str(buf, buf_size);
|
||||
mind_ir::ModelProto model_;
|
||||
|
|
|
@ -34,9 +34,10 @@ class MindIRLoader {
|
|||
|
||||
bool get_need_renormalize() const { return need_renormalize_; }
|
||||
void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; }
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const void *buffer, const size_t &size);
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name);
|
||||
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names);
|
||||
|
||||
FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name);
|
||||
std::vector<FuncGraphPtr> LoadMindIRs(const std::vector<std::string> file_names);
|
||||
|
||||
private:
|
||||
bool ParseModelProto(mind_ir::ModelProto *model, const std::string &path);
|
||||
|
@ -51,6 +52,6 @@ class MindIRLoader {
|
|||
|
||||
std::string LoadPreprocess(const std::string &file_name);
|
||||
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
|
||||
FuncGraphPtr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_LOAD_MODEL_H
|
||||
|
|
|
@ -54,7 +54,9 @@ class Cell(Cell_):
|
|||
currently. The bprop method must contain the self parameter.
|
||||
|
||||
Args:
|
||||
auto_prefix (bool): Recursively generate namespaces. Default: True.
|
||||
auto_prefix (bool): Recursively generate namespaces. It will affect the name of the parameter in the network.
|
||||
If set to True, the network parameter name will be prefixed, otherwise it will not.
|
||||
Default: True.
|
||||
flags (dict): Network configuration information, currently it is used for the binding of network and dataset.
|
||||
Users can also customize network attributes by this parameter. Default: None.
|
||||
|
||||
|
@ -65,13 +67,23 @@ class Cell(Cell_):
|
|||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.ops as ops
|
||||
>>> class MyCell(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(MyCell, self).__init__()
|
||||
... self.relu = ops.ReLU()
|
||||
... def __init__(self, forward_net):
|
||||
... super(MyCell, self).__init__(auto_prefix=False)
|
||||
... self.net = forward_net
|
||||
... self.relu = ops.ReLU()
|
||||
...
|
||||
... def construct(self, x):
|
||||
... return self.relu(x)
|
||||
... def construct(self, x):
|
||||
... y = self.net(x)
|
||||
... return self.relu(y)
|
||||
>>>
|
||||
>>> inner_net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
||||
>>> my_net = MyCell(inner_net)
|
||||
>>> print(my_net.trainable_params())
|
||||
... # If the 'auto_prefix' set to True or not set when call the '__init__' method of the parent class,
|
||||
... # the parameter's name will be 'net.weight'.
|
||||
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
|
||||
"""
|
||||
|
||||
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
|
||||
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
|
||||
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase',
|
||||
|
@ -1640,7 +1652,17 @@ class GraphCell(Cell):
|
|||
diagram, and can only use data that shape and type are the same as the input when exporting the MindIR.
|
||||
|
||||
Args:
|
||||
graph (object): A compiled graph loaded from MindIR.
|
||||
graph (FuncGraph): A compiled graph loaded from MindIR.
|
||||
params_init (dict): Parameters need to be inited in the graph.
|
||||
The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
|
||||
If the parameter exists in the graph according to the name, update it's value.
|
||||
If the parameter does not exist, ignore it. Default: None.
|
||||
Raises:
|
||||
TypeError: If the `graph` is not a FuncGraph.
|
||||
TypeError: If the `params_init` is not a dict.
|
||||
TypeError: If the key of the `params_init` is not a str.
|
||||
TypeError: If the value of the `params_init` is neither a Tensor nor a Parameter.
|
||||
ValueError: If the initial value's dtype and shape are not consistent with the parameter would be inited.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
@ -1661,12 +1683,26 @@ class GraphCell(Cell):
|
|||
[6. 9. 6.]
|
||||
[4. 6. 4.]]]]
|
||||
"""
|
||||
def __init__(self, graph):
|
||||
def __init__(self, graph, params_init=None):
|
||||
super(GraphCell, self).__init__(auto_prefix=True)
|
||||
if not isinstance(graph, FuncGraph):
|
||||
raise TypeError(f"graph must be a FuncGraph loaded from MindIR, but got {type(graph)}.")
|
||||
raise TypeError(f"The 'graph' must be a FuncGraph loaded from MindIR, but got {type(graph)}.")
|
||||
self.graph = graph
|
||||
|
||||
params_init = {} if params_init is None else params_init
|
||||
if not isinstance(params_init, dict):
|
||||
raise TypeError(f"The 'params_init' must be a dict, but got {type(params_init)}.")
|
||||
for name, value in params_init.items():
|
||||
if not isinstance(name, str) or not isinstance(value, Tensor):
|
||||
raise TypeError("The key of the 'params_init' must be str, and the value must be Tensor or Parameter, "
|
||||
f"but got the key type: {type(name)}, and the value type: {type(value)}")
|
||||
|
||||
params_dict = self.graph.update_hyper_params(params_init)
|
||||
for name, value in params_dict.items():
|
||||
param = Parameter(value)
|
||||
param.param_info = value.param_info
|
||||
self._params[name] = param
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.graph(*inputs)
|
||||
|
||||
|
|
|
@ -64,7 +64,6 @@ def _tensors_cast_datatype(datatype, param):
|
|||
return F.cast(param, datatype)
|
||||
|
||||
|
||||
|
||||
class WithLossCell(Cell):
|
||||
r"""
|
||||
Cell with loss function.
|
||||
|
|
|
@ -12,7 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Model and parameters serialization."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""test get and init GraphCell parameters"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import export, load, save_checkpoint, load_checkpoint
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.flag = False
|
||||
self.weight = Parameter(np_param, requires_grad=True)
|
||||
self.dense = nn.Dense(3, 4)
|
||||
|
||||
def construct(self, x, y):
|
||||
if self.flag:
|
||||
ret = self.dense(x * self.weight)
|
||||
else:
|
||||
ret = x * y * self.weight
|
||||
self.weight += 1.0
|
||||
return ret
|
||||
|
||||
|
||||
np_a = np.ones((2, 3), np.float32) + 2
|
||||
np_b = np.ones((2, 3), np.float32) + 3
|
||||
np_param = np.arange(2 * 3).reshape((2, 3)).astype(np.float32)
|
||||
input_a = Tensor(np_a)
|
||||
input_b = Tensor(np_b)
|
||||
mindir_path = "test_net.mindir"
|
||||
ckpt_path = "test_net.ckpt"
|
||||
|
||||
|
||||
def setup_module():
|
||||
net = Net()
|
||||
export(net, input_a, input_b, file_name=mindir_path[:-7], file_format='MINDIR')
|
||||
|
||||
|
||||
def load_mindir_and_update_params():
|
||||
load_net = nn.GraphCell(graph=load(mindir_path))
|
||||
ret = load_net(input_a, input_b)
|
||||
assert np.array_equal(ret.asnumpy(), np_a * np_b * np_param)
|
||||
|
||||
save_checkpoint(load_net, ckpt_path)
|
||||
params_init = load_checkpoint(ckpt_path)
|
||||
load_net_with_new_params = nn.GraphCell(graph=load(mindir_path), params_init=params_init)
|
||||
return load_net_with_new_params
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_get_and_init_graph_cell_parameters():
|
||||
"""
|
||||
Description: load mind ir and update parameters.
|
||||
Expectation: generate a graph with updated parameters.
|
||||
"""
|
||||
load_net = load_mindir_and_update_params()
|
||||
ret = load_net(input_a, input_b)
|
||||
|
||||
assert np.array_equal(ret.asnumpy(), np_a * np_b * (np_param + 1.0))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_init_graph_cell_parameters_with_wrong_type():
|
||||
"""
|
||||
Description: load mind ir and update parameters with wrong type.
|
||||
Expectation: raise a ValueError indicating the params type error.
|
||||
"""
|
||||
new_params = {"weight": np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32)}
|
||||
with pytest.raises(TypeError) as err:
|
||||
graph = load(mindir_path)
|
||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||
load_net(input_a, input_b)
|
||||
|
||||
assert "The key of the 'params_init' must be str, and the value must be Tensor or Parameter" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_init_graph_cell_parameters_with_wrong_shape():
|
||||
"""
|
||||
Description: load mind ir and update parameters with wrong tensor shape.
|
||||
Expectation: raise a ValueError indicating the tensor shape error.
|
||||
"""
|
||||
new_params = {"weight": Parameter(np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.float32))}
|
||||
with pytest.raises(ValueError) as err:
|
||||
graph = load(mindir_path)
|
||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||
load_net(input_a, input_b)
|
||||
|
||||
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_init_graph_cell_parameters_with_wrong_dtype():
|
||||
"""
|
||||
Description: load mind ir and update parameters with wrong tensor dtype.
|
||||
Expectation: raise a ValueError indicating the tensor dtype error.
|
||||
"""
|
||||
new_params = {"weight": Parameter(np.arange(2 * 3).reshape((2, 3)).astype(np.float64))}
|
||||
with pytest.raises(ValueError) as err:
|
||||
graph = load(mindir_path)
|
||||
load_net = nn.GraphCell(graph, params_init=new_params)
|
||||
load_net(input_a, input_b)
|
||||
|
||||
assert "Only support update parameter by Tensor with same shape and dtype as it" in str(err.value)
|
Loading…
Reference in New Issue