modify lite exportMindir para

This commit is contained in:
zhou_chao1993 2022-07-18 16:58:10 +08:00
parent 058f6a2577
commit 44c9dac785
12 changed files with 93 additions and 65 deletions

View File

@ -20,7 +20,7 @@ mindspore_lite.Converter
- **input_format** (Format可选) - 指定导出模型的输入格式。仅对四维输入有效。选项Format.NHWC | Format.NCHW。默认值Format.NHWC。
- **input_data_type** (DataType可选) - 输入张量的数据类型默认与模型中定义的类型相同。默认值DataType.FLOAT32。
- **output_data_type** (DataType可选) - 输出张量的数据类型默认与模型中定义的类型相同。默认值DataType.FLOAT32。
- **export_mindir** (bool可选) - 是否导出MindIR pb。默认值False
- **export_mindir** (ModelType可选) - 导出模型文件的类型。默认值ModelType.MINDIR_LITE
- **decrypt_key** (str可选) - 用于解密文件的密钥以十六进制字符表示。仅当fmk_type为FmkType.MINDIR时有效。默认值""。
- **decrypt_mode** (str可选) - MindIR文件的解密方法。仅在设置decrypt_key时有效。选项"AES-GCM" | "AES-CBC"。默认值:"AES-GCM"。
- **enable_encryption** (bool可选) - 是否导出加密模型。默认值False。
@ -45,7 +45,7 @@ mindspore_lite.Converter
- **TypeError** - `input_format` 不是Format类型。
- **TypeError** - `input_data_type` 不是DataType类型。
- **TypeError** - `output_data_type` 不是DataType类型。
- **TypeError** - `export_mindir` 不是bool类型。
- **TypeError** - `export_mindir` 不是ModelType类型。
- **TypeError** - `decrypt_key` 不是str类型。
- **TypeError** - `decrypt_mode` 不是str类型。
- **TypeError** - `enable_encryption` 不是bool类型。

View File

@ -53,8 +53,8 @@ class MS_API Converter {
void SetOutputDataType(DataType data_type);
DataType GetOutputDataType();
void SetExportMindIR(bool export_mindir);
bool GetExportMindIR() const;
void SetExportMindIR(ModelType export_mindir);
ModelType GetExportMindIR() const;
void SetDecryptKey(const std::string &key);
std::string GetDecryptKey() const;

View File

@ -22,6 +22,7 @@ from enum import Enum
from ._checkparam import check_isinstance, check_input_shape, check_config_info
from .lib import _c_lite_wrapper
from .tensor import DataType, Format, data_type_py_cxx_map, data_type_cxx_py_map, format_py_cxx_map, format_cxx_py_map
from .model import ModelType
__all__ = ['FmkType', 'Converter']
@ -69,7 +70,7 @@ class Converter:
defined in model. Default: DataType.FLOAT32.
output_data_type (DataType, optional): Data type of output tensors.
The default type is same with the type defined in model. Default: DataType.FLOAT32.
export_mindir (bool, optional): Whether to export MindIR pb. Default: False.
export_mindir (ModelType, optional): Which model type need to be export. Default: ModelType.MINDIR_LITE.
decrypt_key (str, optional): The key used to decrypt the file, expressed in hexadecimal characters.
Only valid when fmk_type is FmkType.MINDIR. Default: "".
decrypt_mode (str, optional): Decryption method for the MindIR file. Only valid when dec_key is set.
@ -96,7 +97,7 @@ class Converter:
TypeError: `input_format` is not a Format.
TypeError: `input_data_type` is not a DataType.
TypeError: `output_data_type` is not a DataType.
TypeError: `export_mindir` is not a bool.
TypeError: `export_mindir` is not a ModelType.
TypeError: `decrypt_key` is not a str.
TypeError: `decrypt_mode` is not a str.
TypeError: `enable_encryption` is not a bool.
@ -121,7 +122,7 @@ class Converter:
input_format: Format.NHWC,
input_data_type: DataType.FLOAT32,
output_data_type: DataType.FLOAT32,
export_mindir: False,
export_mindir: MINDIR_LITE,
decrypt_key: ,
decrypt_mode: ,
enable_encryption: False,
@ -133,8 +134,9 @@ class Converter:
def __init__(self, fmk_type, model_file, output_file, weight_file="", config_file="", weight_fp16=False,
input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32,
output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM",
enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False):
output_data_type=DataType.FLOAT32, export_mindir=ModelType.MINDIR_LITE, decrypt_key="",
decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", infer=False, train_model=False,
no_fusion=False):
check_isinstance("fmk_type", fmk_type, FmkType)
check_isinstance("model_file", model_file, str)
check_isinstance("output_file", output_file, str)
@ -145,7 +147,7 @@ class Converter:
check_isinstance("input_format", input_format, Format)
check_isinstance("input_data_type", input_data_type, DataType)
check_isinstance("output_data_type", output_data_type, DataType)
check_isinstance("export_mindir", export_mindir, bool)
check_isinstance("export_mindir", export_mindir, ModelType)
check_isinstance("decrypt_key", decrypt_key, str)
check_isinstance("decrypt_mode", decrypt_mode, str)
check_isinstance("enable_encryption", enable_encryption, bool)
@ -189,7 +191,7 @@ class Converter:
self._converter.set_input_data_type(data_type_py_cxx_map.get(input_data_type))
if output_data_type != DataType.FLOAT32:
self._converter.set_output_data_type(data_type_py_cxx_map.get(output_data_type))
if export_mindir:
if export_mindir != ModelType.MINDIR_LITE:
self._converter.set_export_mindir(export_mindir)
if decrypt_key != "":
self._converter.set_decrypt_key(decrypt_key)

