!26280 fix mindir export's error when using _encrypt

Merge pull request !26280 from lianliguang/master
This commit is contained in:
i-robot 2021-11-24 01:48:32 +00:00 committed by Gitee
commit 9a5fd32bd2
6 changed files with 264 additions and 85 deletions

View File

@ -1018,7 +1018,9 @@ bool SetMindIRGraphAction(const ResourcePtr &res) {
});
if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) {
MS_LOG(EXCEPTION) << "The input arguments is not compatible with the function graph which has been exported before."
<< " Please check the args is same with export.\n"
<< "Please check the args is same with export.\n"
<< "The export input argument size : " << func_args.size() << "\n"
<< "The load input argument size : " << broaded_args.size() << "\n"
<< "Export input args info:" << abstract::ArgsToString(func_args) << "\n"
<< "The input args info:" << abstract::ArgsToString(broaded_args);
}

View File

@ -69,11 +69,7 @@ class MS_CORE_API AbstractFuncUnion final : public AbstractFunction {
std::string ToString() const override;
AbstractFunctionPtr GetUnique() override {
MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion";
AbstractFunctionPtr result;
return result;
}
AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; }
/// \brief Check whether the input AbstractFunction is in AbstractFuncUnion.
///
@ -90,11 +86,7 @@ class MS_CORE_API AbstractFuncUnion final : public AbstractFunction {
std::size_t hash() const override;
AbstractFunctionPtr Copy() const override {
MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion";
AbstractFunctionPtr result;
return result;
}
AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; }
private:
AbstractFuncAtomPtrList func_list_;

View File

@ -64,8 +64,6 @@ AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -202,8 +200,7 @@ AbstractBasePtr InferImplScatterElements(const AnalysisEnginePtr &, const Primit
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -246,8 +243,6 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReduceFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -56,16 +56,16 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard
class RegisterStandardPrimitiveEvalHelper {
public:
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeImpl &infer_impl,
const InferValueImpl &infer_value_impl, const bool is_wight_list = true) {
const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list};
const InferValueImpl &infer_value_impl, const bool is_white_list = true) {
const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_white_list};
RegisterStandardPrimitiveImpl(primitive, impl_reg);
}
~RegisterStandardPrimitiveEvalHelper() = default;
};
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_wight_list) \
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_white_list) \
static auto helper_##name = \
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_wight_list); \
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_white_list); \
std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \
auto out = std::make_shared<name>(); \
return out; \

View File

