[MSLITE] runtime convert support input buffer
This commit is contained in:
parent
a304352179
commit
022f10e639
|
@ -39,8 +39,7 @@ enum ModelType : uint32_t {
|
|||
kAIR = 1,
|
||||
kOM = 2,
|
||||
kONNX = 3,
|
||||
kFlatBuffer = 4,
|
||||
kMindIR_Opt = 5,
|
||||
kMindIR_Opt = 4,
|
||||
// insert new data type here
|
||||
kUnknownType = 0xFFFFFFFF
|
||||
};
|
||||
|
|
|
@ -133,6 +133,23 @@ std::vector<std::shared_ptr<FuncGraph>> MindIRLoader::LoadMindIRs(std::vector<st
|
|||
return funcgraph_vec;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const void *buffer, const size_t &size) {
|
||||
/* mindir -> func_graph
|
||||
* only support lite */
|
||||
mind_ir::ModelProto model;
|
||||
auto ret = model.ParseFromArray(buffer, size);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "ParseFromArray failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MSANFModelParser model_parser;
|
||||
model_parser.SetLite();
|
||||
FuncGraphPtr func_graph = model_parser.Parse(model);
|
||||
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const std::string &file_name) {
|
||||
if (file_name.length() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
||||
|
|
|
@ -34,6 +34,7 @@ class MindIRLoader {
|
|||
|
||||
bool get_need_renormalize() const { return need_renormalize_; }
|
||||
void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; }
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const void *buffer, const size_t &size);
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name);
|
||||
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names);
|
||||
|
||||
|
|
|
@ -228,9 +228,9 @@ int NetRunner::Main() {
|
|||
|
||||
if (epochs_ > 0) {
|
||||
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
|
||||
mindspore::Serialization::ExportModel(*model_, mindspore::kFlatBuffer, trained_fn, mindspore::kNoQuant, false);
|
||||
mindspore::Serialization::ExportModel(*model_, mindspore::kMindIR, trained_fn, mindspore::kNoQuant, false);
|
||||
trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_infer.ms";
|
||||
mindspore::Serialization::ExportModel(*model_, mindspore::kFlatBuffer, trained_fn, mindspore::kNoQuant, true);
|
||||
mindspore::Serialization::ExportModel(*model_, mindspore::kMindIR, trained_fn, mindspore::kNoQuant, true);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -136,7 +136,7 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
|
|||
MS_LOG(ERROR) << "Model is not TrainModel.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (model_type != kFlatBuffer) {
|
||||
if (model_type != kMindIR && model_type != kMindIR_Opt) {
|
||||
MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
|
||||
return kLiteParamInvalid;
|
||||
}
|
||||
|
|
|
@ -1318,67 +1318,66 @@ session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_
|
|||
return session;
|
||||
}
|
||||
|
||||
const char *lite::LiteSession::LoadModelByBuff(const char *model_buf, mindspore::ModelType model_type, size_t size) {
|
||||
mindspore::ModelType lite::LiteSession::LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf,
|
||||
size_t *size, mindspore::ModelType model_type) {
|
||||
if (model_type == mindspore::ModelType::kMindIR_Opt) {
|
||||
*size = buf_size;
|
||||
*lite_buf = const_cast<char *>(model_buf);
|
||||
return mindspore::ModelType::kMindIR_Opt;
|
||||
}
|
||||
|
||||
if (model_type != mindspore::ModelType::kMindIR) {
|
||||
return mindspore::ModelType::kUnknownType;
|
||||
}
|
||||
|
||||
flatbuffers::Verifier verify((const uint8_t *)model_buf, buf_size);
|
||||
auto version_verify = lite::LiteModel::VersionVerify(&verify);
|
||||
if (version_verify != SCHEMA_INVALID) {
|
||||
MS_LOG(DEBUG) << "The kMindIR type model buffer is valid mslite model buffer";
|
||||
*size = buf_size;
|
||||
*lite_buf = const_cast<char *>(model_buf);
|
||||
return mindspore::ModelType::kMindIR_Opt;
|
||||
}
|
||||
|
||||
#ifdef RUNTIME_CONVERT
|
||||
*lite_buf = RuntimeConvert(model_buf, buf_size, size);
|
||||
#else
|
||||
MS_LOG(ERROR) << "Please enable runtime convert.";
|
||||
#endif
|
||||
return mindspore::ModelType::kMindIR;
|
||||
}
|
||||
|
||||
const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size) {
|
||||
size_t buf_size;
|
||||
auto model_buf = lite::ReadFile(file.c_str(), &buf_size);
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "The model path is invalid";
|
||||
return model_buf;
|
||||
}
|
||||
|
||||
if (model_type == mindspore::ModelType::kMindIR) {
|
||||
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
||||
auto version_verify = lite::LiteModel::VersionVerify(&verify);
|
||||
if (version_verify != SCHEMA_INVALID) {
|
||||
MS_LOG(DEBUG) << "The kMindIR type model buffer is valid mslite model buffer";
|
||||
return model_buf;
|
||||
}
|
||||
|
||||
#ifdef RUNTIME_CONVERT
|
||||
return RuntimeConvert(model_buf, size);
|
||||
#endif
|
||||
MS_LOG(ERROR) << "Please enable runtime convert.";
|
||||
char *lite_buf = nullptr;
|
||||
auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, size, model_type);
|
||||
if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Invalid Model Type";
|
||||
return nullptr;
|
||||
if (buf_model_type == mindspore::ModelType::kMindIR) {
|
||||
free(model_buf);
|
||||
model_buf = nullptr;
|
||||
}
|
||||
return lite_buf;
|
||||
}
|
||||
|
||||
char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size) {
|
||||
if (model_type == mindspore::ModelType::kMindIR_Opt) {
|
||||
return lite::ReadFile(file.c_str(), size);
|
||||
}
|
||||
|
||||
if (model_type == mindspore::ModelType::kMindIR) {
|
||||
size_t tmp_size;
|
||||
auto tmp_buf = lite::ReadFile(file.c_str(), &tmp_size);
|
||||
flatbuffers::Verifier verify((const uint8_t *)tmp_buf, tmp_size);
|
||||
auto version_verify = lite::LiteModel::VersionVerify(&verify);
|
||||
if (version_verify != SCHEMA_INVALID) {
|
||||
MS_LOG(DEBUG) << "The kMindIR type model path is valid mslite model";
|
||||
return tmp_buf;
|
||||
}
|
||||
free(tmp_buf);
|
||||
tmp_buf = nullptr;
|
||||
|
||||
#ifdef RUNTIME_CONVERT
|
||||
return RuntimeConvert(file, size);
|
||||
#endif
|
||||
MS_LOG(ERROR) << "Please enable runtime convert.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Invalid Model Type";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int lite::LiteSession::CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, size_t size,
|
||||
session::LiteSession *session) {
|
||||
auto lite_buf = LoadModelByBuff(model_buf, model_type, size);
|
||||
if (lite_buf == nullptr) {
|
||||
int lite::LiteSession::CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type,
|
||||
const size_t &buf_size, session::LiteSession *session) {
|
||||
size_t lite_buf_size = 0;
|
||||
char *lite_buf = nullptr;
|
||||
auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, &lite_buf_size, model_type);
|
||||
if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid model_buf";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto *model = lite::ImportFromBuffer(model_buf, size, true);
|
||||
auto *model = lite::ImportFromBuffer(lite_buf, lite_buf_size, true);
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -53,7 +53,7 @@ class LiteSession : public session::LiteSession {
|
|||
|
||||
static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
|
||||
|
||||
static int CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, size_t size,
|
||||
static int CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size,
|
||||
session::LiteSession *session);
|
||||
static int CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type,
|
||||
session::LiteSession *session);
|
||||
|
@ -128,8 +128,9 @@ class LiteSession : public session::LiteSession {
|
|||
|
||||
static void FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kernels);
|
||||
|
||||
static const char *LoadModelByBuff(const char *model_buf, mindspore::ModelType model_type, size_t size);
|
||||
static char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size);
|
||||
static mindspore::ModelType LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf,
|
||||
size_t *size, mindspore::ModelType model_type);
|
||||
static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size);
|
||||
|
||||
private:
|
||||
int PreCheck(Model *model);
|
||||
|
|
|
@ -16,78 +16,63 @@
|
|||
|
||||
#ifdef RUNTIME_CONVERT
|
||||
#include "src/runtime/runtime_convert.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "include/version.h"
|
||||
#include "tools/converter/converter.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "tools/converter/import/mindspore_importer.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
char *RuntimeConvert(const char *model_buf, size_t) {
|
||||
char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size) {
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid input model buffer.";
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(ERROR) << "Invalid Now. Use model path for model build.";
|
||||
return nullptr;
|
||||
|
||||
auto flag = std::make_unique<converter::Flags>();
|
||||
flag->fmk = converter::kFmkTypeMs;
|
||||
flag->inputDataType = kTypeUnknown;
|
||||
flag->outputDataType = kTypeUnknown;
|
||||
flag->saveFP16 = false;
|
||||
flag->trainModel = false;
|
||||
|
||||
Converter cvt;
|
||||
auto meta_graph = cvt.Convert(flag, model_buf, buf_size);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void *lite_buf = nullptr;
|
||||
meta_graph->version = Version();
|
||||
auto status = TransferMetaGraph(*meta_graph, &lite_buf, size);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Transfer model failed.";
|
||||
delete meta_graph;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
delete meta_graph;
|
||||
return reinterpret_cast<char *>(lite_buf);
|
||||
}
|
||||
|
||||
char *RuntimeConvert(const std::string &file_path, size_t *size) {
|
||||
void *model_buf = nullptr;
|
||||
converter::Flags flag;
|
||||
flag.fmk = converter::kFmkTypeMs;
|
||||
flag.modelFile = file_path;
|
||||
flag.inputDataType = kTypeUnknown;
|
||||
flag.outputDataType = kTypeUnknown;
|
||||
flag.saveFP16 = false;
|
||||
flag.trainModel = false;
|
||||
auto flag = std::make_unique<converter::Flags>();
|
||||
flag->fmk = converter::kFmkTypeMs;
|
||||
flag->modelFile = file_path;
|
||||
flag->inputDataType = kTypeUnknown;
|
||||
flag->outputDataType = kTypeUnknown;
|
||||
flag->saveFP16 = false;
|
||||
flag->trainModel = false;
|
||||
|
||||
MindsporeImporter ms_import;
|
||||
FuncGraphPtr func_graph = ms_import.ImportMindIR(flag);
|
||||
if (func_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Update graph inputs and outputs dtype failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// funcgraph compile
|
||||
AnfTransform funcgraph_transform;
|
||||
func_graph = funcgraph_transform.Transform(func_graph, &flag);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Transform anf graph return nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// protobuf -> flatbuffer
|
||||
auto meta_graph = Export(func_graph, false, false, false);
|
||||
Converter cvt;
|
||||
auto meta_graph = cvt.Convert(flag);
|
||||
MS_LOG(ERROR) << "Convert failed.";
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Export to meta graph return nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// metagraph compile
|
||||
GraphDefTransform metagraph_transform;
|
||||
metagraph_transform.SetGraphDef(meta_graph);
|
||||
auto status = metagraph_transform.Transform(flag);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Transform meta graph failed " << status;
|
||||
delete meta_graph;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = UpdateGraphOutputName(meta_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "UpdateGraphOutputName failed.";
|
||||
delete meta_graph;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void *model_buf = nullptr;
|
||||
meta_graph->version = Version();
|
||||
status = TransferMetaGraph(*meta_graph, &model_buf, size);
|
||||
auto status = TransferMetaGraph(*meta_graph, &model_buf, size);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Transfer model failed.";
|
||||
delete meta_graph;
|
||||
|
|
|
@ -20,9 +20,10 @@
|
|||
#ifdef RUNTIME_CONVERT
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore::lite {
|
||||
char *RuntimeConvert(const char *model_buf, size_t size);
|
||||
char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size);
|
||||
char *RuntimeConvert(const std::string &file_path, size_t *size);
|
||||
} // namespace mindspore::lite
|
||||
#endif // RUNTIME_CONVERT
|
||||
|
|
|
@ -24,7 +24,6 @@ file(GLOB_RECURSE TEST_UT_SRC
|
|||
${TEST_DIR}/ut/src/scheduler_test.cc
|
||||
${TEST_DIR}/ut/src/registry/registry_test.cc
|
||||
${TEST_DIR}/ut/src/registry/registry_custom_op_test.cc
|
||||
${TEST_DIR}/ut/src/runtime/runtime_pass_tests.cc
|
||||
${TEST_DIR}/st/multiple_device_test.cc
|
||||
${TEST_DIR}/st/mindrt_parallel_runtime_test.cc
|
||||
${TEST_DIR}/st/mix_data_type_test.cc
|
||||
|
@ -40,6 +39,10 @@ if(MSLITE_ENABLE_RUNTIME_CONVERT)
|
|||
list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_convert_tests.cc)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_RUNTIME_PASS)
|
||||
list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_pass_tests.cc)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_TRAIN)
|
||||
file(GLOB_RECURSE TEST_TRAIN_UT_SRC
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc
|
||||
|
@ -160,13 +163,14 @@ if(MSLITE_ENABLE_MINDRT)
|
|||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_CONVERTER)
|
||||
target_link_libraries(lite-test tflite_parser_mid caffe_parser_mid
|
||||
onnx_parser_mid tf_parser_mid)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_CONVERTER AND (NOT MSLITE_ENABLE_RUNTIME_CONVERT))
|
||||
target_link_libraries(lite-test
|
||||
anf_exporter_mid
|
||||
mslite_converter_plugin
|
||||
tflite_parser_mid
|
||||
caffe_parser_mid
|
||||
onnx_parser_mid
|
||||
tf_parser_mid
|
||||
graph_pass_mid
|
||||
fusion_mid
|
||||
quantizer_mid
|
||||
|
|
|
@ -37,13 +37,13 @@ TEST_F(TestCxxApiLiteSerialization, test_load_file_not_exist_FAILED) {
|
|||
TEST_F(TestCxxApiLiteSerialization, test_load_file_not_exist_x2_FAILED) {
|
||||
std::vector<Graph> graphs;
|
||||
auto status =
|
||||
Serialization::Load(std::vector<std::string>(2, "./nets/file_not_exist.mindir"), ModelType::kFlatBuffer, &graphs);
|
||||
Serialization::Load(std::vector<std::string>(2, "./nets/file_not_exist.mindir"), ModelType::kMindIR, &graphs);
|
||||
ASSERT_TRUE(status != kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiLiteSerialization, test_export_uninitialized_FAILED) {
|
||||
Model model;
|
||||
ASSERT_TRUE(Serialization::ExportModel(model, ModelType::kFlatBuffer, "./nets/export.ms") != kSuccess);
|
||||
ASSERT_TRUE(Serialization::ExportModel(model, ModelType::kMindIR, "./nets/export.ms") != kSuccess);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -77,13 +77,41 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
|||
return func_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag, const void *buf, const size_t &size) {
|
||||
MindsporeImporter ms_import;
|
||||
FuncGraphPtr func_graph = ms_import.ImportMindIR(flag, buf, size);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Get funcGraph failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Update graph inputs and outputs dtype failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag, const void *buf,
|
||||
const size_t &size) {
|
||||
if (flag == nullptr || buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Input flag is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto graph = BuildFuncGraph(*flag, buf, size);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Parser/Import model return nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
return TransferFuncGraph(flag, graph);
|
||||
}
|
||||
|
||||
schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) {
|
||||
if (flag == nullptr) {
|
||||
MS_LOG(ERROR) << "Input flag is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
|
||||
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");
|
||||
|
||||
// load plugin
|
||||
static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders;
|
||||
|
@ -106,15 +134,23 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
return TransferFuncGraph(flag, graph);
|
||||
}
|
||||
|
||||
schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr<converter::Flags> &flag,
|
||||
FuncGraphPtr func_graph) {
|
||||
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
|
||||
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");
|
||||
|
||||
// funcgraph compile
|
||||
graph = funcgraph_transform_->Transform(graph, flag.get());
|
||||
if (graph == nullptr) {
|
||||
func_graph = funcgraph_transform_->Transform(func_graph, flag.get());
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Transform anf graph return nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// protobuf -> flatbuffer
|
||||
auto meta_graph = Export(graph, false, false, flag->trainModel);
|
||||
auto meta_graph = Export(func_graph, false, false, flag->trainModel);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Export to meta graph return nullptr";
|
||||
return nullptr;
|
||||
|
|
|
@ -39,7 +39,12 @@ class Converter {
|
|||
this->model_parser_ = nullptr;
|
||||
}
|
||||
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag);
|
||||
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag, const void *buf, const size_t &size);
|
||||
|
||||
private:
|
||||
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag);
|
||||
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag, const void *buf, const size_t &size);
|
||||
schema::MetaGraphT *TransferFuncGraph(const std::unique_ptr<converter::Flags> &flag, FuncGraphPtr func_graph);
|
||||
|
||||
protected:
|
||||
converter::ModelParser *model_parser_ = nullptr;
|
||||
|
|
|
@ -246,6 +246,12 @@ void MindsporeImporter::RemoveUnusedGraphInput(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag, const void *buff, const size_t &size) {
|
||||
MindIRLoader mindir_loader;
|
||||
auto func_graph = mindir_loader.LoadMindIR(buff, size);
|
||||
return CheckAndUpdateFuncGraph(flag, func_graph);
|
||||
}
|
||||
|
||||
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
||||
FuncGraphPtr func_graph;
|
||||
if (!flag.dec_key.empty()) {
|
||||
|
@ -264,6 +270,11 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
|||
MindIRLoader mindir_loader;
|
||||
func_graph = mindir_loader.LoadMindIR(flag.modelFile);
|
||||
}
|
||||
|
||||
return CheckAndUpdateFuncGraph(flag, func_graph);
|
||||
}
|
||||
|
||||
FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const converter::Flags &flag, FuncGraphPtr func_graph) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
|
||||
MS_LOG(ERROR)
|
||||
|
|
|
@ -29,11 +29,13 @@ class MindsporeImporter {
|
|||
MindsporeImporter() = default;
|
||||
~MindsporeImporter() = default;
|
||||
FuncGraphPtr ImportMindIR(const converter::Flags &flag);
|
||||
FuncGraphPtr ImportMindIR(const converter::Flags &flag, const void *buff, const size_t &size);
|
||||
|
||||
private:
|
||||
static void RemoveUnusedGraphInput(const FuncGraphPtr &func_graph);
|
||||
STATUS ProcessDependCnode(const CNodePtr &cnode);
|
||||
STATUS GetFuncGraphOutputName(const CNodePtr &cnode);
|
||||
FuncGraphPtr CheckAndUpdateFuncGraph(const converter::Flags &flag, FuncGraphPtr func_graph);
|
||||
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag);
|
||||
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
|
||||
std::vector<std::string> output_tensor_name_;
|
||||
|
|
Loading…
Reference in New Issue