[MSLITE] runtime convert support input buffer

This commit is contained in:
ling 2021-11-11 12:53:33 +08:00
parent a304352179
commit 022f10e639
15 changed files with 187 additions and 126 deletions

View File

@ -39,8 +39,7 @@ enum ModelType : uint32_t {
kAIR = 1, kAIR = 1,
kOM = 2, kOM = 2,
kONNX = 3, kONNX = 3,
kFlatBuffer = 4, kMindIR_Opt = 4,
kMindIR_Opt = 5,
// insert new data type here // insert new data type here
kUnknownType = 0xFFFFFFFF kUnknownType = 0xFFFFFFFF
}; };

View File

@ -133,6 +133,23 @@ std::vector<std::shared_ptr<FuncGraph>> MindIRLoader::LoadMindIRs(std::vector<st
return funcgraph_vec; return funcgraph_vec;
} }
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const void *buffer, const size_t &size) {
/* mindir -> func_graph
* only support lite */
mind_ir::ModelProto model;
auto ret = model.ParseFromArray(buffer, size);
if (!ret) {
MS_LOG(ERROR) << "ParseFromArray failed.";
return nullptr;
}
MSANFModelParser model_parser;
model_parser.SetLite();
FuncGraphPtr func_graph = model_parser.Parse(model);
return func_graph;
}
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const std::string &file_name) { std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const std::string &file_name) {
if (file_name.length() > PATH_MAX) { if (file_name.length() > PATH_MAX) {
MS_LOG(ERROR) << "The length of the file name exceeds the limit."; MS_LOG(ERROR) << "The length of the file name exceeds the limit.";

View File

@ -34,6 +34,7 @@ class MindIRLoader {
bool get_need_renormalize() const { return need_renormalize_; } bool get_need_renormalize() const { return need_renormalize_; }
void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; } void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; }
std::shared_ptr<FuncGraph> LoadMindIR(const void *buffer, const size_t &size);
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name); std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name);
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names); std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names);

View File

@ -228,9 +228,9 @@ int NetRunner::Main() {
if (epochs_ > 0) { if (epochs_ > 0) {
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms"; auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
mindspore::Serialization::ExportModel(*model_, mindspore::kFlatBuffer, trained_fn, mindspore::kNoQuant, false); mindspore::Serialization::ExportModel(*model_, mindspore::kMindIR, trained_fn, mindspore::kNoQuant, false);
trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_infer.ms"; trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_infer.ms";
mindspore::Serialization::ExportModel(*model_, mindspore::kFlatBuffer, trained_fn, mindspore::kNoQuant, true); mindspore::Serialization::ExportModel(*model_, mindspore::kMindIR, trained_fn, mindspore::kNoQuant, true);
} }
return 0; return 0;
} }

View File

@ -136,7 +136,7 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
MS_LOG(ERROR) << "Model is not TrainModel."; MS_LOG(ERROR) << "Model is not TrainModel.";
return kLiteError; return kLiteError;
} }
if (model_type != kFlatBuffer) { if (model_type != kMindIR && model_type != kMindIR_Opt) {
MS_LOG(ERROR) << "Unsupported Export Format " << model_type; MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
return kLiteParamInvalid; return kLiteParamInvalid;
} }

View File