@ -22,10 +22,11 @@ import os
import shutil
import stat
import sys
import time
from collections import defaultdict
import threading
from threading import Thread, Lock
import time
from collections import defaultdict
import numpy as np
from mindspore.train.checkpoint_pb2 import Checkpoint
@ -67,6 +68,7 @@ _ckpt_mutex = Lock()
SLICE_SIZE = 512 * 1024
PROTO_LIMIT_SIZE = 1024 * 1024 * 2
TOTAL_SAVE = 1024 * 1024
PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024
def _special_process_par(par, new_par):
@ -837,6 +839,108 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
net.set_train(mode=True)
def _generate_front_info_for_param_data_file(is_encrypt, kwargs):
front_info = bytes()
check_code = sys.byteorder == "little"
front_info += check_code.to_bytes(1, byteorder=sys.byteorder)
front_info += bytes(63)
if is_encrypt():
front_info = _encrypt(front_info, len(front_info), kwargs['enc_key'],
len(kwargs['enc_key']), kwargs['enc_mode'])
return front_info
def _change_file(ori_data_file_name, dirname, external_local):
# The parameter has been not written in the file
if os.path.getsize(ori_data_file_name) == 64:
raise RuntimeError("The parameter size is exceed 1T,cannot export to the file")
data_file_name = os.path.join(dirname, external_local)
if os.path.exists(data_file_name):
os.chmod(data_file_name, stat.S_IWUSR)
return data_file_name
def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
'''
The function to save parameter data
'''
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
# save parameter
file_prefix = file_name.split("/")[-1]
if file_prefix.endswith(".mindir"):
file_prefix = file_prefix[:-7]
current_path = os.path.abspath(file_name)
dirname = os.path.dirname(current_path)
data_path = os.path.join(dirname, file_prefix + "_variables")
if os.path.exists(data_path):
shutil.rmtree(data_path)
os.makedirs(data_path, exist_ok=True)
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
# Reserves 4096 bytes as spare information such as check data
offset = 64
index = 0
parameter_size = (offset / 1024)
external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
data_file_name = os.path.join(dirname, external_local)
if os.path.exists(data_file_name):
os.chmod(data_file_name, stat.S_IWUSR)
f = open(data_file_name, "wb")
f.write(bytes(offset))
try:
for param_proto in model.graph.parameter:
name = param_proto.name[param_proto.name.find(":") + 1:]
param = net_dict[name]
raw_data = param.data.asnumpy().tobytes()
data_length = len(raw_data)
append_size = 0
if data_length % 64 != 0:
append_size = 64 - (data_length % 64)
parameter_size += ((append_size + data_length) / 1024)
if parameter_size > PARAMETER_SPLIT_SIZE:
front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs)
f.seek(0, 0)
f.write(front_info)
f.close()
os.chmod(data_file_name, stat.S_IRUSR)
offset = 64
index += 1
parameter_size = (offset + append_size + data_length) / 1024
external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
data_file_name = _change_file(data_file_name, dirname, external_local)
f = open(data_file_name, "wb")
f.write(bytes(offset))
param_proto.external_data.location = external_local
param_proto.external_data.length = data_length
param_proto.external_data.offset = offset
write_data = raw_data + bytes(append_size)
offset += (data_length + append_size)
if is_encrypt():
write_data = _encrypt(write_data, len(write_data), kwargs['enc_key'],
len(kwargs['enc_key']), kwargs['enc_mode'])
f.write(write_data)
# save graph
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
if os.path.exists(graph_file_name):
os.chmod(graph_file_name, stat.S_IWUSR)
with open(graph_file_name, 'wb') as model_file:
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
model_string = model.SerializeToString()
if is_encrypt():
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'],
len(kwargs['enc_key']),
kwargs['enc_mode'])
model_file.write(model_string)
os.chmod(graph_file_name, stat.S_IRUSR)
front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs)
f.seek(0, 0)
f.write(front_info)
finally:
f.close()
os.chmod(data_file_name, stat.S_IRUSR)
def _save_mindir(net, file_name, *inputs, **kwargs):
"""Save MindIR format file."""
model = mindir_model()
@ -859,67 +963,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
if save_together:
_save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs)
else:
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
# save parameter
file_prefix = file_name.split("/")[-1]
if file_prefix.endswith(".mindir"):
file_prefix = file_prefix[:-7]
current_path = os.path.abspath(file_name)
dirname = os.path.dirname(current_path)
data_path = os.path.join(dirname, file_prefix + "_variables")
if os.path.exists(data_path):
shutil.rmtree(data_path)
os.makedirs(data_path, exist_ok=True)
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
# Reserves 4096 bytes as spare information such as check data
offset = 64
data_file_name = os.path.join(data_path, "veriables.data")
if os.path.exists(data_file_name):
os.chmod(data_file_name, stat.S_IWUSR)
with open(data_file_name, "wb") as f:
f.write(bytes(offset))
for name, param in net_dict.items():
for param_proto in model.graph.parameter:
if name == param_proto.name[param_proto.name.find(":") + 1:]:
data_file = os.path.join(file_prefix + "_variables", "veriables.data")
param_proto.external_data.location = data_file
raw_data = param.data.asnumpy().tobytes()
data_length = len(raw_data)
param_proto.external_data.length = data_length
param_proto.external_data.offset = offset
write_data = raw_data
offset += data_length
if data_length % 64 != 0:
append_size = 64 - (data_length % 64)
write_data += (bytes(append_size))
offset += append_size
if is_encrypt():
write_data = _encrypt(write_data, len(write_data), kwargs['enc_key'],
len(kwargs['enc_key']), kwargs['enc_mode'])
f.write(write_data)
# save graph
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
if os.path.exists(graph_file_name):
os.chmod(graph_file_name, stat.S_IWUSR)
with open(graph_file_name, 'wb') as model_file:
os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR)
model_string = model.SerializeToString()
if is_encrypt():
model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'],
len(kwargs['enc_key']),
kwargs['enc_mode'])
model_file.write(model_string)
os.chmod(graph_file_name, stat.S_IRUSR)
front_info = bytearray()
check_code = sys.byteorder == "little"
front_info += check_code.to_bytes(1, byteorder=sys.byteorder)
f.seek(0, 0)
if is_encrypt():
front_info = _encrypt(front_info, len(front_info), kwargs['enc_key'],
len(kwargs['enc_key']), kwargs['enc_mode'])
f.write(front_info)
_spilt_save(net_dict, model, file_name, is_encrypt, **kwargs)
def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
@ -1242,6 +1286,7 @@ def restore_group_info_list(group_info_file_name):
restore_rank_list = [rank for rank in restore_list.dim]
return restore_rank_list
def build_searched_strategy(strategy_filename):
"""
Build strategy of every parameter in network. Used in the case of distributed inference.
@ -1528,7 +1573,6 @@ def async_ckpt_thread_status():
def _check_predict_strategy(predict_strategy):
"""Check predict strategy."""
def _check_int_list(arg):
if not isinstance(arg, list):
return False

View File

@ -0,0 +1,146 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test mindir export larger than 1G """
import os
import sys
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Parameter
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import export, load
def get_front_info():
correct_data = bytes()
check_code = sys.byteorder == "little"
correct_data += check_code.to_bytes(1, byteorder=sys.byteorder)
correct_data += bytes(63)
return correct_data
def get_correct_data(parameter):
correct_data = bytes()
data = parameter.data.asnumpy().tobytes()
data_size = len(data)
if data_size % 64 != 0:
data += bytes((64 - data_size % 64))
correct_data += data
return correct_data
def get_data(mindir_name):
data_path = mindir_name + "_variables"
data = bytes()
for dirpath, _, filenames in os.walk(data_path):
for filename in filenames:
with open(os.path.join(dirpath, filename), "rb") as f:
data += f.readline()
return data
def test_mindir_export_split():
"""
Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0)
Description: MindIR Export model is exceed TOTAL_SAVE should be split save as model file and data file
Expectation: No exception.
"""
ms.train.serialization.TOTAL_SAVE = 0
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.addn = ops.AddN()
self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w")
self.z = Parameter(Tensor(np.array([2, 3, 3, 4])).astype(np.float32), name="z")
def construct(self, x):
return self.addn((x, self.y, self.z))
x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32))
add_net = Net()
export(add_net, x, file_name="mindir_export_split", file_format="MINDIR")
graph = load("mindir_export_split_graph.mindir")
assert graph is not None
correct_data = get_front_info()
correct_data += get_correct_data(add_net.y)
correct_data += get_correct_data(add_net.z)
export_data = get_data("mindir_export_split")
assert export_data == correct_data
assert oct(os.stat(os.path.join("mindir_export_split_variables", "data_0")).st_mode)[-3:] == "400"
assert oct(os.stat("mindir_export_split_graph.mindir").st_mode)[-3:] == "400"
def test_mindir_export_larger_error():
"""
Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0)
Description: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) should be split save as model file
and data file if the model has a parameter which exceed PARAMETER_SPLIT_SIZE(1T but mocked as 0)
the exception should be reported.
Expectation: Parameter is exceed PARAMETER_SPLIT_SIZE
"""
ms.train.serialization.TOTAL_SAVE = 0
ms.train.serialization.PARAMETER_SPLIT_SIZE = 0
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.add = ops.Add()
self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w")
def construct(self, x):
return self.add(x, self.y)
x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32))
add = Net()
with pytest.raises(RuntimeError) as e:
export(add, x, file_name="net", file_format="MINDIR")
assert e.message == "The parameter size is exceed 1T,cannot export to the file"
def test_mindir_export_larger_parameter_exceed_1t_mock():
"""
Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0)
Description: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) should be split save as model file
and data file if the parameter data file exceed PARAMETER_SPLIT_SIZE(1T but mocked as 129Bytes) limit,
it will be split to another file named data_0,data_1,data_2...
Expectation: No exception.
"""
ms.train.serialization.TOTAL_SAVE = 0
ms.train.serialization.PARAMETER_SPLIT_SIZE = 129 / 1024
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.addn = ops.AddN()
self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w")
self.z = Parameter(Tensor(np.array([2, 3, 3, 4])).astype(np.float32), name="z")
def construct(self, x):
return self.addn((x, self.y, self.z))
x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32))
add_net = Net()
export(add_net, x, file_name="larger_parameter_exceed_1T_mock", file_format="MINDIR")
graph = load("larger_parameter_exceed_1T_mock_graph.mindir")
assert graph is not None
correct_data = get_front_info()
correct_data += get_correct_data(add_net.y)
correct_data += get_front_info()
correct_data += get_correct_data(add_net.z)
export_data = get_data("larger_parameter_exceed_1T_mock")
assert export_data == correct_data