modify lite exportMindir para
This commit is contained in:
parent
058f6a2577
commit
44c9dac785
|
@ -20,7 +20,7 @@ mindspore_lite.Converter
|
|||
- **input_format** (Format,可选) - 指定导出模型的输入格式。仅对四维输入有效。选项:Format.NHWC | Format.NCHW。默认值:Format.NHWC。
|
||||
- **input_data_type** (DataType,可选) - 输入张量的数据类型,默认与模型中定义的类型相同。默认值:DataType.FLOAT32。
|
||||
- **output_data_type** (DataType,可选) - 输出张量的数据类型,默认与模型中定义的类型相同。默认值:DataType.FLOAT32。
|
||||
- **export_mindir** (bool,可选) - 是否导出MindIR pb。默认值:False。
|
||||
- **export_mindir** (ModelType,可选) - 导出模型文件的类型。默认值:ModelType.MINDIR_LITE。
|
||||
- **decrypt_key** (str,可选) - 用于解密文件的密钥,以十六进制字符表示。仅当fmk_type为FmkType.MINDIR时有效。默认值:""。
|
||||
- **decrypt_mode** (str,可选) - MindIR文件的解密方法。仅在设置decrypt_key时有效。选项:"AES-GCM" | "AES-CBC"。默认值:"AES-GCM"。
|
||||
- **enable_encryption** (bool,可选) - 是否导出加密模型。默认值:False。
|
||||
|
@ -45,7 +45,7 @@ mindspore_lite.Converter
|
|||
- **TypeError** - `input_format` 不是Format类型。
|
||||
- **TypeError** - `input_data_type` 不是DataType类型。
|
||||
- **TypeError** - `output_data_type` 不是DataType类型。
|
||||
- **TypeError** - `export_mindir` 不是bool类型。
|
||||
- **TypeError** - `export_mindir` 不是ModelType类型。
|
||||
- **TypeError** - `decrypt_key` 不是str类型。
|
||||
- **TypeError** - `decrypt_mode` 不是str类型。
|
||||
- **TypeError** - `enable_encryption` 不是bool类型。
|
||||
|
|
|
@ -53,8 +53,8 @@ class MS_API Converter {
|
|||
void SetOutputDataType(DataType data_type);
|
||||
DataType GetOutputDataType();
|
||||
|
||||
void SetExportMindIR(bool export_mindir);
|
||||
bool GetExportMindIR() const;
|
||||
void SetExportMindIR(ModelType export_mindir);
|
||||
ModelType GetExportMindIR() const;
|
||||
|
||||
void SetDecryptKey(const std::string &key);
|
||||
std::string GetDecryptKey() const;
|
||||
|
|
|
@ -22,6 +22,7 @@ from enum import Enum
|
|||
from ._checkparam import check_isinstance, check_input_shape, check_config_info
|
||||
from .lib import _c_lite_wrapper
|
||||
from .tensor import DataType, Format, data_type_py_cxx_map, data_type_cxx_py_map, format_py_cxx_map, format_cxx_py_map
|
||||
from .model import ModelType
|
||||
|
||||
__all__ = ['FmkType', 'Converter']
|
||||
|
||||
|
@ -69,7 +70,7 @@ class Converter:
|
|||
defined in model. Default: DataType.FLOAT32.
|
||||
output_data_type (DataType, optional): Data type of output tensors.
|
||||
The default type is same with the type defined in model. Default: DataType.FLOAT32.
|
||||
export_mindir (bool, optional): Whether to export MindIR pb. Default: False.
|
||||
export_mindir (ModelType, optional): Which model type need to be export. Default: ModelType.MINDIR_LITE.
|
||||
decrypt_key (str, optional): The key used to decrypt the file, expressed in hexadecimal characters.
|
||||
Only valid when fmk_type is FmkType.MINDIR. Default: "".
|
||||
decrypt_mode (str, optional): Decryption method for the MindIR file. Only valid when dec_key is set.
|
||||
|
@ -96,7 +97,7 @@ class Converter:
|
|||
TypeError: `input_format` is not a Format.
|
||||
TypeError: `input_data_type` is not a DataType.
|
||||
TypeError: `output_data_type` is not a DataType.
|
||||
TypeError: `export_mindir` is not a bool.
|
||||
TypeError: `export_mindir` is not a ModelType.
|
||||
TypeError: `decrypt_key` is not a str.
|
||||
TypeError: `decrypt_mode` is not a str.
|
||||
TypeError: `enable_encryption` is not a bool.
|
||||
|
@ -121,7 +122,7 @@ class Converter:
|
|||
input_format: Format.NHWC,
|
||||
input_data_type: DataType.FLOAT32,
|
||||
output_data_type: DataType.FLOAT32,
|
||||
export_mindir: False,
|
||||
export_mindir: MINDIR_LITE,
|
||||
decrypt_key: ,
|
||||
decrypt_mode: ,
|
||||
enable_encryption: False,
|
||||
|
@ -133,8 +134,9 @@ class Converter:
|
|||
|
||||
def __init__(self, fmk_type, model_file, output_file, weight_file="", config_file="", weight_fp16=False,
|
||||
input_shape=None, input_format=Format.NHWC, input_data_type=DataType.FLOAT32,
|
||||
output_data_type=DataType.FLOAT32, export_mindir=False, decrypt_key="", decrypt_mode="AES-GCM",
|
||||
enable_encryption=False, encrypt_key="", infer=False, train_model=False, no_fusion=False):
|
||||
output_data_type=DataType.FLOAT32, export_mindir=ModelType.MINDIR_LITE, decrypt_key="",
|
||||
decrypt_mode="AES-GCM", enable_encryption=False, encrypt_key="", infer=False, train_model=False,
|
||||
no_fusion=False):
|
||||
check_isinstance("fmk_type", fmk_type, FmkType)
|
||||
check_isinstance("model_file", model_file, str)
|
||||
check_isinstance("output_file", output_file, str)
|
||||
|
@ -145,7 +147,7 @@ class Converter:
|
|||
check_isinstance("input_format", input_format, Format)
|
||||
check_isinstance("input_data_type", input_data_type, DataType)
|
||||
check_isinstance("output_data_type", output_data_type, DataType)
|
||||
check_isinstance("export_mindir", export_mindir, bool)
|
||||
check_isinstance("export_mindir", export_mindir, ModelType)
|
||||
check_isinstance("decrypt_key", decrypt_key, str)
|
||||
check_isinstance("decrypt_mode", decrypt_mode, str)
|
||||
check_isinstance("enable_encryption", enable_encryption, bool)
|
||||
|
@ -189,7 +191,7 @@ class Converter:
|
|||
self._converter.set_input_data_type(data_type_py_cxx_map.get(input_data_type))
|
||||
if output_data_type != DataType.FLOAT32:
|
||||
self._converter.set_output_data_type(data_type_py_cxx_map.get(output_data_type))
|
||||
if export_mindir:
|
||||
if export_mindir != ModelType.MINDIR_LITE:
|
||||
self._converter.set_export_mindir(export_mindir)
|
||||
if decrypt_key != "":
|
||||
self._converter.set_decrypt_key(decrypt_key)
|
||||
|
|
|
@ -70,15 +70,19 @@ char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size
|
|||
}
|
||||
|
||||
ConverterImpl cvt;
|
||||
auto meta_graph = cvt.Convert(param, model_buf, buf_size);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert failed.";
|
||||
schema::MetaGraphT *meta_graph = nullptr;
|
||||
auto status = cvt.Convert(param, &meta_graph, model_buf, buf_size);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert model failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "meta graph is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void *lite_buf = nullptr;
|
||||
meta_graph->version = Version();
|
||||
auto status = TransferMetaGraph(*meta_graph, &lite_buf, size);
|
||||
status = TransferMetaGraph(*meta_graph, &lite_buf, size);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Transfer model failed.";
|
||||
delete meta_graph;
|
||||
|
@ -103,9 +107,14 @@ char *RuntimeConvert(const std::string &file_path, size_t *size) {
|
|||
param->train_model = false;
|
||||
|
||||
ConverterImpl cvt;
|
||||
auto meta_graph = cvt.Convert(param);
|
||||
MS_LOG(ERROR) << "Convert failed.";
|
||||
schema::MetaGraphT *meta_graph = nullptr;
|
||||
auto status = cvt.Convert(param, &meta_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert model failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "meta graph is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -178,13 +178,13 @@ def test_converter_26():
|
|||
with pytest.raises(TypeError) as raise_info:
|
||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||
output_file="mobilenetv2.tflite", export_mindir=1)
|
||||
assert "export_mindir must be bool" in str(raise_info.value)
|
||||
assert "export_mindir must be ModelType" in str(raise_info.value)
|
||||
|
||||
|
||||
def test_converter_27():
|
||||
converter = mslite.Converter(fmk_type=mslite.FmkType.TFLITE, model_file="mobilenetv2.tflite",
|
||||
output_file="mobilenetv2.tflite", export_mindir=True)
|
||||
assert "export_mindir: True" in str(converter)
|
||||
output_file="mobilenetv2.tflite", export_mindir=mslite.ModelType.MINDIR_LITE)
|
||||
assert "export_mindir: ModelType.kMindIR_Lite" in str(converter)
|
||||
|
||||
|
||||
def test_converter_28():
|
||||
|
|
|
@ -119,33 +119,40 @@ FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr<ConverterPara>
|
|||
return func_graph;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, const void *buf,
|
||||
const size_t &size) {
|
||||
int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph,
|
||||
const void *buf, const size_t &size) {
|
||||
if (param == nullptr || buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Input param is nullptr";
|
||||
return nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto graph = BuildFuncGraph(param, buf, size);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Parser/Import model return nullptr";
|
||||
return nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed.");
|
||||
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, RET_ERROR, "funcgraph_transform init failed.");
|
||||
// funcgraph_transform
|
||||
graph = funcgraph_transform_->Transform(graph, param);
|
||||
MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr.");
|
||||
MS_CHECK_TRUE_MSG(graph != nullptr, RET_ERROR, "Transform anf graph return nullptr.");
|
||||
// export protobuf
|
||||
auto status = MindIRSerialize(param, graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "Export to mindir proto return nullptr.";
|
||||
if (param->export_mindir == kMindIR) {
|
||||
auto status = MindIRSerialize(param, graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Export to mindir proto failed";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Export to mindir success";
|
||||
return RET_OK;
|
||||
}
|
||||
}
|
||||
return TransferFuncGraph(param, graph);
|
||||
*meta_graph = TransferFuncGraph(param, graph);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m) {
|
||||
int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph) {
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "Input param is nullptr";
|
||||
return nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
param->aclModelOptionCfgParam.om_file_path = param->output_file;
|
||||
|
@ -153,7 +160,7 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara>
|
|||
auto ret = InitConfigParam(param);
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init config file failed." << std::endl;
|
||||
return nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -162,11 +169,11 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara>
|
|||
if (!param->plugins_path.empty()) {
|
||||
for (auto &path : param->plugins_path) {
|
||||
auto dl_loader = std::make_shared<DynamicLibraryLoader>();
|
||||
MS_CHECK_TRUE_RET(dl_loader != nullptr, nullptr);
|
||||
MS_CHECK_TRUE_RET(dl_loader != nullptr, RET_ERROR);
|
||||
auto status = dl_loader->Open(path);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "open dynamic library failed. " << path;
|
||||
return nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
dl_loaders.emplace_back(dl_loader);
|
||||
}
|
||||
|
@ -175,24 +182,30 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara>
|
|||
auto graph = BuildFuncGraph(param);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Parser/Import model return nullptr";
|
||||
return nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
|
||||
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, RET_ERROR, "funcgraph_transform init failed");
|
||||
// funcgraph transform
|
||||
graph = funcgraph_transform_->Transform(graph, param);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Transform anf graph return nullptr";
|
||||
return nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// export protobuf
|
||||
auto status = MindIRSerialize(param, graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "Export to mindir proto return nullptr.";
|
||||
if (param->export_mindir == kMindIR) {
|
||||
auto status = MindIRSerialize(param, graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Export to mindir failed";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Export to mindir success";
|
||||
return RET_OK;
|
||||
}
|
||||
}
|
||||
|
||||
return TransferFuncGraph(param, graph);
|
||||
*meta_graph = TransferFuncGraph(param, graph);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *ConverterImpl::TransferFuncGraph(const std::shared_ptr<ConverterPara> ¶m,
|
||||
|
@ -531,9 +544,18 @@ int RunConverter(const std::shared_ptr<ConverterPara> ¶m, void **model_data,
|
|||
|
||||
param->aclModelOptionCfgParam.offline = !not_save;
|
||||
ConverterImpl converter_impl;
|
||||
auto meta_graph = converter_impl.Convert(param);
|
||||
schema::MetaGraphT *meta_graph = nullptr;
|
||||
int status = converter_impl.Convert(param, &meta_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert model failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param->export_mindir == kMindIR) {
|
||||
MS_LOG(DEBUG) << "No need to export mindir to fb";
|
||||
return RET_OK;
|
||||
}
|
||||
NotSupportOp::GetInstance()->PrintOps();
|
||||
int status = ReturnCode::GetSingleReturnCode()->status_code();
|
||||
status = ReturnCode::GetSingleReturnCode()->status_code();
|
||||
std::ostringstream oss;
|
||||
if (meta_graph == nullptr) {
|
||||
oss.clear();
|
||||
|
|
|
@ -52,8 +52,9 @@ class ConverterImpl {
|
|||
delete model_parser_;
|
||||
this->model_parser_ = nullptr;
|
||||
}
|
||||
schema::MetaGraphT *Convert(const std::shared_ptr<ConverterPara> ¶m);
|
||||
schema::MetaGraphT *Convert(const std::shared_ptr<ConverterPara> ¶m, const void *buf, const size_t &size);
|
||||
int Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph);
|
||||
int Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph, const void *buf,
|
||||
const size_t &size);
|
||||
|
||||
private:
|
||||
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> ¶m);
|
||||
|
|
|
@ -32,7 +32,7 @@ Flags::Flags() {
|
|||
AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", "");
|
||||
AddFlag(&Flags::modelFile, "modelFile",
|
||||
"Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", "");
|
||||
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
|
||||
AddFlag(&Flags::outputFile, "outputFile", "Output model file path.", "");
|
||||
AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel",
|
||||
"");
|
||||
AddFlag(&Flags::inputDataTypeStr, "inputDataType",
|
||||
|
@ -82,10 +82,7 @@ Flags::Flags() {
|
|||
"Whether to do pre-inference after convert. "
|
||||
"true | false",
|
||||
"false");
|
||||
AddFlag(&Flags::exportMindIR, "exportMindIR",
|
||||
"Whether to export MindIR pb. "
|
||||
"true | false",
|
||||
"false");
|
||||
AddFlag(&Flags::exportMindIR, "exportMindIR", "MINDIR | MINDIR_LITE", "MINDIR_LITE");
|
||||
AddFlag(&Flags::noFusionStr, "NoFusion", "Avoid fusion optimization true|false", "false");
|
||||
}
|
||||
|
||||
|
@ -279,10 +276,10 @@ int Flags::InitNoFusion() {
|
|||
}
|
||||
|
||||
int Flags::InitExportMindIR() {
|
||||
if (this->exportMindIR == "true") {
|
||||
this->export_mindir = true;
|
||||
} else if (this->exportMindIR == "false") {
|
||||
this->export_mindir = false;
|
||||
if (this->exportMindIR == "MINDIR") {
|
||||
this->export_mindir = kMindIR;
|
||||
} else if (this->exportMindIR == "MINDIR_LITE") {
|
||||
this->export_mindir = kMindIR_Lite;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: exportMindIR must be true|false " << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
|
|
|
@ -71,7 +71,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
std::string inferStr;
|
||||
bool infer = false;
|
||||
std::string exportMindIR;
|
||||
bool export_mindir = false;
|
||||
ModelType export_mindir = kMindIR_Lite;
|
||||
#ifdef ENABLE_OPENSSL
|
||||
std::string encryptionStr = "true";
|
||||
bool encryption = true;
|
||||
|
|
|
@ -151,17 +151,17 @@ DataType Converter::GetOutputDataType() {
|
|||
}
|
||||
}
|
||||
|
||||
void Converter::SetExportMindIR(bool export_mindir) {
|
||||
void Converter::SetExportMindIR(ModelType export_mindir) {
|
||||
if (data_ != nullptr) {
|
||||
data_->export_mindir = export_mindir;
|
||||
}
|
||||
}
|
||||
|
||||
bool Converter::GetExportMindIR() const {
|
||||
ModelType Converter::GetExportMindIR() const {
|
||||
if (data_ != nullptr) {
|
||||
return data_->export_mindir;
|
||||
} else {
|
||||
return false;
|
||||
return kMindIR_Lite;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ struct ConverterPara {
|
|||
Format input_format = NHWC;
|
||||
DataType input_data_type = DataType::kNumberTypeFloat32;
|
||||
DataType output_data_type = DataType::kNumberTypeFloat32;
|
||||
bool export_mindir = false;
|
||||
ModelType export_mindir = kMindIR_Lite;
|
||||
std::string decrypt_key;
|
||||
std::string decrypt_mode;
|
||||
std::string encrypt_key;
|
||||
|
|
|
@ -403,9 +403,6 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st
|
|||
}
|
||||
|
||||
int MindIRSerialize(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph) {
|
||||
if (!param->export_mindir) {
|
||||
return RET_OK;
|
||||
}
|
||||
mindspore::lite::MindIRSerializer serializer;
|
||||
return serializer.Save(param, func_graph);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue