diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index d46dadb0289..ccbf5d15b4e 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -1018,7 +1018,9 @@ bool SetMindIRGraphAction(const ResourcePtr &res) { }); if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) { MS_LOG(EXCEPTION) << "The input arguments is not compatible with the function graph which has been exported before." - << " Please check the args is same with export.\n" + << "Please check the args is same with export.\n" + << "The export input argument size : " << func_args.size() << "\n" + << "The load input argument size : " << broaded_args.size() << "\n" << "Export input args info:" << abstract::ArgsToString(func_args) << "\n" << "The input args info:" << abstract::ArgsToString(broaded_args); } diff --git a/mindspore/core/abstract/abstract_function.h b/mindspore/core/abstract/abstract_function.h index 0754d4bc8ee..7d3a00c7f03 100644 --- a/mindspore/core/abstract/abstract_function.h +++ b/mindspore/core/abstract/abstract_function.h @@ -69,11 +69,7 @@ class MS_CORE_API AbstractFuncUnion final : public AbstractFunction { std::string ToString() const override; - AbstractFunctionPtr GetUnique() override { - MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; - AbstractFunctionPtr result; - return result; - } + AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; } /// \brief Check whether the input AbstractFunction is in AbstractFuncUnion. /// @@ -90,11 +86,7 @@ class MS_CORE_API AbstractFuncUnion final : public AbstractFunction { std::size_t hash() const override; - AbstractFunctionPtr Copy() const override { - MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; - AbstractFunctionPtr result; - return result; - } + AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; } private: AbstractFuncAtomPtrList func_list_; diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 366d4aafbb1..2e7d6e7e1ea 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -64,8 +64,6 @@ AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr & AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -202,8 +200,7 @@ AbstractBasePtr InferImplScatterElements(const AnalysisEnginePtr &, const Primit const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); + AbstractBasePtr InferImplSubAndFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -246,8 +243,6 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplReduceFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/primitive_infer_map.h b/mindspore/core/abstract/primitive_infer_map.h index fe10017d980..15a34c32b0c 100644 --- a/mindspore/core/abstract/primitive_infer_map.h +++ b/mindspore/core/abstract/primitive_infer_map.h @@ -56,16 +56,16 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard class RegisterStandardPrimitiveEvalHelper { public: RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeImpl &infer_impl, - const InferValueImpl &infer_value_impl, const bool is_wight_list = true) { - const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list}; + const InferValueImpl &infer_value_impl, const bool is_white_list = true) { + const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_white_list}; RegisterStandardPrimitiveImpl(primitive, impl_reg); } ~RegisterStandardPrimitiveEvalHelper() = default; }; -#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_wight_list) \ +#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_white_list) \ static auto helper_##name = \ - abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_wight_list); \ + abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_white_list); \ std::shared_ptr GetDefaultPrimC##name() { \ auto out = std::make_shared(); \ return out; \ diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index bcba5142752..2c53aff5b3a 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -22,10 +22,11 @@ import os import shutil import stat import sys -import time -from collections import defaultdict import threading from threading import Thread, Lock +import time +from collections import defaultdict + import numpy as np from mindspore.train.checkpoint_pb2 import Checkpoint @@ -67,6 +68,7 @@ _ckpt_mutex = Lock() SLICE_SIZE = 512 * 1024 PROTO_LIMIT_SIZE = 1024 * 1024 * 2 TOTAL_SAVE = 1024 * 1024 +PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024 def _special_process_par(par, new_par): @@ -837,6 +839,108 @@ def _export(net, file_name, file_format, *inputs, **kwargs): net.set_train(mode=True) +def _generate_front_info_for_param_data_file(is_encrypt, kwargs): + front_info = bytes() + check_code = sys.byteorder == "little" + front_info += check_code.to_bytes(1, byteorder=sys.byteorder) + front_info += bytes(63) + if is_encrypt(): + front_info = _encrypt(front_info, len(front_info), kwargs['enc_key'], + len(kwargs['enc_key']), kwargs['enc_mode']) + return front_info + + +def _change_file(ori_data_file_name, dirname, external_local): + # The parameter has been not written in the file + if os.path.getsize(ori_data_file_name) == 64: + raise RuntimeError("The parameter size is exceed 1T,cannot export to the file") + data_file_name = os.path.join(dirname, external_local) + if os.path.exists(data_file_name): + os.chmod(data_file_name, stat.S_IWUSR) + return data_file_name + + +def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs): + ''' + The function to save parameter data + ''' + logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.") + # save parameter + file_prefix = file_name.split("/")[-1] + if file_prefix.endswith(".mindir"): + file_prefix = file_prefix[:-7] + current_path = os.path.abspath(file_name) + dirname = os.path.dirname(current_path) + data_path = os.path.join(dirname, file_prefix + "_variables") + if os.path.exists(data_path): + shutil.rmtree(data_path) + os.makedirs(data_path, exist_ok=True) + os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) + # Reserves 4096 bytes as spare information such as check data + offset = 64 + index = 0 + parameter_size = (offset / 1024) + external_local = os.path.join(file_prefix + "_variables", "data_" + str(index)) + data_file_name = os.path.join(dirname, external_local) + if os.path.exists(data_file_name): + os.chmod(data_file_name, stat.S_IWUSR) + f = open(data_file_name, "wb") + f.write(bytes(offset)) + try: + for param_proto in model.graph.parameter: + name = param_proto.name[param_proto.name.find(":") + 1:] + param = net_dict[name] + raw_data = param.data.asnumpy().tobytes() + data_length = len(raw_data) + append_size = 0 + if data_length % 64 != 0: + append_size = 64 - (data_length % 64) + parameter_size += ((append_size + data_length) / 1024) + if parameter_size > PARAMETER_SPLIT_SIZE: + front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs) + f.seek(0, 0) + f.write(front_info) + f.close() + os.chmod(data_file_name, stat.S_IRUSR) + offset = 64 + index += 1 + parameter_size = (offset + append_size + data_length) / 1024 + external_local = os.path.join(file_prefix + "_variables", "data_" + str(index)) + data_file_name = _change_file(data_file_name, dirname, external_local) + f = open(data_file_name, "wb") + f.write(bytes(offset)) + param_proto.external_data.location = external_local + param_proto.external_data.length = data_length + param_proto.external_data.offset = offset + write_data = raw_data + bytes(append_size) + offset += (data_length + append_size) + if is_encrypt(): + write_data = _encrypt(write_data, len(write_data), kwargs['enc_key'], + len(kwargs['enc_key']), kwargs['enc_mode']) + f.write(write_data) + + # save graph + graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir") + if os.path.exists(graph_file_name): + os.chmod(graph_file_name, stat.S_IWUSR) + with open(graph_file_name, 'wb') as model_file: + os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR) + model_string = model.SerializeToString() + if is_encrypt(): + model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], + len(kwargs['enc_key']), + kwargs['enc_mode']) + model_file.write(model_string) + os.chmod(graph_file_name, stat.S_IRUSR) + + front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs) + f.seek(0, 0) + f.write(front_info) + finally: + f.close() + os.chmod(data_file_name, stat.S_IRUSR) + + def _save_mindir(net, file_name, *inputs, **kwargs): """Save MindIR format file.""" model = mindir_model() @@ -859,67 +963,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs): if save_together: _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs) else: - logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.") - # save parameter - file_prefix = file_name.split("/")[-1] - if file_prefix.endswith(".mindir"): - file_prefix = file_prefix[:-7] - current_path = os.path.abspath(file_name) - dirname = os.path.dirname(current_path) - data_path = os.path.join(dirname, file_prefix + "_variables") - if os.path.exists(data_path): - shutil.rmtree(data_path) - os.makedirs(data_path, exist_ok=True) - os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) - # Reserves 4096 bytes as spare information such as check data - offset = 64 - data_file_name = os.path.join(data_path, "veriables.data") - if os.path.exists(data_file_name): - os.chmod(data_file_name, stat.S_IWUSR) - with open(data_file_name, "wb") as f: - f.write(bytes(offset)) - for name, param in net_dict.items(): - for param_proto in model.graph.parameter: - if name == param_proto.name[param_proto.name.find(":") + 1:]: - data_file = os.path.join(file_prefix + "_variables", "veriables.data") - param_proto.external_data.location = data_file - raw_data = param.data.asnumpy().tobytes() - data_length = len(raw_data) - param_proto.external_data.length = data_length - param_proto.external_data.offset = offset - write_data = raw_data - offset += data_length - if data_length % 64 != 0: - append_size = 64 - (data_length % 64) - write_data += (bytes(append_size)) - offset += append_size - if is_encrypt(): - write_data = _encrypt(write_data, len(write_data), kwargs['enc_key'], - len(kwargs['enc_key']), kwargs['enc_mode']) - f.write(write_data) - - # save graph - graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir") - if os.path.exists(graph_file_name): - os.chmod(graph_file_name, stat.S_IWUSR) - with open(graph_file_name, 'wb') as model_file: - os.chmod(graph_file_name, stat.S_IRUSR | stat.S_IWUSR) - model_string = model.SerializeToString() - if is_encrypt(): - model_string = _encrypt(model_string, len(model_string), kwargs['enc_key'], - len(kwargs['enc_key']), - kwargs['enc_mode']) - model_file.write(model_string) - os.chmod(graph_file_name, stat.S_IRUSR) - - front_info = bytearray() - check_code = sys.byteorder == "little" - front_info += check_code.to_bytes(1, byteorder=sys.byteorder) - f.seek(0, 0) - if is_encrypt(): - front_info = _encrypt(front_info, len(front_info), kwargs['enc_key'], - len(kwargs['enc_key']), kwargs['enc_mode']) - f.write(front_info) + _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs) def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs): @@ -1242,6 +1286,7 @@ def restore_group_info_list(group_info_file_name): restore_rank_list = [rank for rank in restore_list.dim] return restore_rank_list + def build_searched_strategy(strategy_filename): """ Build strategy of every parameter in network. Used in the case of distributed inference. @@ -1528,7 +1573,6 @@ def async_ckpt_thread_status(): def _check_predict_strategy(predict_strategy): """Check predict strategy.""" - def _check_int_list(arg): if not isinstance(arg, list): return False diff --git a/tests/ut/python/mindir/test_mindir_export_larger_than_2g.py b/tests/ut/python/mindir/test_mindir_export_larger_than_2g.py new file mode 100644 index 00000000000..cd916006f00 --- /dev/null +++ b/tests/ut/python/mindir/test_mindir_export_larger_than_2g.py @@ -0,0 +1,146 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test mindir export larger than 1G """ +import os +import sys + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Parameter +from mindspore.common.tensor import Tensor +from mindspore.train.serialization import export, load + +def get_front_info(): + correct_data = bytes() + check_code = sys.byteorder == "little" + correct_data += check_code.to_bytes(1, byteorder=sys.byteorder) + correct_data += bytes(63) + return correct_data + +def get_correct_data(parameter): + correct_data = bytes() + data = parameter.data.asnumpy().tobytes() + data_size = len(data) + if data_size % 64 != 0: + data += bytes((64 - data_size % 64)) + correct_data += data + return correct_data + + +def get_data(mindir_name): + data_path = mindir_name + "_variables" + data = bytes() + for dirpath, _, filenames in os.walk(data_path): + for filename in filenames: + with open(os.path.join(dirpath, filename), "rb") as f: + data += f.readline() + return data + + +def test_mindir_export_split(): + """ + Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) + Description: MindIR Export model is exceed TOTAL_SAVE should be split save as model file and data file + Expectation: No exception. + """ + ms.train.serialization.TOTAL_SAVE = 0 + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.addn = ops.AddN() + self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w") + self.z = Parameter(Tensor(np.array([2, 3, 3, 4])).astype(np.float32), name="z") + + def construct(self, x): + return self.addn((x, self.y, self.z)) + + x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32)) + add_net = Net() + export(add_net, x, file_name="mindir_export_split", file_format="MINDIR") + graph = load("mindir_export_split_graph.mindir") + assert graph is not None + correct_data = get_front_info() + correct_data += get_correct_data(add_net.y) + correct_data += get_correct_data(add_net.z) + export_data = get_data("mindir_export_split") + assert export_data == correct_data + assert oct(os.stat(os.path.join("mindir_export_split_variables", "data_0")).st_mode)[-3:] == "400" + assert oct(os.stat("mindir_export_split_graph.mindir").st_mode)[-3:] == "400" + + +def test_mindir_export_larger_error(): + """ + Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) + Description: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) should be split save as model file + and data file if the model has a parameter which exceed PARAMETER_SPLIT_SIZE(1T but mocked as 0) + the exception should be reported. + Expectation: Parameter is exceed PARAMETER_SPLIT_SIZE + """ + ms.train.serialization.TOTAL_SAVE = 0 + ms.train.serialization.PARAMETER_SPLIT_SIZE = 0 + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.add = ops.Add() + self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w") + + def construct(self, x): + return self.add(x, self.y) + + x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32)) + add = Net() + with pytest.raises(RuntimeError) as e: + export(add, x, file_name="net", file_format="MINDIR") + assert e.message == "The parameter size is exceed 1T,cannot export to the file" + + +def test_mindir_export_larger_parameter_exceed_1t_mock(): + """ + Feature: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) + Description: MindIR Export model is exceed TOTAL_SAVE(1G but mocked as 0) should be split save as model file + and data file if the parameter data file exceed PARAMETER_SPLIT_SIZE(1T but mocked as 129Bytes) limit, + it will be split to another file named data_0,data_1,data_2... + Expectation: No exception. + """ + ms.train.serialization.TOTAL_SAVE = 0 + ms.train.serialization.PARAMETER_SPLIT_SIZE = 129 / 1024 + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.addn = ops.AddN() + self.y = Parameter(Tensor(np.array([2, 3, 3, 4]).astype(np.float32)), name="w") + self.z = Parameter(Tensor(np.array([2, 3, 3, 4])).astype(np.float32), name="z") + + def construct(self, x): + return self.addn((x, self.y, self.z)) + + x = Tensor(np.array([2, 3, 3, 4]).astype(np.float32)) + add_net = Net() + export(add_net, x, file_name="larger_parameter_exceed_1T_mock", file_format="MINDIR") + graph = load("larger_parameter_exceed_1T_mock_graph.mindir") + assert graph is not None + correct_data = get_front_info() + correct_data += get_correct_data(add_net.y) + correct_data += get_front_info() + correct_data += get_correct_data(add_net.z) + export_data = get_data("larger_parameter_exceed_1T_mock") + assert export_data == correct_data