forked from mindspore-Ecosystem/mindspore
!48547 complete the param of converter_lite
Merge pull request !48547 from 周超/master5
This commit is contained in:
commit
b9c8db3b80
|
@ -210,8 +210,8 @@ if(DEFINED ENV{MSLITE_ENABLE_MODEL_PRE_INFERENCE})
|
|||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)
|
||||
set(MSLITE_ENABLE_MODEL_ENCRYPTION ON)
|
||||
set(MSLITE_ENABLE_CONVERTER ON)
|
||||
set(MSLITE_ENABLE_TRAIN off)
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_COVERAGE})
|
||||
|
@ -254,7 +254,6 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)
|
|||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_ACL AND (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE))
|
||||
set(MSLITE_ENABLE_MODEL_ENCRYPTION ON)
|
||||
set(PLATFORM_ARM32 off)
|
||||
endif()
|
||||
|
||||
|
@ -786,7 +785,6 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)
|
|||
set(MSLITE_DEPS_MKLDNN on)
|
||||
set(MSLITE_DEPS_LIBEVENT on)
|
||||
set(MSLITE_DEPS_PYBIND11 on)
|
||||
set(MSLITE_DEPS_OPENSSL on)
|
||||
if(SUPPORT_TENSORRT)
|
||||
set(MSLITE_DEPS_FAST_TRANSFORMERS on)
|
||||
endif()
|
||||
|
|
|
@ -99,7 +99,7 @@ if(MSLITE_ENABLE_CONTROLFLOW)
|
|||
)
|
||||
endif()
|
||||
|
||||
if(SUPPORT_TRAIN)
|
||||
if(SUPPORT_TRAIN OR MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)
|
||||
file(GLOB TRAIN_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
|
||||
file(GLOB TRAIN_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp32_grad/*.cc)
|
||||
add_library(train_cpu_kernel_mid OBJECT ${TRAIN_KERNEL_SRC})
|
||||
|
|
|
@ -12,7 +12,7 @@ function Convert() {
|
|||
fail=0
|
||||
local cfg_file_list model_info model_name extra_info model_type cfg_file_name model_file weight_file output_file \
|
||||
quant_type config_file train_model in_dtype out_dtype converter_result cfg_file calib_size save_type \
|
||||
encryption_flag input_format
|
||||
input_format
|
||||
cfg_file_list=$1
|
||||
for cfg_file in ${cfg_file_list[*]}; do
|
||||
while read line; do
|
||||
|
@ -70,7 +70,6 @@ function Convert() {
|
|||
out_dtype="DEFAULT"
|
||||
fp16_weight="off"
|
||||
save_type="MINDIR_LITE"
|
||||
encryption_flag="false"
|
||||
optimize="general"
|
||||
if [[ ${cfg_file_name} =~ "weightquant" ]]; then
|
||||
# models_weightquant_${suffix}.cfg
|
||||
|
@ -91,7 +90,6 @@ function Convert() {
|
|||
fi
|
||||
elif [[ ${cfg_file_name} =~ "_cloud" ]]; then
|
||||
save_type="MINDIR"
|
||||
encryption_flag="false"
|
||||
if [[ ${input_shapes} != "" && ${input_names} != "" ]]; then
|
||||
if [[ ${input_num} == "" ]]; then
|
||||
input_num=1
|
||||
|
@ -139,12 +137,12 @@ function Convert() {
|
|||
if [[ ${cfg_file_name} =~ "_cloud" ]]; then
|
||||
echo "./converter_lite --fmk=${model_fmk} --modelFile=${model_file} --weightFile=${weight_file} --outputFile=${output_file}\
|
||||
--inputDataType=${in_dtype} --outputDataType=${out_dtype} --inputShape=${spec_shapes} --fp16=${fp16_weight}\
|
||||
--configFile=${config_file} --saveType=${save_type} --optimize=${optimize} --encryption=${encryption_flag}\
|
||||
--configFile=${config_file} --saveType=${save_type} --optimize=${optimize} \
|
||||
--trainModel=${train_model} --inputDataFormat=${input_format}"
|
||||
|
||||
./converter_lite --fmk=${model_fmk} --modelFile=${model_file} --weightFile=${weight_file} --outputFile=${output_file}\
|
||||
--inputDataType=${in_dtype} --outputDataType=${out_dtype} --inputShape="${spec_shapes}" --fp16=${fp16_weight}\
|
||||
--configFile=${config_file} --saveType=${save_type} --optimize=${optimize} --encryption=${encryption_flag}\
|
||||
--configFile=${config_file} --saveType=${save_type} --optimize=${optimize} \
|
||||
--trainModel=${train_model} --inputDataFormat=${input_format} >> "$4"
|
||||
else
|
||||
echo "./converter_lite --fmk=${model_fmk} --modelFile=${model_file} --weightFile=${weight_file} --outputFile=${output_file}\
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "ops/core_ops.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/common/string_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -43,6 +44,7 @@ constexpr size_t kTupleGetItemInputSize = 3;
|
|||
constexpr size_t kSecondIndex = 1;
|
||||
constexpr size_t kInvalidSize = SIZE_MAX;
|
||||
constexpr auto kMakeTuple = "MakeTuple";
|
||||
constexpr size_t kEncMaxLen = 16;
|
||||
} // namespace
|
||||
|
||||
static STATUS GetAbstractfromTupleGetItem(const CNodePtr &cnode, AbstractBasePtr *abstract, size_t *idx) {
|
||||
|
@ -685,5 +687,23 @@ int TransferMetaGraph(const schema::MetaGraphT &graph, void **model_buf, size_t
|
|||
(void)memcpy(*model_buf, content, *size);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int InitEncryptKey(const std::shared_ptr<ConverterPara> ¶m, unsigned char *encKey, size_t *keyLen) {
|
||||
if (!param->enable_encryption) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (param->encrypt_key.empty()) {
|
||||
MS_LOG(ERROR) << "param->encrypt_key is empty.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
*keyLen = lite::Hex2ByteArray(param->encrypt_key, encKey, kEncMaxLen);
|
||||
if (*keyLen != kEncMaxLen) {
|
||||
MS_LOG(ERROR) << "enc_key must expressed in hexadecimal characters "
|
||||
<< " and only support AES-GCM method and the key length is " << kEncMaxLen;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/converter/cxx_api/converter_para.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -99,6 +100,8 @@ STATUS GetShapeVectorAndIdxFromCNode(const CNodePtr &cnode, std::vector<int64_t>
|
|||
STATUS GetShapeVectorFromParameter(const mindspore::ParameterPtr ¶m_node, std::vector<int64_t> *shape_vector);
|
||||
|
||||
STATUS GetCNodeOrParameterShapeVec(const AnfNodePtr &anf_node, std::vector<int> *shape);
|
||||
|
||||
int InitEncryptKey(const std::shared_ptr<ConverterPara> ¶m, unsigned char *encKey, size_t *keyLen);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -204,6 +204,8 @@ set(MODEL_LOADER_FRAMEWORK_SRC
|
|||
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)
|
||||
add_compile_definitions(ENABLE_CLOUD_FUSION_INFERENCE)
|
||||
add_compile_definitions(ENABLE_CLOUD_INFERENCE)
|
||||
add_compile_definitions(SUPPORT_TRAIN)
|
||||
set(SUPPORT_TRAIN on)
|
||||
|
||||
# string(REPLACE "-Werror" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
|
||||
# string(REPLACE "-Werror" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
|
|
|
@ -213,6 +213,70 @@ int PreInference(const schema::MetaGraphT &meta_graph, bool train_model) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int PreInferenceMindIR(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &graph, bool train_model) {
|
||||
if (train_model) {
|
||||
MS_LOG(WARNING) << "train model dont support pre-infer.";
|
||||
return RET_OK;
|
||||
}
|
||||
MindIRSerializer mindir_serializer(false);
|
||||
auto ret = mindir_serializer.Save(param, graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Save funcgraph failed";
|
||||
return ret;
|
||||
}
|
||||
void *data = nullptr;
|
||||
size_t size = 0;
|
||||
ret = mindir_serializer.GetBuffAndSize(&data, &size);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get buffer and size failed";
|
||||
return ret;
|
||||
}
|
||||
|
||||
mindspore::Model model;
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "New context failed while running ";
|
||||
std::cerr << "New context failed while running " << std::endl;
|
||||
free(data);
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
|
||||
auto &device_list = context->MutableDeviceInfo();
|
||||
device_list.push_back(device_info);
|
||||
|
||||
auto status = model.Build(data, size, kMindIR, context);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "Build error ";
|
||||
std::cerr << "Build error " << std::endl;
|
||||
free(data);
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (auto &tensor : model.GetInputs()) {
|
||||
if (tensor.Shape().empty() || tensor.DataSize() == 0 ||
|
||||
std::find(tensor.Shape().begin(), tensor.Shape().end(), -1) != tensor.Shape().end()) {
|
||||
MS_LOG(WARNING) << tensor.Name() << " is dynamic shape and will not be pre-infer.";
|
||||
free(data);
|
||||
return RET_OK;
|
||||
}
|
||||
ret = GenerateRandomData(&tensor);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << tensor.Name() << "GenerateRandomData failed.";
|
||||
free(data);
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
std::vector<MSTensor> outputs;
|
||||
status = model.Predict(model.GetInputs(), &outputs);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "Inference error ";
|
||||
std::cerr << "Inference error " << std::endl;
|
||||
free(data);
|
||||
return RET_ERROR;
|
||||
}
|
||||
free(data);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConverterImpl::InitConfigParam(const std::shared_ptr<ConverterPara> ¶m) {
|
||||
lite::ConfigFileParser config_parser;
|
||||
std::map<std::string, std::map<std::string, std::string>> maps;
|
||||
|
@ -740,9 +804,9 @@ int ConverterImpl::SaveGraph(FuncGraphPtr graph, const std::shared_ptr<Converter
|
|||
size_t *data_size, bool not_save) {
|
||||
int status = RET_ERROR;
|
||||
if (param->export_mindir == kMindIR) {
|
||||
status = ConverterFuncGraph::Save(param, graph, model_data, data_size);
|
||||
status = SaveMindIRModel(graph, param, model_data, data_size);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Export to mindir failed: " << status << " " << GetErrorInfo(status);
|
||||
MS_LOG(ERROR) << "Save mindir model failed :" << status << " " << GetErrorInfo(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -773,13 +837,31 @@ int ConverterImpl::SaveGraph(FuncGraphPtr graph, const std::shared_ptr<Converter
|
|||
}
|
||||
delete meta_graph;
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Save failed:" << status << " " << GetErrorInfo(status);
|
||||
MS_LOG(ERROR) << "Save failed:" << status << " " << GetErrorInfo(status);
|
||||
return status;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConverterImpl::SaveMindIRModel(FuncGraphPtr graph, const std::shared_ptr<ConverterPara> ¶m, void **model_data,
|
||||
size_t *data_size) {
|
||||
int status = RET_OK;
|
||||
if (param->pre_infer) {
|
||||
status = PreInferenceMindIR(param, graph, param->train_model);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "PreInferenceMindIR failed: " << status << " " << GetErrorInfo(status);
|
||||
return status;
|
||||
}
|
||||
}
|
||||
status = ConverterFuncGraph::Save(param, graph, model_data, data_size);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Export to mindir failed: " << status << " " << GetErrorInfo(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int RunConverter(const std::shared_ptr<ConverterPara> ¶m, void **model_data, size_t *data_size, bool not_save) {
|
||||
mindspore::mindspore_log_init();
|
||||
|
||||
|
|
|
@ -59,6 +59,8 @@ class ConverterImpl {
|
|||
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key);
|
||||
int SaveGraph(FuncGraphPtr graph, const std::shared_ptr<ConverterPara> ¶m, void **model_data, size_t *data_size,
|
||||
bool not_save);
|
||||
int SaveMindIRModel(FuncGraphPtr graph, const std::shared_ptr<ConverterPara> ¶m, void **model_data,
|
||||
size_t *data_size);
|
||||
int LoadPluginLib(const std::shared_ptr<ConverterPara> ¶m);
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -88,7 +88,7 @@ Flags::Flags() {
|
|||
AddFlag(&Flags::device, "device",
|
||||
"Set the target device, support Ascend, Ascend310 and Ascend310P will be deprecated.", "");
|
||||
AddFlag(&Flags::saveTypeStr, "saveType", "The type of saved model. MINDIR | MINDIR_LITE", "MINDIR_LITE");
|
||||
AddFlag(&Flags::optimizeStr, "optimize", "The type of optimization. none | general | ascend_oriented", "");
|
||||
AddFlag(&Flags::optimizeStr, "optimize", "The type of optimization. none | general | ascend_oriented", "general");
|
||||
AddFlag(&Flags::optimizeTransformerStr, "optimizeTransformer", "Enable Fast-Transformer fusion true|false", "false");
|
||||
}
|
||||
|
||||
|
@ -306,9 +306,6 @@ int Flags::InitSaveType() {
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (this->saveTypeStr == "MINDIR") {
|
||||
this->disableFusion = true;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -35,24 +35,6 @@ constexpr size_t kFlatbuffersBuilderInitSize = 1024;
|
|||
constexpr size_t kEncMaxLen = 16;
|
||||
} // namespace
|
||||
|
||||
int InitEncryptKey(const std::shared_ptr<ConverterPara> ¶m, unsigned char *encKey, size_t *keyLen) {
|
||||
if (!param->enable_encryption) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (param->encrypt_key.empty()) {
|
||||
MS_LOG(ERROR) << "param->encrypt_key is empty.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
*keyLen = lite::Hex2ByteArray(param->encrypt_key, encKey, kEncMaxLen);
|
||||
if (*keyLen != kEncMaxLen) {
|
||||
MS_LOG(ERROR) << "enc_key must expressed in hexadecimal characters "
|
||||
<< " and only support AES-GCM method and the key length is " << kEncMaxLen;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS ConverterToMetaGraph::UpdateMetaGraphOutputName(schema::MetaGraphT *meta_graph,
|
||||
const std::vector<std::string> &output_names) {
|
||||
MS_CHECK_TRUE_MSG(meta_graph != nullptr, RET_NULL_PTR, "meta_graph is nullptr");
|
||||
|
|
|
@ -20,11 +20,13 @@
|
|||
#include <fstream>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "utils/crypto.h"
|
||||
#include "mindspore/ccsrc/include/common/debug/dump_proto.h"
|
||||
#include "mindspore/ccsrc/include/common/utils/utils.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/common/common.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "mindspore/core/utils/file_utils.h"
|
||||
#include "mindspore/core/ir/quantization_param.h"
|
||||
#include "mindspore/lite/tools/converter/quantizer/quant_param_holder.h"
|
||||
|
@ -32,12 +34,13 @@
|
|||
#include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
// unit is byte. model size more than 1G need split.
|
||||
constexpr const size_t TOTAL_SAVE = 1024 * 1024 * 1024;
|
||||
constexpr const size_t PARA_ROUND = 1024;
|
||||
constexpr const int64_t OFFSET = 64;
|
||||
constexpr size_t kEncMaxLen = 16;
|
||||
|
||||
namespace {
|
||||
bool DeleteDirRecursively(const std::string &dir_name) {
|
||||
DIR *dir = opendir(dir_name.c_str());
|
||||
dirent *dirent = nullptr;
|
||||
|
@ -203,9 +206,9 @@ int MindIRSerializer::Save(const std::shared_ptr<ConverterPara> ¶m, const Fu
|
|||
}
|
||||
|
||||
if (save_together_) {
|
||||
ret = SaveMindIRTogether();
|
||||
ret = SaveMindIRTogether(param);
|
||||
} else {
|
||||
ret = SplitSave();
|
||||
ret = SplitSave(param);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "save mindir weight failed.";
|
||||
|
@ -308,7 +311,7 @@ std::shared_ptr<mindspore::QuantizationParam> MindIRSerializer::ConvertQuantPara
|
|||
return std::make_shared<mindspore::QuantizationParam>(quantization);
|
||||
}
|
||||
|
||||
int MindIRSerializer::SaveMindIRTogether() {
|
||||
int MindIRSerializer::SaveMindIRTogether(const std::shared_ptr<ConverterPara> ¶m) {
|
||||
for (auto ¶m_proto : *(model_proto_.mutable_graph()->mutable_parameter())) {
|
||||
std::string proto_name = param_proto.name();
|
||||
auto para = GetFgParaAccordingToProtoName(proto_name);
|
||||
|
@ -323,7 +326,7 @@ int MindIRSerializer::SaveMindIRTogether() {
|
|||
param_proto.set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
|
||||
}
|
||||
|
||||
return SaveProtoToFile(&model_proto_, save_model_path_);
|
||||
return SaveProtoToFile(&model_proto_, save_model_path_, param);
|
||||
}
|
||||
|
||||
int MindIRSerializer::CreateParameterDir() {
|
||||
|
@ -426,7 +429,7 @@ std::string MindIRSerializer::CreateExternalPath(const std::string &external_fil
|
|||
return external_local_path;
|
||||
}
|
||||
|
||||
int MindIRSerializer::SplitSave() {
|
||||
int MindIRSerializer::SplitSave(const std::shared_ptr<ConverterPara> ¶m) {
|
||||
MS_LOG(DEBUG) << "Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.";
|
||||
int ret = CreateParameterDir();
|
||||
if (ret != RET_OK) {
|
||||
|
@ -504,7 +507,7 @@ int MindIRSerializer::SplitSave() {
|
|||
#else
|
||||
split_model_file_name = save_path_ + "/" + model_name_ + "_graph.mindir";
|
||||
#endif
|
||||
return SaveProtoToFile(&model_proto_, split_model_file_name);
|
||||
return SaveProtoToFile(&model_proto_, split_model_file_name, param);
|
||||
}
|
||||
|
||||
int MindIRSerializer::ParserPath(const std::string &output_path) {
|
||||
|
@ -562,7 +565,12 @@ int MindIRSerializer::IfSaveTogether(bool *save_together) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) {
|
||||
int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file,
|
||||
const std::shared_ptr<ConverterPara> ¶m) {
|
||||
if (!is_export_model_) {
|
||||
MS_LOG(INFO) << "No need to save proto to file";
|
||||
return RET_OK;
|
||||
}
|
||||
auto realpath = Common::CreatePrefixPath(output_file, true);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path of file " << output_file << " failed.";
|
||||
|
@ -575,12 +583,48 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st
|
|||
MS_LOG(ERROR) << "Open the file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (!model_proto->SerializeToOstream(&fout)) {
|
||||
MS_LOG(ERROR) << "Failed to write the mindir proto to file " << realpath.value();
|
||||
unsigned char enc_key[kEncMaxLen] = {0};
|
||||
size_t key_len = 0;
|
||||
auto ret = InitEncryptKey(param, enc_key, &key_len);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init encrypt key failed";
|
||||
fout.close();
|
||||
return RET_ERROR;
|
||||
return ret;
|
||||
}
|
||||
if (key_len > 0) {
|
||||
void *buffer = nullptr;
|
||||
size_t size = 0;
|
||||
ret = GetBuffAndSize(&buffer, &size);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get buffer and size failed";
|
||||
fout.close();
|
||||
return ret;
|
||||
}
|
||||
size_t encrypt_len = 0;
|
||||
auto encrypt_content =
|
||||
Encrypt(&encrypt_len, reinterpret_cast<Byte *>(buffer), size, enc_key, key_len, param->encrypt_mode);
|
||||
if (encrypt_content == nullptr || encrypt_len == 0) {
|
||||
MS_LOG(ERROR) << "Encrypt failed.";
|
||||
free(buffer);
|
||||
fout.close();
|
||||
return RET_ERROR;
|
||||
}
|
||||
fout.write(reinterpret_cast<const char *>(encrypt_content.get()), encrypt_len);
|
||||
if (fout.bad()) {
|
||||
MS_LOG(ERROR) << "Write model file failed: " << save_model_path_;
|
||||
free(buffer);
|
||||
fout.close();
|
||||
return RET_ERROR;
|
||||
}
|
||||
free(buffer);
|
||||
} else {
|
||||
if (!model_proto->SerializeToOstream(&fout)) {
|
||||
MS_LOG(ERROR) << "Failed to write the mindir proto to file " << realpath.value();
|
||||
fout.close();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
fout.close();
|
||||
ChangeFileMode(realpath.value(), S_IRUSR);
|
||||
return RET_OK;
|
||||
|
|
|
@ -33,6 +33,7 @@ namespace mindspore::lite {
|
|||
class MindIRSerializer {
|
||||
public:
|
||||
MindIRSerializer() {}
|
||||
explicit MindIRSerializer(bool is_export_model) { is_export_model_ = is_export_model; }
|
||||
virtual ~MindIRSerializer() {
|
||||
if (data_fs_ != nullptr) {
|
||||
data_fs_->close();
|
||||
|
@ -47,9 +48,10 @@ class MindIRSerializer {
|
|||
private:
|
||||
int ParserPath(const std::string &output_path);
|
||||
int IfSaveTogether(bool *save_together);
|
||||
int SaveMindIRTogether();
|
||||
int SplitSave();
|
||||
int SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file);
|
||||
int SaveMindIRTogether(const std::shared_ptr<ConverterPara> ¶m);
|
||||
int SplitSave(const std::shared_ptr<ConverterPara> ¶m);
|
||||
int SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file,
|
||||
const std::shared_ptr<ConverterPara> ¶m);
|
||||
int ConvertQuantHolderToQuantizationParam(const FuncGraphPtr &func_graph);
|
||||
std::shared_ptr<mindspore::QuantizationParam> ConvertQuantParamTToQuantizationParam(
|
||||
std::vector<schema::QuantParamT> quant_param);
|
||||
|
@ -77,6 +79,7 @@ class MindIRSerializer {
|
|||
std::unordered_map<tensor::TensorPtr, mind_ir::TensorProto *> para_proto_dict_{};
|
||||
std::fstream *data_fs_ = nullptr;
|
||||
std::shared_ptr<system::FileSystem> fs_{};
|
||||
bool is_export_model_ = true;
|
||||
};
|
||||
// export func_graph
|
||||
int MindIRSerialize(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph, bool need_buff,
|
||||
|
|
Loading…
Reference in New Issue