View File

@ -70,15 +70,19 @@ char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size
}
ConverterImpl cvt;
auto meta_graph = cvt.Convert(param, model_buf, buf_size);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Convert failed.";
schema::MetaGraphT *meta_graph = nullptr;
auto status = cvt.Convert(param, &meta_graph, model_buf, buf_size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert model failed.";
return nullptr;
}
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "meta graph is nullptr.";
return nullptr;
}
void *lite_buf = nullptr;
meta_graph->version = Version();
auto status = TransferMetaGraph(*meta_graph, &lite_buf, size);
status = TransferMetaGraph(*meta_graph, &lite_buf, size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Transfer model failed.";
delete meta_graph;
@ -103,9 +107,14 @@ char *RuntimeConvert(const std::string &file_path, size_t *size) {
param->train_model = false;
ConverterImpl cvt;
auto meta_graph = cvt.Convert(param);
MS_LOG(ERROR) << "Convert failed.";
schema::MetaGraphT *meta_graph = nullptr;
auto status = cvt.Convert(param, &meta_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert model failed";
return nullptr;
}
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "meta graph is nullptr.";
return nullptr;
}

View File

@ -178,13 +178,13 @@ def test_converter_26():
with pytest.raises(TypeError) as raise_info:
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
output_file="mobilenetv2.tflite", export_mindir=1)
assert "export_mindir must be bool" in str(raise_info.value)
assert "export_mindir must be ModelType" in str(raise_info.value)
def test_converter_27():
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
output_file="mobilenetv2.tflite", export_mindir=True)
assert "export_mindir: True" in str(converter)
output_file="mobilenetv2.tflite", export_mindir=mslite.ModelType.MINDIR_LITE)
assert "export_mindir: ModelType.kMindIR_Lite" in str(converter)
def test_converter_28():

View File

