[MSLITE] model Type

This commit is contained in:
ling 2021-11-05 11:34:34 +08:00
parent 0f143d65b2
commit aaca4be28b
18 changed files with 185 additions and 56 deletions

View File

@ -40,6 +40,7 @@ enum ModelType : uint32_t {
kOM = 2,
kONNX = 3,
kFlatBuffer = 4,
kMindIR_Opt = 5,
// insert new data type here
kUnknownType = 0xFFFFFFFF
};

View File

@ -71,7 +71,8 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
return kLiteNullptr;
}
auto ret = lite::LiteSession::CreateSessionByBuf(static_cast<const char *>(model_data), data_size, session.get());
auto ret =
lite::LiteSession::CreateSessionByBuf(static_cast<const char *>(model_data), model_type, data_size, session.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init session failed";
return kLiteError;
@ -99,7 +100,7 @@ Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
return kLiteNullptr;
}
auto ret = lite::LiteSession::CreateSessionByPath(model_path, session.get());
auto ret = lite::LiteSession::CreateSessionByPath(model_path, model_type, session.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init session failed";
return kLiteError;

View File

@ -192,7 +192,7 @@ int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) {
return RET_OK;
}
int LiteModel::VersionVerify(flatbuffers::Verifier *verify) const {
int LiteModel::VersionVerify(flatbuffers::Verifier *verify) {
if (verify == nullptr) {
MS_LOG(ERROR) << "verify is null.";
return RET_ERROR;

View File

@ -59,6 +59,8 @@ class LiteModel : public Model {
SchemaTensorWrapper *GetSchemaTensor(const size_t &tensor_index) const;
static int VersionVerify(flatbuffers::Verifier *verify);
private:
#ifdef ENABLE_V0
int ConvertAttrs(Model::Node *node, std::vector<schema::Tensor *> *dst_tensor);
@ -265,8 +267,6 @@ class LiteModel : public Model {
void SetNodeDeviceType(Node *node, const schema::v0::CNode &c_node) { node->device_type_ = -1; }
#endif
int VersionVerify(flatbuffers::Verifier *verify) const;
int GenerateModelByVersion();
int ConvertSubGraph(const schema::SubGraph &sub_graph);

View File

@ -1294,7 +1294,7 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf,
MS_LOG(ERROR) << "Create session failed";
return nullptr;
}
auto ret = lite::LiteSession::CreateSessionByBuf(model_buf, size, session);
auto ret = lite::LiteSession::CreateSessionByBuf(model_buf, mindspore::ModelType::kMindIR_Opt, size, session);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init session failed";
delete session;
@ -1309,7 +1309,7 @@ session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_
MS_LOG(ERROR) << "Create session failed";
return nullptr;
}
auto ret = lite::LiteSession::CreateSessionByPath(model_path, session);
auto ret = lite::LiteSession::CreateSessionByPath(model_path, mindspore::ModelType::kMindIR_Opt, session);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init session failed";
delete session;
@ -1318,7 +1318,66 @@ session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_
return session;
}
int lite::LiteSession::CreateSessionByBuf(const char *model_buf, size_t size, session::LiteSession *session) {
const char *lite::LiteSession::LoadModelByBuff(const char *model_buf, mindspore::ModelType model_type, size_t size) {
if (model_type == mindspore::ModelType::kMindIR_Opt) {
return model_buf;
}
if (model_type == mindspore::ModelType::kMindIR) {
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
auto version_verify = lite::LiteModel::VersionVerify(&verify);
if (version_verify != SCHEMA_INVALID) {
MS_LOG(DEBUG) << "The kMindIR type model buffer is valid mslite model buffer";
return model_buf;
}
#ifdef RUNTIME_CONVERT
return RuntimeConvert(model_buf, size);
#endif
MS_LOG(ERROR) << "Please enable runtime convert.";
return nullptr;
}
MS_LOG(ERROR) << "Invalid Model Type";
return nullptr;
}
char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size) {
if (model_type == mindspore::ModelType::kMindIR_Opt) {
return lite::ReadFile(file.c_str(), size);
}
if (model_type == mindspore::ModelType::kMindIR) {
size_t tmp_size;
auto tmp_buf = lite::ReadFile(file.c_str(), &tmp_size);
flatbuffers::Verifier verify((const uint8_t *)tmp_buf, tmp_size);
auto version_verify = lite::LiteModel::VersionVerify(&verify);
if (version_verify != SCHEMA_INVALID) {
MS_LOG(DEBUG) << "The kMindIR type model path is valid mslite model";
return tmp_buf;
}
free(tmp_buf);
tmp_buf = nullptr;
#ifdef RUNTIME_CONVERT
return RuntimeConvert(file, size);
#endif
MS_LOG(ERROR) << "Please enable runtime convert.";
return nullptr;
}
MS_LOG(ERROR) << "Invalid Model Type";
return nullptr;
}
int lite::LiteSession::CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, size_t size,
session::LiteSession *session) {
auto lite_buf = LoadModelByBuff(model_buf, model_type, size);
if (lite_buf == nullptr) {
MS_LOG(ERROR) << "Invalid model_buf";
return RET_ERROR;
}
auto *model = lite::ImportFromBuffer(model_buf, size, true);
if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed";
@ -1336,21 +1395,10 @@ int lite::LiteSession::CreateSessionByBuf(const char *model_buf, size_t size, se
return RET_OK;
}
char *lite::LiteSession::LoadModelByPath(const char *file, size_t *size) {
if (IsCharEndWith(file, MINDIR_POSTFIX)) {
#ifdef RUNTIME_CONVERT
return RuntimeConvert(file, size);
#endif
MS_LOG(ERROR) << "Please enable runtime convert.";
return nullptr;
}
return lite::ReadFile(file, size);
}
int lite::LiteSession::CreateSessionByPath(const std::string &model_path, session::LiteSession *session) {
int lite::LiteSession::CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type,
session::LiteSession *session) {
size_t model_size;
auto model_buf = LoadModelByPath(model_path.c_str(), &model_size);
auto model_buf = LoadModelByPath(model_path, model_type, &model_size);
if (model_buf == nullptr) {
MS_LOG(ERROR) << "Read model file failed";
return RET_ERROR;
@ -1360,6 +1408,7 @@ int lite::LiteSession::CreateSessionByPath(const std::string &model_path, sessio
MS_LOG(ERROR) << "Import model failed";
return RET_ERROR;
}
(reinterpret_cast<lite::LiteModel *>(model))->set_keep_model_buf(true);
auto ret = session->CompileGraph(model);
if (ret != lite::RET_OK) {

View File

@ -53,8 +53,10 @@ class LiteSession : public session::LiteSession {
static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
static int CreateSessionByBuf(const char *model_buf, size_t size, session::LiteSession *session);
static int CreateSessionByPath(const std::string &model_path, session::LiteSession *session);
static int CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, size_t size,
session::LiteSession *session);
static int CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type,
session::LiteSession *session);
virtual int Init(InnerContext *context);
@ -126,7 +128,8 @@ class LiteSession : public session::LiteSession {
static void FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kernels);
static char *LoadModelByPath(const char *file, size_t *size);
static const char *LoadModelByBuff(const char *model_buf, mindspore::ModelType model_type, size_t size);
static char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size);
private:
int PreCheck(Model *model);

View File

@ -24,7 +24,16 @@
#include "tools/converter/import/mindspore_importer.h"
namespace mindspore::lite {
char *RuntimeConvert(const char *file_path, size_t *size) {
char *RuntimeConvert(const char *model_buf, size_t) {
if (model_buf == nullptr) {
MS_LOG(ERROR) << "Invalid input model buffer.";
return nullptr;
}
MS_LOG(ERROR) << "Invalid Now. Use model path for model build.";
return nullptr;
}
char *RuntimeConvert(const std::string &file_path, size_t *size) {
void *model_buf = nullptr;
converter::Flags flag;
flag.fmk = converter::kFmkTypeMs;

View File

@ -19,9 +19,11 @@
#ifdef RUNTIME_CONVERT
#include <stdio.h>
#include <string>
namespace mindspore::lite {
char *RuntimeConvert(const char *file_path, size_t *size);
char *RuntimeConvert(const char *model_buf, size_t size);
char *RuntimeConvert(const std::string &file_path, size_t *size);
} // namespace mindspore::lite
#endif // RUNTIME_CONVERT

View File

@ -207,7 +207,7 @@ TEST_F(MixDataTypeTest, mix1) {
auto status = impl->LoadConfig("MixDataTypeTestConfig");
ASSERT_EQ(status, kSuccess);
status = impl->Build(flat_model, size, kFlatBuffer, context);
status = impl->Build(flat_model, size, kMindIR_Opt, context);
ASSERT_EQ(status, kSuccess);
/* check */

View File

@ -371,7 +371,7 @@ TEST_F(MultipleDeviceTest, NewApi1) {
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::GPUDeviceInfo>());
mindspore::Model *model = new mindspore::Model();
auto ret = model->Build(content, size, mindspore::kFlatBuffer, context);
auto ret = model->Build(content, size, mindspore::kMindIR_Opt, context);
ASSERT_EQ(false, ret.IsOk());
delete model;
@ -436,7 +436,7 @@ TEST_F(MultipleDeviceTest, NewApi5) {
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::CPUDeviceInfo>());
auto model_impl = std::make_shared<mindspore::ModelImpl>();
auto ret = model_impl->Build(content, size, mindspore::kFlatBuffer, context);
auto ret = model_impl->Build(content, size, mindspore::kMindIR_Opt, context);
ASSERT_EQ(mindspore::kSuccess, ret.StatusCode());
CheckResult(reinterpret_cast<const mindspore::lite::LiteSession *>(model_impl->GetSession())->get_kernels(),
@ -481,7 +481,7 @@ TEST_F(MultipleDeviceTest, NewApi6) {
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::GPUDeviceInfo>());
auto model_impl = std::make_shared<mindspore::ModelImpl>();
auto ret = model_impl->Build(content, size, mindspore::kFlatBuffer, context);
auto ret = model_impl->Build(content, size, mindspore::kMindIR_Opt, context);
ASSERT_EQ(mindspore::kSuccess, ret.StatusCode());
CheckResult(reinterpret_cast<const mindspore::lite::LiteSession *>(model_impl->GetSession())->get_kernels(),
@ -525,7 +525,7 @@ TEST_F(MultipleDeviceTest, NewApi7) {
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::GPUDeviceInfo>());
auto model_impl = std::make_shared<mindspore::ModelImpl>();
auto ret = model_impl->Build(content, size, mindspore::kFlatBuffer, context);
auto ret = model_impl->Build(content, size, mindspore::kMindIR_Opt, context);
ASSERT_EQ(mindspore::kSuccess, ret.StatusCode());
CheckResult(reinterpret_cast<const mindspore::lite::LiteSession *>(model_impl->GetSession())->get_kernels(),
@ -550,7 +550,7 @@ TEST_F(MultipleDeviceTest, NewApi8) {
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::KirinNPUDeviceInfo>());
auto model_impl = std::make_shared<mindspore::ModelImpl>();
auto ret = model_impl->Build(content, size, mindspore::kFlatBuffer, context);
auto ret = model_impl->Build(content, size, mindspore::kMindIR_Opt, context);
ASSERT_EQ(mindspore::kSuccess, ret.StatusCode());
CheckResult(reinterpret_cast<const mindspore::lite::LiteSession *>(model_impl->GetSession())->get_kernels(),

View File

@ -203,7 +203,7 @@ TEST_F(TestRegistryCustomOp, TestCustomAdd) {
// build a model
auto model = std::make_shared<mindspore::Model>();
auto ret = model->Build(content, size, kFlatBuffer, context);
auto ret = model->Build(content, size, kMindIR_Opt, context);
ASSERT_EQ(kSuccess, ret.StatusCode());
auto inputs = model->GetInputs();
ASSERT_EQ(inputs.size(), 2);

View File

@ -498,7 +498,7 @@ TEST_F(TestGPURegistryCustomOp, TestGPUCustomAdd) {
// build a model
auto model = std::make_shared<mindspore::Model>();
auto ret = model->Build(content, size, kFlatBuffer, context);
auto ret = model->Build(content, size, kMindIR_Opt, context);
ASSERT_EQ(kSuccess, ret.StatusCode());
auto inputs = model->GetInputs();
ASSERT_EQ(inputs.size(), 2);

View File

@ -166,7 +166,7 @@ TEST_F(TestRegistry, TestAdd) {
// build a model
auto model = std::make_shared<mindspore::Model>();
auto ret = model->Build(content, size, kFlatBuffer, context);
auto ret = model->Build(content, size, kMindIR_Opt, context);
ASSERT_EQ(kSuccess, ret.StatusCode());
auto inputs = model->GetInputs();
ASSERT_EQ(inputs.size(), 2);

View File

@ -48,4 +48,42 @@ TEST_F(RuntimeConvert, relu1) {
ASSERT_LE(fp32_data[2], 3.0);
ASSERT_LE(fp32_data[3], 4.0);
}
TEST_F(RuntimeConvert, relu2) {
Model model;
auto context = std::make_shared<mindspore::Context>();
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::CPUDeviceInfo>());
Status build_ret = model.Build("./relu.mindir", mindspore::kMindIR_Opt, context);
ASSERT_NE(build_ret, Status::OK());
}
TEST_F(RuntimeConvert, relu3) {
size_t size;
char *mindir_buf = lite::ReadFile("./relu.mindir", &size);
ASSERT_NE(mindir_buf, nullptr);
Model model;
auto context = std::make_shared<mindspore::Context>();
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::CPUDeviceInfo>());
Status build_ret = model.Build(mindir_buf, size, mindspore::kMindIR, context);
ASSERT_EQ(build_ret, Status::OK());
auto inputs = model.GetInputs();
auto in = inputs[0];
std::vector<float> in_float = {1.0, 2.0, -3.0, -4.0};
memcpy(inputs[0].MutableData(), in_float.data(), in.DataSize());
auto outputs = model.GetOutputs();
auto predict_ret = model.Predict(inputs, &outputs);
ASSERT_EQ(predict_ret, Status::OK());
/* checkout output */
auto out = outputs[0];
void *out_data = out.MutableData();
float *fp32_data = reinterpret_cast<float *>(out_data);
ASSERT_LE(fp32_data[0], 1.0);
ASSERT_LE(fp32_data[1], 2.0);
ASSERT_LE(fp32_data[2], 3.0);
ASSERT_LE(fp32_data[3], 4.0);
}
} // namespace mindspore

View File

@ -222,6 +222,33 @@ void BenchmarkFlags::InitResizeDimsList() {
}
}
int BenchmarkBase::CheckModelValid() {
this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
if (!flags_->benchmark_data_type_.empty()) {
if (data_type_map_.find(flags_->benchmark_data_type_) == data_type_map_.end()) {
MS_LOG(ERROR) << "CalibDataType not supported: " << flags_->benchmark_data_type_.c_str();
return RET_ERROR;
}
msCalibDataType = data_type_map_.at(flags_->benchmark_data_type_);
MS_LOG(INFO) << "CalibDataType = " << flags_->benchmark_data_type_.c_str();
std::cout << "CalibDataType = " << flags_->benchmark_data_type_.c_str() << std::endl;
}
if (flags_->model_file_.empty()) {
MS_LOG(ERROR) << "modelPath is required";
std::cerr << "modelPath is required" << std::endl;
return RET_ERROR;
}
if (ModelTypeMap.find(flags_->model_type_) == ModelTypeMap.end()) {
MS_LOG(ERROR) << "Invalid model type: " << flags_->model_type_;
std::cerr << "Invalid model type: " << flags_->model_type_ << std::endl;
return RET_ERROR;
}
return RET_OK;
}
int BenchmarkBase::CheckThreadNumValid() {
if (this->flags_->num_threads_ < kThreadNumMin) {
MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0";
@ -333,6 +360,7 @@ int BenchmarkBase::Init() {
return 1;
}
MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_;
MS_LOG(INFO) << "ModelType = " << this->flags_->model_type_;
MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_;
MS_LOG(INFO) << "ConfigFilePath = " << this->flags_->config_file_;
MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_;
@ -345,6 +373,7 @@ int BenchmarkBase::Init() {
MS_LOG(INFO) << "EnableParallel = " << this->flags_->enable_parallel_;
MS_LOG(INFO) << "calibDataPath = " << this->flags_->benchmark_data_file_;
std::cout << "ModelPath = " << this->flags_->model_file_ << std::endl;
std::cout << "ModelType = " << this->flags_->model_type_ << std::endl;
std::cout << "InDataPath = " << this->flags_->in_data_file_ << std::endl;
std::cout << "ConfigFilePath = " << this->flags_->config_file_ << std::endl;
std::cout << "InDataType = " << this->flags_->in_data_type_in_ << std::endl;
@ -368,6 +397,7 @@ int BenchmarkBase::Init() {
std::cerr << "Invalid numThreads." << std::endl;
return RET_ERROR;
}
static std::vector<std::string> CPU_BIND_MODE_MAP = {"NO_BIND", "HIGHER_CPU", "MID_CPU"};
if (this->flags_->cpu_bind_mode_ >= 1) {
MS_LOG(INFO) << "cpuBindMode = " << CPU_BIND_MODE_MAP[this->flags_->cpu_bind_mode_];
@ -377,23 +407,13 @@ int BenchmarkBase::Init() {
std::cout << "cpuBindMode = NO_BIND" << std::endl;
}
this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
if (!flags_->benchmark_data_type_.empty()) {
if (data_type_map_.find(flags_->benchmark_data_type_) == data_type_map_.end()) {
MS_LOG(ERROR) << "CalibDataType not supported: " << flags_->benchmark_data_type_.c_str();
return RET_ERROR;
}
msCalibDataType = data_type_map_.at(flags_->benchmark_data_type_);
MS_LOG(INFO) << "CalibDataType = " << flags_->benchmark_data_type_.c_str();
std::cout << "CalibDataType = " << flags_->benchmark_data_type_.c_str() << std::endl;
auto model_ret = CheckModelValid();
if (model_ret != RET_OK) {
MS_LOG(ERROR) << "Invalid Model File.";
std::cerr << "Invalid Model File." << std::endl;
return RET_ERROR;
}
if (flags_->model_file_.empty()) {
MS_LOG(ERROR) << "modelPath is required";
std::cerr << "modelPath is required" << std::endl;
return 1;
}
flags_->InitInputDataList();
flags_->InitResizeDimsList();
if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&

View File

@ -31,6 +31,7 @@
#include <utility>
#include <nlohmann/json.hpp>
#include "include/model.h"
#include "include/api/types.h"
#include "tools/common/flag_parser.h"
#include "src/common/file_utils.h"
#include "src/common/utils.h"
@ -62,7 +63,9 @@ constexpr const char *DELIM_SLASH = "/";
extern const std::unordered_map<int, std::string> kTypeIdMap;
extern const std::unordered_map<schema::Format, std::string> kTensorFormatMap;
//
const std::unordered_map<std::string, mindspore::ModelType> ModelTypeMap{
{"MindIR_Opt", mindspore::ModelType::kMindIR_Opt}, {"MindIR", mindspore::ModelType::kMindIR}};
namespace dump {
constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG";
constexpr auto kSettings = "common_dump_settings";
@ -103,11 +106,11 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
BenchmarkFlags() {
// common
AddFlag(&BenchmarkFlags::model_file_, "modelFile", "Input model file", "");
AddFlag(&BenchmarkFlags::model_type_, "modelType", "Input model type. MindIR | MindIR_Opt", "MindIR");
AddFlag(&BenchmarkFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", "");
AddFlag(&BenchmarkFlags::config_file_, "configFile", "Config file", "");
AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310", "CPU");
AddFlag(&BenchmarkFlags::cpu_bind_mode_, "cpuBindMode",
"Input 0 for NO_BIND, 1 for HIGHER_CPU, 2 for MID_CPU, default value: 1", 1);
AddFlag(&BenchmarkFlags::cpu_bind_mode_, "cpuBindMode", "Input 0 for NO_BIND, 1 for HIGHER_CPU, 2 for MID_CPU.", 1);
// MarkPerformance
AddFlag(&BenchmarkFlags::loop_count_, "loopCount", "Run loop count", 10);
AddFlag(&BenchmarkFlags::num_threads_, "numThreads", "Run threads number", 2);
@ -138,6 +141,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
std::string model_file_;
std::string in_data_file_;
std::string config_file_;
std::string model_type_;
std::vector<std::string> input_data_list_;
InDataType in_data_type_ = kBinary;
std::string in_data_type_in_ = "bin";
@ -298,6 +302,8 @@ class MS_API BenchmarkBase {
int CheckThreadNumValid();
int CheckModelValid();
int CheckDeviceTypeValid();
protected:

View File

@ -445,6 +445,7 @@ int BenchmarkUnifiedApi::RunBenchmark() {
auto start_prepare_time = GetTimeUs();
// Load graph
std::string model_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1);
mindspore::ModelType model_type = ModelTypeMap.at(flags_->model_type_);
MS_LOG(INFO) << "start unified benchmark run";
std::cout << "start unified benchmark run" << std::endl;
@ -466,7 +467,7 @@ int BenchmarkUnifiedApi::RunBenchmark() {
}
}
auto ret = ms_model_.Build(flags_->model_file_, kMindIR, context);
auto ret = ms_model_.Build(flags_->model_file_, model_type, context);
if (ret != kSuccess) {
MS_LOG(ERROR) << "ms_model_.Build failed while running ", model_name.c_str();
std::cout << "ms_model_.Build failed while running ", model_name.c_str();

View File

@ -39,7 +39,6 @@
#include "include/api/model.h"
namespace mindspore::lite {
class MS_API BenchmarkUnifiedApi : public BenchmarkBase {
public:
explicit BenchmarkUnifiedApi(BenchmarkFlags *flags) : BenchmarkBase(flags) {}