@ -1318,67 +1318,66 @@ session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_
return session; return session;
} }
const char *lite::LiteSession::LoadModelByBuff(const char *model_buf, mindspore::ModelType model_type, size_t size) { mindspore::ModelType lite::LiteSession::LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf,
size_t *size, mindspore::ModelType model_type) {
if (model_type == mindspore::ModelType::kMindIR_Opt) { if (model_type == mindspore::ModelType::kMindIR_Opt) {
*size = buf_size;
*lite_buf = const_cast<char *>(model_buf);
return mindspore::ModelType::kMindIR_Opt;
}
if (model_type != mindspore::ModelType::kMindIR) {
return mindspore::ModelType::kUnknownType;
}
flatbuffers::Verifier verify((const uint8_t *)model_buf, buf_size);
auto version_verify = lite::LiteModel::VersionVerify(&verify);
if (version_verify != SCHEMA_INVALID) {
MS_LOG(DEBUG) << "The kMindIR type model buffer is valid mslite model buffer";
*size = buf_size;
*lite_buf = const_cast<char *>(model_buf);
return mindspore::ModelType::kMindIR_Opt;
}
#ifdef RUNTIME_CONVERT
*lite_buf = RuntimeConvert(model_buf, buf_size, size);
#else
MS_LOG(ERROR) << "Please enable runtime convert.";
#endif
return mindspore::ModelType::kMindIR;
}
const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size) {
size_t buf_size;
auto model_buf = lite::ReadFile(file.c_str(), &buf_size);
if (model_buf == nullptr) {
MS_LOG(ERROR) << "The model path is invalid";
return model_buf; return model_buf;
} }
if (model_type == mindspore::ModelType::kMindIR) { char *lite_buf = nullptr;
flatbuffers::Verifier verify((const uint8_t *)model_buf, size); auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, size, model_type);
auto version_verify = lite::LiteModel::VersionVerify(&verify); if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) {
if (version_verify != SCHEMA_INVALID) {
MS_LOG(DEBUG) << "The kMindIR type model buffer is valid mslite model buffer";
return model_buf;
}
#ifdef RUNTIME_CONVERT
return RuntimeConvert(model_buf, size);
#endif
MS_LOG(ERROR) << "Please enable runtime convert.";
return nullptr; return nullptr;
} }
if (buf_model_type == mindspore::ModelType::kMindIR) {
MS_LOG(ERROR) << "Invalid Model Type"; free(model_buf);
return nullptr; model_buf = nullptr;
}
return lite_buf;
} }
char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size) { int lite::LiteSession::CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type,
if (model_type == mindspore::ModelType::kMindIR_Opt) { const size_t &buf_size, session::LiteSession *session) {
return lite::ReadFile(file.c_str(), size); size_t lite_buf_size = 0;
} char *lite_buf = nullptr;
auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, &lite_buf_size, model_type);
if (model_type == mindspore::ModelType::kMindIR) { if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) {
size_t tmp_size;
auto tmp_buf = lite::ReadFile(file.c_str(), &tmp_size);
flatbuffers::Verifier verify((const uint8_t *)tmp_buf, tmp_size);
auto version_verify = lite::LiteModel::VersionVerify(&verify);
if (version_verify != SCHEMA_INVALID) {
MS_LOG(DEBUG) << "The kMindIR type model path is valid mslite model";
return tmp_buf;
}
free(tmp_buf);
tmp_buf = nullptr;
#ifdef RUNTIME_CONVERT
return RuntimeConvert(file, size);
#endif
MS_LOG(ERROR) << "Please enable runtime convert.";
return nullptr;
}
MS_LOG(ERROR) << "Invalid Model Type";
return nullptr;
}
int lite::LiteSession::CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, size_t size,
session::LiteSession *session) {
auto lite_buf = LoadModelByBuff(model_buf, model_type, size);
if (lite_buf == nullptr) {
MS_LOG(ERROR) << "Invalid model_buf"; MS_LOG(ERROR) << "Invalid model_buf";
return RET_ERROR; return RET_ERROR;
} }
auto *model = lite::ImportFromBuffer(model_buf, size, true); auto *model = lite::ImportFromBuffer(lite_buf, lite_buf_size, true);
if (model == nullptr) { if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed"; MS_LOG(ERROR) << "Import model failed";
return RET_ERROR; return RET_ERROR;

View File

@ -53,7 +53,7 @@ class LiteSession : public session::LiteSession {
static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context); static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
static int CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, size_t size, static int CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size,
session::LiteSession *session); session::LiteSession *session);
static int CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type, static int CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type,
session::LiteSession *session); session::LiteSession *session);
@ -128,8 +128,9 @@ class LiteSession : public session::LiteSession {
static void FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kernels); static void FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kernels);
static const char *LoadModelByBuff(const char *model_buf, mindspore::ModelType model_type, size_t size); static mindspore::ModelType LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf,
static char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size); size_t *size, mindspore::ModelType model_type);
static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size);
private: private:
int PreCheck(Model *model); int PreCheck(Model *model);

View File

