[MSLITE] runtime convert support input buffer
This commit is contained in:
parent
a304352179
commit
022f10e639
|
@ -39,8 +39,7 @@ enum ModelType : uint32_t {
|
||||||
kAIR = 1,
|
kAIR = 1,
|
||||||
kOM = 2,
|
kOM = 2,
|
||||||
kONNX = 3,
|
kONNX = 3,
|
||||||
kFlatBuffer = 4,
|
kMindIR_Opt = 4,
|
||||||
kMindIR_Opt = 5,
|
|
||||||
// insert new data type here
|
// insert new data type here
|
||||||
kUnknownType = 0xFFFFFFFF
|
kUnknownType = 0xFFFFFFFF
|
||||||
};
|
};
|
||||||
|
|
|
@ -133,6 +133,23 @@ std::vector<std::shared_ptr<FuncGraph>> MindIRLoader::LoadMindIRs(std::vector<st
|
||||||
return funcgraph_vec;
|
return funcgraph_vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const void *buffer, const size_t &size) {
|
||||||
|
/* mindir -> func_graph
|
||||||
|
* only support lite */
|
||||||
|
mind_ir::ModelProto model;
|
||||||
|
auto ret = model.ParseFromArray(buffer, size);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(ERROR) << "ParseFromArray failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
MSANFModelParser model_parser;
|
||||||
|
model_parser.SetLite();
|
||||||
|
FuncGraphPtr func_graph = model_parser.Parse(model);
|
||||||
|
|
||||||
|
return func_graph;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const std::string &file_name) {
|
std::shared_ptr<FuncGraph> MindIRLoader::LoadMindIR(const std::string &file_name) {
|
||||||
if (file_name.length() > PATH_MAX) {
|
if (file_name.length() > PATH_MAX) {
|
||||||
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
|
||||||
|
|
|
@ -34,6 +34,7 @@ class MindIRLoader {
|
||||||
|
|
||||||
bool get_need_renormalize() const { return need_renormalize_; }
|
bool get_need_renormalize() const { return need_renormalize_; }
|
||||||
void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; }
|
void set_need_renormalize(bool need_renormalize) { need_renormalize_ = need_renormalize; }
|
||||||
|
std::shared_ptr<FuncGraph> LoadMindIR(const void *buffer, const size_t &size);
|
||||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name);
|
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name);
|
||||||
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names);
|
std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(const std::vector<std::string> file_names);
|
||||||
|
|
||||||
|
|
|
@ -228,9 +228,9 @@ int NetRunner::Main() {
|
||||||
|
|
||||||
if (epochs_ > 0) {
|
if (epochs_ > 0) {
|
||||||
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
|
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms";
|
||||||
mindspore::Serialization::ExportModel(*model_, mindspore::kFlatBuffer, trained_fn, mindspore::kNoQuant, false);
|
mindspore::Serialization::ExportModel(*model_, mindspore::kMindIR, trained_fn, mindspore::kNoQuant, false);
|
||||||
trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_infer.ms";
|
trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_infer.ms";
|
||||||
mindspore::Serialization::ExportModel(*model_, mindspore::kFlatBuffer, trained_fn, mindspore::kNoQuant, true);
|
mindspore::Serialization::ExportModel(*model_, mindspore::kMindIR, trained_fn, mindspore::kNoQuant, true);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,7 +136,7 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
|
||||||
MS_LOG(ERROR) << "Model is not TrainModel.";
|
MS_LOG(ERROR) << "Model is not TrainModel.";
|
||||||
return kLiteError;
|
return kLiteError;
|
||||||
}
|
}
|
||||||
if (model_type != kFlatBuffer) {
|
if (model_type != kMindIR && model_type != kMindIR_Opt) {
|
||||||
MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
|
MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
|
||||||
return kLiteParamInvalid;
|
return kLiteParamInvalid;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1318,67 +1318,66 @@ session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_
|
||||||
return session;
|
return session;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char *lite::LiteSession::LoadModelByBuff(const char *model_buf, mindspore::ModelType model_type, size_t size) {
|
mindspore::ModelType lite::LiteSession::LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf,
|
||||||
|
size_t *size, mindspore::ModelType model_type) {
|
||||||
if (model_type == mindspore::ModelType::kMindIR_Opt) {
|
if (model_type == mindspore::ModelType::kMindIR_Opt) {
|
||||||
return model_buf;
|
*size = buf_size;
|
||||||
|
*lite_buf = const_cast<char *>(model_buf);
|
||||||
|
return mindspore::ModelType::kMindIR_Opt;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_type == mindspore::ModelType::kMindIR) {
|
if (model_type != mindspore::ModelType::kMindIR) {
|
||||||
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
return mindspore::ModelType::kUnknownType;
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Verifier verify((const uint8_t *)model_buf, buf_size);
|
||||||
auto version_verify = lite::LiteModel::VersionVerify(&verify);
|
auto version_verify = lite::LiteModel::VersionVerify(&verify);
|
||||||
if (version_verify != SCHEMA_INVALID) {
|
if (version_verify != SCHEMA_INVALID) {
|
||||||
MS_LOG(DEBUG) << "The kMindIR type model buffer is valid mslite model buffer";
|
MS_LOG(DEBUG) << "The kMindIR type model buffer is valid mslite model buffer";
|
||||||
|
*size = buf_size;
|
||||||
|
*lite_buf = const_cast<char *>(model_buf);
|
||||||
|
return mindspore::ModelType::kMindIR_Opt;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef RUNTIME_CONVERT
|
||||||
|
*lite_buf = RuntimeConvert(model_buf, buf_size, size);
|
||||||
|
#else
|
||||||
|
MS_LOG(ERROR) << "Please enable runtime convert.";
|
||||||
|
#endif
|
||||||
|
return mindspore::ModelType::kMindIR;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size) {
|
||||||
|
size_t buf_size;
|
||||||
|
auto model_buf = lite::ReadFile(file.c_str(), &buf_size);
|
||||||
|
if (model_buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The model path is invalid";
|
||||||
return model_buf;
|
return model_buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef RUNTIME_CONVERT
|
char *lite_buf = nullptr;
|
||||||
return RuntimeConvert(model_buf, size);
|
auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, size, model_type);
|
||||||
#endif
|
if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) {
|
||||||
MS_LOG(ERROR) << "Please enable runtime convert.";
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (buf_model_type == mindspore::ModelType::kMindIR) {
|
||||||
MS_LOG(ERROR) << "Invalid Model Type";
|
free(model_buf);
|
||||||
return nullptr;
|
model_buf = nullptr;
|
||||||
|
}
|
||||||
|
return lite_buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size) {
|
int lite::LiteSession::CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type,
|
||||||
if (model_type == mindspore::ModelType::kMindIR_Opt) {
|
const size_t &buf_size, session::LiteSession *session) {
|
||||||
return lite::ReadFile(file.c_str(), size);
|
size_t lite_buf_size = 0;
|
||||||
}
|
char *lite_buf = nullptr;
|
||||||
|
auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, &lite_buf_size, model_type);
|
||||||
if (model_type == mindspore::ModelType::kMindIR) {
|
if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) {
|
||||||
size_t tmp_size;
|
|
||||||
auto tmp_buf = lite::ReadFile(file.c_str(), &tmp_size);
|
|
||||||
flatbuffers::Verifier verify((const uint8_t *)tmp_buf, tmp_size);
|
|
||||||
auto version_verify = lite::LiteModel::VersionVerify(&verify);
|
|
||||||
if (version_verify != SCHEMA_INVALID) {
|
|
||||||
MS_LOG(DEBUG) << "The kMindIR type model path is valid mslite model";
|
|
||||||
return tmp_buf;
|
|
||||||
}
|
|
||||||
free(tmp_buf);
|
|
||||||
tmp_buf = nullptr;
|
|
||||||
|
|
||||||
#ifdef RUNTIME_CONVERT
|
|
||||||
return RuntimeConvert(file, size);
|
|
||||||
#endif
|
|
||||||
MS_LOG(ERROR) << "Please enable runtime convert.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(ERROR) << "Invalid Model Type";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
int lite::LiteSession::CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, size_t size,
|
|
||||||
session::LiteSession *session) {
|
|
||||||
auto lite_buf = LoadModelByBuff(model_buf, model_type, size);
|
|
||||||
if (lite_buf == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Invalid model_buf";
|
MS_LOG(ERROR) << "Invalid model_buf";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto *model = lite::ImportFromBuffer(model_buf, size, true);
|
auto *model = lite::ImportFromBuffer(lite_buf, lite_buf_size, true);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "Import model failed";
|
MS_LOG(ERROR) << "Import model failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -53,7 +53,7 @@ class LiteSession : public session::LiteSession {
|
||||||
|
|
||||||
static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
|
static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
|
||||||
|
|
||||||
static int CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, size_t size,
|
static int CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size,
|
||||||
session::LiteSession *session);
|
session::LiteSession *session);
|
||||||
static int CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type,
|
static int CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type,
|
||||||
session::LiteSession *session);
|
session::LiteSession *session);
|
||||||
|
@ -128,8 +128,9 @@ class LiteSession : public session::LiteSession {
|
||||||
|
|
||||||
static void FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kernels);
|
static void FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kernels);
|
||||||
|
|
||||||
static const char *LoadModelByBuff(const char *model_buf, mindspore::ModelType model_type, size_t size);
|
static mindspore::ModelType LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf,
|
||||||
static char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size);
|
size_t *size, mindspore::ModelType model_type);
|
||||||
|
static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int PreCheck(Model *model);
|
int PreCheck(Model *model);
|
||||||
|
|
|
@ -16,78 +16,63 @@
|
||||||
|
|
||||||
#ifdef RUNTIME_CONVERT
|
#ifdef RUNTIME_CONVERT
|
||||||
#include "src/runtime/runtime_convert.h"
|
#include "src/runtime/runtime_convert.h"
|
||||||
#include "tools/common/graph_util.h"
|
#include "include/version.h"
|
||||||
|
#include "tools/converter/converter.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
#include "tools/converter/anf_transform.h"
|
|
||||||
#include "tools/anf_exporter/anf_exporter.h"
|
|
||||||
#include "tools/converter/graphdef_transform.h"
|
|
||||||
#include "tools/converter/import/mindspore_importer.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
char *RuntimeConvert(const char *model_buf, size_t) {
|
char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size) {
|
||||||
if (model_buf == nullptr) {
|
if (model_buf == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid input model buffer.";
|
MS_LOG(ERROR) << "Invalid input model buffer.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
MS_LOG(ERROR) << "Invalid Now. Use model path for model build.";
|
|
||||||
|
auto flag = std::make_unique<converter::Flags>();
|
||||||
|
flag->fmk = converter::kFmkTypeMs;
|
||||||
|
flag->inputDataType = kTypeUnknown;
|
||||||
|
flag->outputDataType = kTypeUnknown;
|
||||||
|
flag->saveFP16 = false;
|
||||||
|
flag->trainModel = false;
|
||||||
|
|
||||||
|
Converter cvt;
|
||||||
|
auto meta_graph = cvt.Convert(flag, model_buf, buf_size);
|
||||||
|
if (meta_graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Convert failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void *lite_buf = nullptr;
|
||||||
|
meta_graph->version = Version();
|
||||||
|
auto status = TransferMetaGraph(*meta_graph, &lite_buf, size);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Transfer model failed.";
|
||||||
|
delete meta_graph;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
delete meta_graph;
|
||||||
|
return reinterpret_cast<char *>(lite_buf);
|
||||||
|
}
|
||||||
|
|
||||||
char *RuntimeConvert(const std::string &file_path, size_t *size) {
|
char *RuntimeConvert(const std::string &file_path, size_t *size) {
|
||||||
void *model_buf = nullptr;
|
auto flag = std::make_unique<converter::Flags>();
|
||||||
converter::Flags flag;
|
flag->fmk = converter::kFmkTypeMs;
|
||||||
flag.fmk = converter::kFmkTypeMs;
|
flag->modelFile = file_path;
|
||||||
flag.modelFile = file_path;
|
flag->inputDataType = kTypeUnknown;
|
||||||
flag.inputDataType = kTypeUnknown;
|
flag->outputDataType = kTypeUnknown;
|
||||||
flag.outputDataType = kTypeUnknown;
|
flag->saveFP16 = false;
|
||||||
flag.saveFP16 = false;
|
flag->trainModel = false;
|
||||||
flag.trainModel = false;
|
|
||||||
|
|
||||||
MindsporeImporter ms_import;
|
Converter cvt;
|
||||||
FuncGraphPtr func_graph = ms_import.ImportMindIR(flag);
|
auto meta_graph = cvt.Convert(flag);
|
||||||
if (func_graph == nullptr) {
|
MS_LOG(ERROR) << "Convert failed.";
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "Update graph inputs and outputs dtype failed.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// funcgraph compile
|
|
||||||
AnfTransform funcgraph_transform;
|
|
||||||
func_graph = funcgraph_transform.Transform(func_graph, &flag);
|
|
||||||
if (func_graph == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Transform anf graph return nullptr";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// protobuf -> flatbuffer
|
|
||||||
auto meta_graph = Export(func_graph, false, false, false);
|
|
||||||
if (meta_graph == nullptr) {
|
if (meta_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "Export to meta graph return nullptr";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// metagraph compile
|
|
||||||
GraphDefTransform metagraph_transform;
|
|
||||||
metagraph_transform.SetGraphDef(meta_graph);
|
|
||||||
auto status = metagraph_transform.Transform(flag);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "Transform meta graph failed " << status;
|
|
||||||
delete meta_graph;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
status = UpdateGraphOutputName(meta_graph);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "UpdateGraphOutputName failed.";
|
|
||||||
delete meta_graph;
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void *model_buf = nullptr;
|
||||||
meta_graph->version = Version();
|
meta_graph->version = Version();
|
||||||
status = TransferMetaGraph(*meta_graph, &model_buf, size);
|
auto status = TransferMetaGraph(*meta_graph, &model_buf, size);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Transfer model failed.";
|
MS_LOG(ERROR) << "Transfer model failed.";
|
||||||
delete meta_graph;
|
delete meta_graph;
|
||||||
|
|
|
@ -20,9 +20,10 @@
|
||||||
#ifdef RUNTIME_CONVERT
|
#ifdef RUNTIME_CONVERT
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
char *RuntimeConvert(const char *model_buf, size_t size);
|
char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size);
|
||||||
char *RuntimeConvert(const std::string &file_path, size_t *size);
|
char *RuntimeConvert(const std::string &file_path, size_t *size);
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
#endif // RUNTIME_CONVERT
|
#endif // RUNTIME_CONVERT
|
||||||
|
|
|
@ -24,7 +24,6 @@ file(GLOB_RECURSE TEST_UT_SRC
|
||||||
${TEST_DIR}/ut/src/scheduler_test.cc
|
${TEST_DIR}/ut/src/scheduler_test.cc
|
||||||
${TEST_DIR}/ut/src/registry/registry_test.cc
|
${TEST_DIR}/ut/src/registry/registry_test.cc
|
||||||
${TEST_DIR}/ut/src/registry/registry_custom_op_test.cc
|
${TEST_DIR}/ut/src/registry/registry_custom_op_test.cc
|
||||||
${TEST_DIR}/ut/src/runtime/runtime_pass_tests.cc
|
|
||||||
${TEST_DIR}/st/multiple_device_test.cc
|
${TEST_DIR}/st/multiple_device_test.cc
|
||||||
${TEST_DIR}/st/mindrt_parallel_runtime_test.cc
|
${TEST_DIR}/st/mindrt_parallel_runtime_test.cc
|
||||||
${TEST_DIR}/st/mix_data_type_test.cc
|
${TEST_DIR}/st/mix_data_type_test.cc
|
||||||
|
@ -40,6 +39,10 @@ if(MSLITE_ENABLE_RUNTIME_CONVERT)
|
||||||
list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_convert_tests.cc)
|
list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_convert_tests.cc)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(MSLITE_ENABLE_RUNTIME_PASS)
|
||||||
|
list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_pass_tests.cc)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MSLITE_ENABLE_TRAIN)
|
if(MSLITE_ENABLE_TRAIN)
|
||||||
file(GLOB_RECURSE TEST_TRAIN_UT_SRC
|
file(GLOB_RECURSE TEST_TRAIN_UT_SRC
|
||||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc
|
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc
|
||||||
|
@ -160,13 +163,14 @@ if(MSLITE_ENABLE_MINDRT)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MSLITE_ENABLE_CONVERTER)
|
if(MSLITE_ENABLE_CONVERTER)
|
||||||
|
target_link_libraries(lite-test tflite_parser_mid caffe_parser_mid
|
||||||
|
onnx_parser_mid tf_parser_mid)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MSLITE_ENABLE_CONVERTER AND (NOT MSLITE_ENABLE_RUNTIME_CONVERT))
|
||||||
target_link_libraries(lite-test
|
target_link_libraries(lite-test
|
||||||
anf_exporter_mid
|
anf_exporter_mid
|
||||||
mslite_converter_plugin
|
mslite_converter_plugin
|
||||||
tflite_parser_mid
|
|
||||||
caffe_parser_mid
|
|
||||||
onnx_parser_mid
|
|
||||||
tf_parser_mid
|
|
||||||
graph_pass_mid
|
graph_pass_mid
|
||||||
fusion_mid
|
fusion_mid
|
||||||
quantizer_mid
|
quantizer_mid
|
||||||
|
|
|
@ -37,13 +37,13 @@ TEST_F(TestCxxApiLiteSerialization, test_load_file_not_exist_FAILED) {
|
||||||
TEST_F(TestCxxApiLiteSerialization, test_load_file_not_exist_x2_FAILED) {
|
TEST_F(TestCxxApiLiteSerialization, test_load_file_not_exist_x2_FAILED) {
|
||||||
std::vector<Graph> graphs;
|
std::vector<Graph> graphs;
|
||||||
auto status =
|
auto status =
|
||||||
Serialization::Load(std::vector<std::string>(2, "./nets/file_not_exist.mindir"), ModelType::kFlatBuffer, &graphs);
|
Serialization::Load(std::vector<std::string>(2, "./nets/file_not_exist.mindir"), ModelType::kMindIR, &graphs);
|
||||||
ASSERT_TRUE(status != kSuccess);
|
ASSERT_TRUE(status != kSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestCxxApiLiteSerialization, test_export_uninitialized_FAILED) {
|
TEST_F(TestCxxApiLiteSerialization, test_export_uninitialized_FAILED) {
|
||||||
Model model;
|
Model model;
|
||||||
ASSERT_TRUE(Serialization::ExportModel(model, ModelType::kFlatBuffer, "./nets/export.ms") != kSuccess);
|
ASSERT_TRUE(Serialization::ExportModel(model, ModelType::kMindIR, "./nets/export.ms") != kSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -77,13 +77,41 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
||||||
return func_graph;
|
return func_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag, const void *buf, const size_t &size) {
|
||||||
|
MindsporeImporter ms_import;
|
||||||
|
FuncGraphPtr func_graph = ms_import.ImportMindIR(flag, buf, size);
|
||||||
|
if (func_graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Get funcGraph failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Update graph inputs and outputs dtype failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return func_graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag, const void *buf,
|
||||||
|
const size_t &size) {
|
||||||
|
if (flag == nullptr || buf == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Input flag is nullptr";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto graph = BuildFuncGraph(*flag, buf, size);
|
||||||
|
if (graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Parser/Import model return nullptr";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return TransferFuncGraph(flag, graph);
|
||||||
|
}
|
||||||
|
|
||||||
schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) {
|
schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) {
|
||||||
if (flag == nullptr) {
|
if (flag == nullptr) {
|
||||||
MS_LOG(ERROR) << "Input flag is nullptr";
|
MS_LOG(ERROR) << "Input flag is nullptr";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
|
|
||||||
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");
|
|
||||||
|
|
||||||
// load plugin
|
// load plugin
|
||||||
static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders;
|
static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders;
|
||||||
|
@ -106,15 +134,23 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return TransferFuncGraph(flag, graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr<converter::Flags> &flag,
|
||||||
|
FuncGraphPtr func_graph) {
|
||||||
|
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
|
||||||
|
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");
|
||||||
|
|
||||||
// funcgraph compile
|
// funcgraph compile
|
||||||
graph = funcgraph_transform_->Transform(graph, flag.get());
|
func_graph = funcgraph_transform_->Transform(func_graph, flag.get());
|
||||||
if (graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "Transform anf graph return nullptr";
|
MS_LOG(ERROR) << "Transform anf graph return nullptr";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// protobuf -> flatbuffer
|
// protobuf -> flatbuffer
|
||||||
auto meta_graph = Export(graph, false, false, flag->trainModel);
|
auto meta_graph = Export(func_graph, false, false, flag->trainModel);
|
||||||
if (meta_graph == nullptr) {
|
if (meta_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "Export to meta graph return nullptr";
|
MS_LOG(ERROR) << "Export to meta graph return nullptr";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -39,7 +39,12 @@ class Converter {
|
||||||
this->model_parser_ = nullptr;
|
this->model_parser_ = nullptr;
|
||||||
}
|
}
|
||||||
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag);
|
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag);
|
||||||
|
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag, const void *buf, const size_t &size);
|
||||||
|
|
||||||
|
private:
|
||||||
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag);
|
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag);
|
||||||
|
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag, const void *buf, const size_t &size);
|
||||||
|
schema::MetaGraphT *TransferFuncGraph(const std::unique_ptr<converter::Flags> &flag, FuncGraphPtr func_graph);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
converter::ModelParser *model_parser_ = nullptr;
|
converter::ModelParser *model_parser_ = nullptr;
|
||||||
|
|
|
@ -246,6 +246,12 @@ void MindsporeImporter::RemoveUnusedGraphInput(const FuncGraphPtr &func_graph) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag, const void *buff, const size_t &size) {
|
||||||
|
MindIRLoader mindir_loader;
|
||||||
|
auto func_graph = mindir_loader.LoadMindIR(buff, size);
|
||||||
|
return CheckAndUpdateFuncGraph(flag, func_graph);
|
||||||
|
}
|
||||||
|
|
||||||
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
||||||
FuncGraphPtr func_graph;
|
FuncGraphPtr func_graph;
|
||||||
if (!flag.dec_key.empty()) {
|
if (!flag.dec_key.empty()) {
|
||||||
|
@ -264,6 +270,11 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
||||||
MindIRLoader mindir_loader;
|
MindIRLoader mindir_loader;
|
||||||
func_graph = mindir_loader.LoadMindIR(flag.modelFile);
|
func_graph = mindir_loader.LoadMindIR(flag.modelFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return CheckAndUpdateFuncGraph(flag, func_graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const converter::Flags &flag, FuncGraphPtr func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
|
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
|
||||||
MS_LOG(ERROR)
|
MS_LOG(ERROR)
|
||||||
|
|
|
@ -29,11 +29,13 @@ class MindsporeImporter {
|
||||||
MindsporeImporter() = default;
|
MindsporeImporter() = default;
|
||||||
~MindsporeImporter() = default;
|
~MindsporeImporter() = default;
|
||||||
FuncGraphPtr ImportMindIR(const converter::Flags &flag);
|
FuncGraphPtr ImportMindIR(const converter::Flags &flag);
|
||||||
|
FuncGraphPtr ImportMindIR(const converter::Flags &flag, const void *buff, const size_t &size);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static void RemoveUnusedGraphInput(const FuncGraphPtr &func_graph);
|
static void RemoveUnusedGraphInput(const FuncGraphPtr &func_graph);
|
||||||
STATUS ProcessDependCnode(const CNodePtr &cnode);
|
STATUS ProcessDependCnode(const CNodePtr &cnode);
|
||||||
STATUS GetFuncGraphOutputName(const CNodePtr &cnode);
|
STATUS GetFuncGraphOutputName(const CNodePtr &cnode);
|
||||||
|
FuncGraphPtr CheckAndUpdateFuncGraph(const converter::Flags &flag, FuncGraphPtr func_graph);
|
||||||
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag);
|
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag);
|
||||||
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
|
size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len);
|
||||||
std::vector<std::string> output_tensor_name_;
|
std::vector<std::string> output_tensor_name_;
|
||||||
|
|
Loading…
Reference in New Issue