diff --git a/mindspore/lite/tools/common/meta_graph_serializer.cc b/mindspore/lite/tools/common/meta_graph_serializer.cc index c8b559cd7f9..e722a1b9b1b 100644 --- a/mindspore/lite/tools/common/meta_graph_serializer.cc +++ b/mindspore/lite/tools/common/meta_graph_serializer.cc @@ -178,15 +178,16 @@ bool MetaGraphSerializer::ExtraAndSerializeModelWeight(const schema::MetaGraphT } bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key, - const size_t key_len, const std::string &enc_mode) { + const size_t key_len, const std::string &enc_mode, + size_t *size) { // serialize model flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); auto offset = schema::MetaGraph::Pack(builder, &meta_graphT); builder.Finish(offset); schema::FinishMetaGraphBuffer(builder, offset); - size_t size = builder.GetSize(); + *size = builder.GetSize(); auto content = builder.GetBufferPointer(); - if (!SerializeModel(content, size, key, key_len, enc_mode)) { + if (!SerializeModel(content, *size, key, key_len, enc_mode)) { MS_LOG(ERROR) << "Serialize graph failed"; return false; } @@ -221,21 +222,33 @@ uint8_t *MetaGraphSerializer::GetMetaGraphPackedBuff(flatbuffers::FlatBufferBuil int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key, const size_t key_len, const std::string &enc_mode) { - MetaGraphSerializer meta_graph_serializer; size_t size = 0; + auto ret = MetaGraphSerializer::Save(graph, output_path, &size, key, key_len, enc_mode); + return ret; +} + +int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, size_t *size, + const Byte *key, const size_t key_len, const std::string &enc_mode) { + MetaGraphSerializer meta_graph_serializer; + *size = 0; flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); - auto buffer = meta_graph_serializer.GetMetaGraphPackedBuff(&builder, graph, &size); + auto buffer = meta_graph_serializer.GetMetaGraphPackedBuff(&builder, graph, size); if (!meta_graph_serializer.InitPath(output_path)) { MS_LOG(ERROR) << "Init path failed"; return RET_ERROR; } - auto save_together = (size < kModelSizeLimit); + size_t tensors_size = 0; + for (auto &tensor : graph.allTensors) { + tensors_size += tensor->data.size(); + } + + auto save_together = (tensors_size < kModelSizeLimit && *size < kModelSizeLimit); if (!meta_graph_serializer.Init(graph, save_together)) { MS_LOG(ERROR) << "Init MetaGraphSerializer failed"; return RET_ERROR; } if (save_together) { - if (!meta_graph_serializer.SerializeModel(buffer, size, key, key_len, enc_mode)) { + if (!meta_graph_serializer.SerializeModel(buffer, *size, key, key_len, enc_mode)) { MS_LOG(ERROR) << "Serialize graph failed"; return RET_ERROR; } @@ -244,10 +257,12 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string MS_LOG(ERROR) << "Serialize graph weight failed"; return RET_ERROR; } - if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode)) { + size_t model_size = 0; + if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode, &model_size)) { MS_LOG(ERROR) << "Serialize graph and adjust weight failed"; return RET_ERROR; } + *size = model_size + tensors_size; } return RET_OK; } diff --git a/mindspore/lite/tools/common/meta_graph_serializer.h b/mindspore/lite/tools/common/meta_graph_serializer.h index 5c22d5e3d23..7dd69faaf6b 100644 --- a/mindspore/lite/tools/common/meta_graph_serializer.h +++ b/mindspore/lite/tools/common/meta_graph_serializer.h @@ -33,6 +33,8 @@ class MetaGraphSerializer { // save serialized fb model static int Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key = {}, const size_t key_len = 0, const std::string &enc_mode = ""); + static int Save(const schema::MetaGraphT &graph, const std::string &output_path, size_t *size, const Byte *key = {}, + const size_t key_len = 0, const std::string &enc_mode = ""); private: MetaGraphSerializer() = default; @@ -48,7 +50,7 @@ class MetaGraphSerializer { bool ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph); bool SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key, const size_t key_len, - const std::string &enc_mode); + const std::string &enc_mode, size_t *size = 0); bool SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len, const std::string &enc_mode); diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc index 71b83d77efc..15b89b56ad7 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc @@ -345,10 +345,9 @@ int ConfigFileParser::ParseMixedBitQuantString(const std::map parse_map{ - {"init_scale", mixed_bit_quant_string_.init_scale}, - {"auto_tune", mixed_bit_quant_string_.auto_tune}, - {"use_cv_data", mixed_bit_quant_string_.use_cv_data}, - {"max_iterations", mixed_bit_quant_string_.max_iterations}, + {"init_scale", mixed_bit_quant_string_.init_scale}, {"auto_tune", mixed_bit_quant_string_.auto_tune}, + {"use_cv_data", mixed_bit_quant_string_.use_cv_data}, {"max_iterations", mixed_bit_quant_string_.max_iterations}, + {"workspace", mixed_bit_quant_string_.workspace}, }; return SetMapData(map, parse_map, kMixedBitWeightQuantParam); } diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h index 07dd994a1f1..3ec85e4f046 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h @@ -53,6 +53,7 @@ struct MixedBitWeightQuantString { std::string auto_tune; std::string use_cv_data; std::string max_iterations; + std::string workspace; }; struct WeightQuantString { diff --git a/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc b/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc index 08ab7ebe5d9..fb51286cae0 100644 --- a/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc @@ -170,6 +170,9 @@ int QuantParamParser::ParseMixedBitWeightQuant(const MixedBitWeightQuantString & MS_LOG(ERROR) << "INPUT ILLEGAL: auto_tune should be true or false."; return RET_INPUT_PARAM_INVALID; } + + // this is required only for model larger than 2G + mixed_bit_weight_quant->workspace = mixed_bit_weight_quant_string.workspace; return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/quant_params.h b/mindspore/lite/tools/converter/quantizer/quant_params.h index 19922f09cf5..a2b4719a8c5 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_params.h +++ b/mindspore/lite/tools/converter/quantizer/quant_params.h @@ -129,6 +129,7 @@ struct MixedBitWeightQuantParam { bool auto_tune = false; bool use_cv_data = false; int max_iterations = kMinIterations; + std::string workspace; // only for model larger than 2G }; struct FullQuantParam { diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 1e1ee016984..fe4ad8d9cb6 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -47,6 +47,8 @@ namespace mindspore::lite::quant { namespace { constexpr size_t kGatherAxisIndex = 3; constexpr int kDefaultThreadNum = 4; +constexpr size_t kEncMaxLen = 16; +constexpr size_t kModelSizeLimit = static_cast(2) * 1024 * 1024 * 1024; } // namespace int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type) { @@ -204,6 +206,27 @@ std::string NodePrimitiveType(const CNodePtr &cnode) { return primitive_c->name(); } +Status LargeModelBuildModel(const schema::MetaGraphT &meta_graph, const std::shared_ptr ¶m, + const std::shared_ptr &model, const std::shared_ptr &context, + size_t *size) { + if (param->mixedBitWeightQuantParam.workspace.empty()) { + MS_LOG(ERROR) << "The model is larger than 2G, mixedBitWeightQuant config needs to set workspace to save tmp model"; + return kLiteError; + } + std::string tmp_save_file_path = param->mixedBitWeightQuantParam.workspace + "/tmp.ms"; + unsigned char encKey[kEncMaxLen] = {0}; + size_t keyLen = 0; + auto status = MetaGraphSerializer::Save(meta_graph, tmp_save_file_path, size, encKey, keyLen, param->encrypt_mode); + if (status != RET_OK) { + MS_LOG(ERROR) << "Save Large Model Failed: " << status << " " << GetErrorInfo(status); + return kLiteError; + } + + mindspore::ModelType model_type = kMindIR_Lite; + auto ret = model->Build(tmp_save_file_path, model_type, context); + return ret; +} + Status BuildModelByFuncGraph(const std::shared_ptr &model, const FuncGraphPtr &func_graph, const std::shared_ptr ¶m, size_t *size) { FuncGraphPtr func_graph_clone; @@ -236,17 +259,6 @@ Status BuildModelByFuncGraph(const std::shared_ptr &model, con return kLiteError; } - flatbuffers::FlatBufferBuilder builder(kMaxNum1024); - auto offset = schema::MetaGraph::Pack(builder, meta_graph); - builder.Finish(offset); - schema::FinishMetaGraphBuffer(builder, offset); - *size = builder.GetSize(); - auto *content = reinterpret_cast(builder.GetBufferPointer()); - if (content == nullptr) { - MS_LOG(ERROR) << "GetBufferPointer return null"; - delete meta_graph; - return kLiteNullptr; - } auto context = std::make_shared(); if (context == nullptr) { MS_LOG(ERROR) << "New context failed while running."; @@ -264,6 +276,30 @@ Status BuildModelByFuncGraph(const std::shared_ptr &model, con } auto &device_list = context->MutableDeviceInfo(); device_list.push_back(device_info); + + size_t tensors_size = 0; + for (auto &tensor : meta_graph->allTensors) { + tensors_size += tensor->data.size(); + } + + if (tensors_size >= kModelSizeLimit) { + auto ret = LargeModelBuildModel(*meta_graph, param, model, context, size); + delete meta_graph; + return ret; + } + + flatbuffers::FlatBufferBuilder builder(kMaxNum1024); + auto offset = schema::MetaGraph::Pack(builder, meta_graph); + builder.Finish(offset); + schema::FinishMetaGraphBuffer(builder, offset); + *size = builder.GetSize(); + auto *content = reinterpret_cast(builder.GetBufferPointer()); + if (content == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer return null"; + delete meta_graph; + return kLiteNullptr; + } + auto ret = model->Build(content, *size, kMindIR, context); delete meta_graph; return ret;