@ -16,78 +16,63 @@
#ifdef RUNTIME_CONVERT #ifdef RUNTIME_CONVERT
#include "src/runtime/runtime_convert.h" #include "src/runtime/runtime_convert.h"
#include "tools/common/graph_util.h" #include "include/version.h"
#include "tools/converter/converter.h"
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include "tools/converter/anf_transform.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/converter/graphdef_transform.h"
#include "tools/converter/import/mindspore_importer.h"
namespace mindspore::lite { namespace mindspore::lite {
char *RuntimeConvert(const char *model_buf, size_t) { char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size) {
if (model_buf == nullptr) { if (model_buf == nullptr) {
MS_LOG(ERROR) << "Invalid input model buffer."; MS_LOG(ERROR) << "Invalid input model buffer.";
return nullptr; return nullptr;
} }
MS_LOG(ERROR) << "Invalid Now. Use model path for model build.";
return nullptr; auto flag = std::make_unique<converter::Flags>();
flag->fmk = converter::kFmkTypeMs;
flag->inputDataType = kTypeUnknown;
flag->outputDataType = kTypeUnknown;
flag->saveFP16 = false;
flag->trainModel = false;
Converter cvt;
auto meta_graph = cvt.Convert(flag, model_buf, buf_size);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Convert failed.";
return nullptr;
}
void *lite_buf = nullptr;
meta_graph->version = Version();
auto status = TransferMetaGraph(*meta_graph, &lite_buf, size);
if (status != RET_OK) {
MS_LOG(ERROR) << "Transfer model failed.";
delete meta_graph;
return nullptr;
}
delete meta_graph;
return reinterpret_cast<char *>(lite_buf);
} }
char *RuntimeConvert(const std::string &file_path, size_t *size) { char *RuntimeConvert(const std::string &file_path, size_t *size) {
void *model_buf = nullptr; auto flag = std::make_unique<converter::Flags>();
converter::Flags flag; flag->fmk = converter::kFmkTypeMs;
flag.fmk = converter::kFmkTypeMs; flag->modelFile = file_path;
flag.modelFile = file_path; flag->inputDataType = kTypeUnknown;
flag.inputDataType = kTypeUnknown; flag->outputDataType = kTypeUnknown;
flag.outputDataType = kTypeUnknown; flag->saveFP16 = false;
flag.saveFP16 = false; flag->trainModel = false;
flag.trainModel = false;
MindsporeImporter ms_import; Converter cvt;
FuncGraphPtr func_graph = ms_import.ImportMindIR(flag); auto meta_graph = cvt.Convert(flag);
if (func_graph == nullptr) { MS_LOG(ERROR) << "Convert failed.";
return nullptr;
}
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
MS_LOG(ERROR) << "Update graph inputs and outputs dtype failed.";
return nullptr;
}
// funcgraph compile
AnfTransform funcgraph_transform;
func_graph = funcgraph_transform.Transform(func_graph, &flag);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Transform anf graph return nullptr";
return nullptr;
}
// protobuf -> flatbuffer
auto meta_graph = Export(func_graph, false, false, false);
if (meta_graph == nullptr) { if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta graph return nullptr";
return nullptr;
}
// metagraph compile
GraphDefTransform metagraph_transform;
metagraph_transform.SetGraphDef(meta_graph);
auto status = metagraph_transform.Transform(flag);
if (status != RET_OK) {
MS_LOG(ERROR) << "Transform meta graph failed " << status;
delete meta_graph;
return nullptr;
}
status = UpdateGraphOutputName(meta_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "UpdateGraphOutputName failed.";
delete meta_graph;
return nullptr; return nullptr;
} }
void *model_buf = nullptr;
meta_graph->version = Version(); meta_graph->version = Version();
status = TransferMetaGraph(*meta_graph, &model_buf, size); auto status = TransferMetaGraph(*meta_graph, &model_buf, size);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Transfer model failed."; MS_LOG(ERROR) << "Transfer model failed.";
delete meta_graph; delete meta_graph;

View File

@ -20,9 +20,10 @@
#ifdef RUNTIME_CONVERT #ifdef RUNTIME_CONVERT
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include <memory>
namespace mindspore::lite { namespace mindspore::lite {
char *RuntimeConvert(const char *model_buf, size_t size); char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size);
char *RuntimeConvert(const std::string &file_path, size_t *size); char *RuntimeConvert(const std::string &file_path, size_t *size);
} // namespace mindspore::lite } // namespace mindspore::lite
#endif // RUNTIME_CONVERT #endif // RUNTIME_CONVERT

View File