@ -119,33 +119,40 @@ FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr<ConverterPara>
return func_graph;
}
schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, const void *buf,
const size_t &size) {
int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph,
const void *buf, const size_t &size) {
if (param == nullptr || buf == nullptr) {
MS_LOG(ERROR) << "Input param is nullptr";
return nullptr;
return RET_ERROR;
}
auto graph = BuildFuncGraph(param, buf, size);
if (graph == nullptr) {
MS_LOG(ERROR) << "Parser/Import model return nullptr";
return nullptr;
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed.");
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, RET_ERROR, "funcgraph_transform init failed.");
// funcgraph_transform
graph = funcgraph_transform_->Transform(graph, param);
MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr.");
MS_CHECK_TRUE_MSG(graph != nullptr, RET_ERROR, "Transform anf graph return nullptr.");
// export protobuf
auto status = MindIRSerialize(param, graph);
if (status != RET_OK) {
MS_LOG(WARNING) << "Export to mindir proto return nullptr.";
if (param->export_mindir == kMindIR) {
auto status = MindIRSerialize(param, graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Export to mindir proto failed";
return RET_ERROR;
} else {
MS_LOG(DEBUG) << "Export to mindir success";
return RET_OK;
}
}
return TransferFuncGraph(param, graph);
*meta_graph = TransferFuncGraph(param, graph);
return RET_OK;
}
schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param) {
int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph) {
if (param == nullptr) {
MS_LOG(ERROR) << "Input param is nullptr";
return nullptr;
return RET_ERROR;
}
param->aclModelOptionCfgParam.om_file_path = param->output_file;
@ -153,7 +160,7 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara>
auto ret = InitConfigParam(param);
if (ret != RET_OK) {
std::cerr << "Init config file failed." << std::endl;
return nullptr;
return RET_ERROR;
}
}
@ -162,11 +169,11 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara>
if (!param->plugins_path.empty()) {
for (auto &path : param->plugins_path) {
auto dl_loader = std::make_shared<DynamicLibraryLoader>();
MS_CHECK_TRUE_RET(dl_loader != nullptr, nullptr);
MS_CHECK_TRUE_RET(dl_loader != nullptr, RET_ERROR);
auto status = dl_loader->Open(path);
if (status != RET_OK) {
MS_LOG(ERROR) << "open dynamic library failed. " << path;
return nullptr;
return RET_ERROR;
}
dl_loaders.emplace_back(dl_loader);
}
@ -175,24 +182,30 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara>
auto graph = BuildFuncGraph(param);
if (graph == nullptr) {
MS_LOG(ERROR) << "Parser/Import model return nullptr";
return nullptr;
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, RET_ERROR, "funcgraph_transform init failed");
// funcgraph transform
graph = funcgraph_transform_->Transform(graph, param);
if (graph == nullptr) {
MS_LOG(ERROR) << "Transform anf graph return nullptr";
return nullptr;
return RET_ERROR;
}
// export protobuf
auto status = MindIRSerialize(param, graph);
if (status != RET_OK) {
MS_LOG(WARNING) << "Export to mindir proto return nullptr.";
if (param->export_mindir == kMindIR) {
auto status = MindIRSerialize(param, graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Export to mindir failed";
return RET_ERROR;
} else {
MS_LOG(DEBUG) << "Export to mindir success";
return RET_OK;
}
}
return TransferFuncGraph(param, graph);
*meta_graph = TransferFuncGraph(param, graph);
return RET_OK;
}
schema::MetaGraphT *ConverterImpl::TransferFuncGraph(const std::shared_ptr<ConverterPara> &param,
@ -531,9 +544,18 @@ int RunConverter(const std::shared_ptr<ConverterPara> &param, void **model_data,
param->aclModelOptionCfgParam.offline = !not_save;
ConverterImpl converter_impl;
auto meta_graph = converter_impl.Convert(param);
schema::MetaGraphT *meta_graph = nullptr;
int status = converter_impl.Convert(param, &meta_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert model failed";
return RET_ERROR;
}
if (param->export_mindir == kMindIR) {
MS_LOG(DEBUG) << "No need to export mindir to fb";
return RET_OK;
}
NotSupportOp::GetInstance()->PrintOps();
int status = ReturnCode::GetSingleReturnCode()->status_code();
status = ReturnCode::GetSingleReturnCode()->status_code();
std::ostringstream oss;
if (meta_graph == nullptr) {
oss.clear();

View File

@ -52,8 +52,9 @@ class ConverterImpl {
delete model_parser_;
this->model_parser_ = nullptr;
}
schema::MetaGraphT *Convert(const std::shared_ptr<ConverterPara> &param);
schema::MetaGraphT *Convert(const std::shared_ptr<ConverterPara> &param, const void *buf, const size_t &size);
int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph);
int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph, const void *buf,
const size_t &size);
private:
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> &param);

View File

@ -32,7 +32,7 @@ Flags::Flags() {
AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", "");
AddFlag(&Flags::modelFile, "modelFile",
"Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", "");
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
AddFlag(&Flags::outputFile, "outputFile", "Output model file path.", "");
AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel",
"");
AddFlag(&Flags::inputDataTypeStr, "inputDataType",
@ -82,10 +82,7 @@ Flags::Flags() {
"Whether to do pre-inference after convert. "
"true | false",
"false");
AddFlag(&Flags::exportMindIR, "exportMindIR",
"Whether to export MindIR pb. "
"true | false",
"false");
AddFlag(&Flags::exportMindIR, "exportMindIR", "MINDIR | MINDIR_LITE", "MINDIR_LITE");
AddFlag(&Flags::noFusionStr, "NoFusion", "Avoid fusion optimization true|false", "false");
}
@ -279,10 +276,10 @@ int Flags::InitNoFusion() {
}
int Flags::InitExportMindIR() {
if (this->exportMindIR == "true") {
this->export_mindir = true;
} else if (this->exportMindIR == "false") {
this->export_mindir = false;
if (this->exportMindIR == "MINDIR") {
this->export_mindir = kMindIR;
} else if (this->exportMindIR == "MINDIR_LITE") {
this->export_mindir = kMindIR_Lite;
} else {
std::cerr << "INPUT ILLEGAL: exportMindIR must be true|false " << std::endl;
return RET_INPUT_PARAM_INVALID;

View File

@ -71,7 +71,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string inferStr;
bool infer = false;
std::string exportMindIR;
bool export_mindir = false;
ModelType export_mindir = kMindIR_Lite;
#ifdef ENABLE_OPENSSL
std::string encryptionStr = "true";
bool encryption = true;

View File

@ -151,17 +151,17 @@ DataType Converter::GetOutputDataType() {
}
}
void Converter::SetExportMindIR(bool export_mindir) {
void Converter::SetExportMindIR(ModelType export_mindir) {
if (data_ != nullptr) {
data_->export_mindir = export_mindir;
}
}
bool Converter::GetExportMindIR() const {
ModelType Converter::GetExportMindIR() const {
if (data_ != nullptr) {
return data_->export_mindir;
} else {
return false;
return kMindIR_Lite;
}
}

View File

@ -48,7 +48,7 @@ struct ConverterPara {
Format input_format = NHWC;
DataType input_data_type = DataType::kNumberTypeFloat32;
DataType output_data_type = DataType::kNumberTypeFloat32;
bool export_mindir = false;
ModelType export_mindir = kMindIR_Lite;
std::string decrypt_key;
std::string decrypt_mode;
std::string encrypt_key;

View File

@ -403,9 +403,6 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st
}
int MindIRSerialize(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph) {
if (!param->export_mindir) {
return RET_OK;
}
mindspore::lite::MindIRSerializer serializer;
return serializer.Save(param, func_graph);
}