@ -24,7 +24,6 @@ file(GLOB_RECURSE TEST_UT_SRC
${TEST_DIR}/ut/src/scheduler_test.cc ${TEST_DIR}/ut/src/scheduler_test.cc
${TEST_DIR}/ut/src/registry/registry_test.cc ${TEST_DIR}/ut/src/registry/registry_test.cc
${TEST_DIR}/ut/src/registry/registry_custom_op_test.cc ${TEST_DIR}/ut/src/registry/registry_custom_op_test.cc
${TEST_DIR}/ut/src/runtime/runtime_pass_tests.cc
${TEST_DIR}/st/multiple_device_test.cc ${TEST_DIR}/st/multiple_device_test.cc
${TEST_DIR}/st/mindrt_parallel_runtime_test.cc ${TEST_DIR}/st/mindrt_parallel_runtime_test.cc
${TEST_DIR}/st/mix_data_type_test.cc ${TEST_DIR}/st/mix_data_type_test.cc
@ -40,6 +39,10 @@ if(MSLITE_ENABLE_RUNTIME_CONVERT)
list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_convert_tests.cc) list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_convert_tests.cc)
endif() endif()
if(MSLITE_ENABLE_RUNTIME_PASS)
list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_pass_tests.cc)
endif()
if(MSLITE_ENABLE_TRAIN) if(MSLITE_ENABLE_TRAIN)
file(GLOB_RECURSE TEST_TRAIN_UT_SRC file(GLOB_RECURSE TEST_TRAIN_UT_SRC
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc
@ -160,13 +163,14 @@ if(MSLITE_ENABLE_MINDRT)
endif() endif()
if(MSLITE_ENABLE_CONVERTER) if(MSLITE_ENABLE_CONVERTER)
target_link_libraries(lite-test tflite_parser_mid caffe_parser_mid
onnx_parser_mid tf_parser_mid)
endif()
if(MSLITE_ENABLE_CONVERTER AND (NOT MSLITE_ENABLE_RUNTIME_CONVERT))
target_link_libraries(lite-test target_link_libraries(lite-test
anf_exporter_mid anf_exporter_mid
mslite_converter_plugin mslite_converter_plugin
tflite_parser_mid
caffe_parser_mid
onnx_parser_mid
tf_parser_mid
graph_pass_mid graph_pass_mid
fusion_mid fusion_mid
quantizer_mid quantizer_mid

View File

@ -37,13 +37,13 @@ TEST_F(TestCxxApiLiteSerialization, test_load_file_not_exist_FAILED) {
TEST_F(TestCxxApiLiteSerialization, test_load_file_not_exist_x2_FAILED) { TEST_F(TestCxxApiLiteSerialization, test_load_file_not_exist_x2_FAILED) {
std::vector<Graph> graphs; std::vector<Graph> graphs;
auto status = auto status =
Serialization::Load(std::vector<std::string>(2, "./nets/file_not_exist.mindir"), ModelType::kFlatBuffer, &graphs); Serialization::Load(std::vector<std::string>(2, "./nets/file_not_exist.mindir"), ModelType::kMindIR, &graphs);
ASSERT_TRUE(status != kSuccess); ASSERT_TRUE(status != kSuccess);
} }
TEST_F(TestCxxApiLiteSerialization, test_export_uninitialized_FAILED) { TEST_F(TestCxxApiLiteSerialization, test_export_uninitialized_FAILED) {
Model model; Model model;
ASSERT_TRUE(Serialization::ExportModel(model, ModelType::kFlatBuffer, "./nets/export.ms") != kSuccess); ASSERT_TRUE(Serialization::ExportModel(model, ModelType::kMindIR, "./nets/export.ms") != kSuccess);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -77,13 +77,41 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
return func_graph; return func_graph;
} }
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag, const void *buf, const size_t &size) {
MindsporeImporter ms_import;
FuncGraphPtr func_graph = ms_import.ImportMindIR(flag, buf, size);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Get funcGraph failed.";
return nullptr;
}
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
MS_LOG(ERROR) << "Update graph inputs and outputs dtype failed.";
return nullptr;
}
return func_graph;
}
schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag, const void *buf,
const size_t &size) {
if (flag == nullptr || buf == nullptr) {
MS_LOG(ERROR) << "Input flag is nullptr";
return nullptr;
}
auto graph = BuildFuncGraph(*flag, buf, size);
if (graph == nullptr) {
MS_LOG(ERROR) << "Parser/Import model return nullptr";
return nullptr;
}
return TransferFuncGraph(flag, graph);
}
schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) { schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) {
if (flag == nullptr) { if (flag == nullptr) {
MS_LOG(ERROR) << "Input flag is nullptr"; MS_LOG(ERROR) << "Input flag is nullptr";
return nullptr; return nullptr;
} }
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");
// load plugin // load plugin
static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders; static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders;
@ -106,15 +134,23 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &
return nullptr; return nullptr;
} }
return TransferFuncGraph(flag, graph);
}
schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr<converter::Flags> &flag,
FuncGraphPtr func_graph) {
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");
// funcgraph compile // funcgraph compile
graph = funcgraph_transform_->Transform(graph, flag.get()); func_graph = funcgraph_transform_->Transform(func_graph, flag.get());
if (graph == nullptr) { if (func_graph == nullptr) {
MS_LOG(ERROR) << "Transform anf graph return nullptr"; MS_LOG(ERROR) << "Transform anf graph return nullptr";
return nullptr; return nullptr;
} }
// protobuf -> flatbuffer // protobuf -> flatbuffer
auto meta_graph = Export(graph, false, false, flag->trainModel); auto meta_graph = Export(func_graph, false, false, flag->trainModel);
if (meta_graph == nullptr) { if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta graph return nullptr"; MS_LOG(ERROR) << "Export to meta graph return nullptr";
return nullptr; return nullptr;

View File

@ -39,7 +39,12 @@ class Converter {
this->model_parser_ = nullptr; this->model_parser_ = nullptr;
} }
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag); schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag);
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag, const void *buf, const size_t &size);
private:
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag); FuncGraphPtr BuildFuncGraph(const converter::Flags &flag);
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag, const void *buf, const size_t &size);
schema::MetaGraphT *TransferFuncGraph(const std::unique_ptr<converter::Flags> &flag, FuncGraphPtr func_graph);
protected: protected:
converter::ModelParser *model_parser_ = nullptr; converter::ModelParser *model_parser_ = nullptr;

View File

@ -246,6 +246,12 @@ void MindsporeImporter::RemoveUnusedGraphInput(const FuncGraphPtr &func_graph) {
} }
} }
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag, const void *buff, const size_t &size) {
MindIRLoader mindir_loader;
auto func_graph = mindir_loader.LoadMindIR(buff, size);
return CheckAndUpdateFuncGraph(flag, func_graph);
}
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) { FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
FuncGraphPtr func_graph; FuncGraphPtr func_graph;
if (!flag.dec_key.empty()) { if (!flag.dec_key.empty()) {
@ -264,6 +270,11 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
MindIRLoader mindir_loader; MindIRLoader mindir_loader;
func_graph = mindir_loader.LoadMindIR(flag.modelFile); func_graph = mindir_loader.LoadMindIR(flag.modelFile);
} }
return CheckAndUpdateFuncGraph(flag, func_graph);
}
FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const converter::Flags &flag, FuncGraphPtr func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR"; MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
MS_LOG(ERROR) MS_LOG(ERROR)

View File

@ -29,11 +29,13 @@ class MindsporeImporter {
MindsporeImporter() = default; MindsporeImporter() = default;
~MindsporeImporter() = default; ~MindsporeImporter() = default;
FuncGraphPtr ImportMindIR(const converter::Flags &flag); FuncGraphPtr ImportMindIR(const converter::Flags &flag);
FuncGraphPtr ImportMindIR(const converter::Flags &flag, const void *buff, const size_t &size);
private: private:
static void RemoveUnusedGraphInput(const FuncGraphPtr &func_graph); static void RemoveUnusedGraphInput(const FuncGraphPtr &func_graph);
STATUS ProcessDependCnode(const CNodePtr &cnode); STATUS ProcessDependCnode(const CNodePtr &cnode);
STATUS GetFuncGraphOutputName(const CNodePtr &cnode); STATUS GetFuncGraphOutputName(const CNodePtr &cnode);
FuncGraphPtr CheckAndUpdateFuncGraph(const converter::Flags &flag, FuncGraphPtr func_graph);
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag); STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag);
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len); size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
std::vector<std::string> output_tensor_name_; std::vector<std::string> output_tensor_name_;