forked from mindspore-Ecosystem/mindspore
[MSLITE] model Type
This commit is contained in:
parent
0f143d65b2
commit
aaca4be28b
|
@ -40,6 +40,7 @@ enum ModelType : uint32_t {
|
|||
kOM = 2,
|
||||
kONNX = 3,
|
||||
kFlatBuffer = 4,
|
||||
kMindIR_Opt = 5,
|
||||
// insert new data type here
|
||||
kUnknownType = 0xFFFFFFFF
|
||||
};
|
||||
|
|
|
@ -71,7 +71,8 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
|
|||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
auto ret = lite::LiteSession::CreateSessionByBuf(static_cast<const char *>(model_data), data_size, session.get());
|
||||
auto ret =
|
||||
lite::LiteSession::CreateSessionByBuf(static_cast<const char *>(model_data), model_type, data_size, session.get());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init session failed";
|
||||
return kLiteError;
|
||||
|
@ -99,7 +100,7 @@ Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
|
|||
return kLiteNullptr;
|
||||
}
|
||||
|
||||
auto ret = lite::LiteSession::CreateSessionByPath(model_path, session.get());
|
||||
auto ret = lite::LiteSession::CreateSessionByPath(model_path, model_type, session.get());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init session failed";
|
||||
return kLiteError;
|
||||
|
|
|
@ -192,7 +192,7 @@ int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int LiteModel::VersionVerify(flatbuffers::Verifier *verify) const {
|
||||
int LiteModel::VersionVerify(flatbuffers::Verifier *verify) {
|
||||
if (verify == nullptr) {
|
||||
MS_LOG(ERROR) << "verify is null.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -59,6 +59,8 @@ class LiteModel : public Model {
|
|||
|
||||
SchemaTensorWrapper *GetSchemaTensor(const size_t &tensor_index) const;
|
||||
|
||||
static int VersionVerify(flatbuffers::Verifier *verify);
|
||||
|
||||
private:
|
||||
#ifdef ENABLE_V0
|
||||
int ConvertAttrs(Model::Node *node, std::vector<schema::Tensor *> *dst_tensor);
|
||||
|
@ -265,8 +267,6 @@ class LiteModel : public Model {
|
|||
void SetNodeDeviceType(Node *node, const schema::v0::CNode &c_node) { node->device_type_ = -1; }
|
||||
#endif
|
||||
|
||||
int VersionVerify(flatbuffers::Verifier *verify) const;
|
||||
|
||||
int GenerateModelByVersion();
|
||||
|
||||
int ConvertSubGraph(const schema::SubGraph &sub_graph);
|
||||
|
|
|
@ -1294,7 +1294,7 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf,
|
|||
MS_LOG(ERROR) << "Create session failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = lite::LiteSession::CreateSessionByBuf(model_buf, size, session);
|
||||
auto ret = lite::LiteSession::CreateSessionByBuf(model_buf, mindspore::ModelType::kMindIR_Opt, size, session);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init session failed";
|
||||
delete session;
|
||||
|
@ -1309,7 +1309,7 @@ session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_
|
|||
MS_LOG(ERROR) << "Create session failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = lite::LiteSession::CreateSessionByPath(model_path, session);
|
||||
auto ret = lite::LiteSession::CreateSessionByPath(model_path, mindspore::ModelType::kMindIR_Opt, session);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init session failed";
|
||||
delete session;
|
||||
|
@ -1318,7 +1318,66 @@ session::LiteSession *lite::LiteSession::CreateSession(const std::string &model_
|
|||
return session;
|
||||
}
|
||||
|
||||
int lite::LiteSession::CreateSessionByBuf(const char *model_buf, size_t size, session::LiteSession *session) {
|
||||
const char *lite::LiteSession::LoadModelByBuff(const char *model_buf, mindspore::ModelType model_type, size_t size) {
|
||||
if (model_type == mindspore::ModelType::kMindIR_Opt) {
|
||||
return model_buf;
|
||||
}
|
||||
|
||||
if (model_type == mindspore::ModelType::kMindIR) {
|
||||
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
||||
auto version_verify = lite::LiteModel::VersionVerify(&verify);
|
||||
if (version_verify != SCHEMA_INVALID) {
|
||||
MS_LOG(DEBUG) << "The kMindIR type model buffer is valid mslite model buffer";
|
||||
return model_buf;
|
||||
}
|
||||
|
||||
#ifdef RUNTIME_CONVERT
|
||||
return RuntimeConvert(model_buf, size);
|
||||
#endif
|
||||
MS_LOG(ERROR) << "Please enable runtime convert.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Invalid Model Type";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size) {
|
||||
if (model_type == mindspore::ModelType::kMindIR_Opt) {
|
||||
return lite::ReadFile(file.c_str(), size);
|
||||
}
|
||||
|
||||
if (model_type == mindspore::ModelType::kMindIR) {
|
||||
size_t tmp_size;
|
||||
auto tmp_buf = lite::ReadFile(file.c_str(), &tmp_size);
|
||||
flatbuffers::Verifier verify((const uint8_t *)tmp_buf, tmp_size);
|
||||
auto version_verify = lite::LiteModel::VersionVerify(&verify);
|
||||
if (version_verify != SCHEMA_INVALID) {
|
||||
MS_LOG(DEBUG) << "The kMindIR type model path is valid mslite model";
|
||||
return tmp_buf;
|
||||
}
|
||||
free(tmp_buf);
|
||||
tmp_buf = nullptr;
|
||||
|
||||
#ifdef RUNTIME_CONVERT
|
||||
return RuntimeConvert(file, size);
|
||||
#endif
|
||||
MS_LOG(ERROR) << "Please enable runtime convert.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Invalid Model Type";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int lite::LiteSession::CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, size_t size,
|
||||
session::LiteSession *session) {
|
||||
auto lite_buf = LoadModelByBuff(model_buf, model_type, size);
|
||||
if (lite_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid model_buf";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto *model = lite::ImportFromBuffer(model_buf, size, true);
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
|
@ -1336,21 +1395,10 @@ int lite::LiteSession::CreateSessionByBuf(const char *model_buf, size_t size, se
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
char *lite::LiteSession::LoadModelByPath(const char *file, size_t *size) {
|
||||
if (IsCharEndWith(file, MINDIR_POSTFIX)) {
|
||||
#ifdef RUNTIME_CONVERT
|
||||
return RuntimeConvert(file, size);
|
||||
#endif
|
||||
MS_LOG(ERROR) << "Please enable runtime convert.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return lite::ReadFile(file, size);
|
||||
}
|
||||
|
||||
int lite::LiteSession::CreateSessionByPath(const std::string &model_path, session::LiteSession *session) {
|
||||
int lite::LiteSession::CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type,
|
||||
session::LiteSession *session) {
|
||||
size_t model_size;
|
||||
auto model_buf = LoadModelByPath(model_path.c_str(), &model_size);
|
||||
auto model_buf = LoadModelByPath(model_path, model_type, &model_size);
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model file failed";
|
||||
return RET_ERROR;
|
||||
|
@ -1360,6 +1408,7 @@ int lite::LiteSession::CreateSessionByPath(const std::string &model_path, sessio
|
|||
MS_LOG(ERROR) << "Import model failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
(reinterpret_cast<lite::LiteModel *>(model))->set_keep_model_buf(true);
|
||||
auto ret = session->CompileGraph(model);
|
||||
if (ret != lite::RET_OK) {
|
||||
|
|
|
@ -53,8 +53,10 @@ class LiteSession : public session::LiteSession {
|
|||
|
||||
static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context);
|
||||
|
||||
static int CreateSessionByBuf(const char *model_buf, size_t size, session::LiteSession *session);
|
||||
static int CreateSessionByPath(const std::string &model_path, session::LiteSession *session);
|
||||
static int CreateSessionByBuf(const char *model_buf, mindspore::ModelType model_type, size_t size,
|
||||
session::LiteSession *session);
|
||||
static int CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type,
|
||||
session::LiteSession *session);
|
||||
|
||||
virtual int Init(InnerContext *context);
|
||||
|
||||
|
@ -126,7 +128,8 @@ class LiteSession : public session::LiteSession {
|
|||
|
||||
static void FreePackOpWeight(const std::vector<kernel::LiteKernel *> &kernels);
|
||||
|
||||
static char *LoadModelByPath(const char *file, size_t *size);
|
||||
static const char *LoadModelByBuff(const char *model_buf, mindspore::ModelType model_type, size_t size);
|
||||
static char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size);
|
||||
|
||||
private:
|
||||
int PreCheck(Model *model);
|
||||
|
|
|
@ -24,7 +24,16 @@
|
|||
#include "tools/converter/import/mindspore_importer.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
char *RuntimeConvert(const char *file_path, size_t *size) {
|
||||
char *RuntimeConvert(const char *model_buf, size_t) {
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid input model buffer.";
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(ERROR) << "Invalid Now. Use model path for model build.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
char *RuntimeConvert(const std::string &file_path, size_t *size) {
|
||||
void *model_buf = nullptr;
|
||||
converter::Flags flag;
|
||||
flag.fmk = converter::kFmkTypeMs;
|
||||
|
|
|
@ -19,9 +19,11 @@
|
|||
|
||||
#ifdef RUNTIME_CONVERT
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore::lite {
|
||||
char *RuntimeConvert(const char *file_path, size_t *size);
|
||||
char *RuntimeConvert(const char *model_buf, size_t size);
|
||||
char *RuntimeConvert(const std::string &file_path, size_t *size);
|
||||
} // namespace mindspore::lite
|
||||
#endif // RUNTIME_CONVERT
|
||||
|
||||
|
|
|
@ -207,7 +207,7 @@ TEST_F(MixDataTypeTest, mix1) {
|
|||
auto status = impl->LoadConfig("MixDataTypeTestConfig");
|
||||
ASSERT_EQ(status, kSuccess);
|
||||
|
||||
status = impl->Build(flat_model, size, kFlatBuffer, context);
|
||||
status = impl->Build(flat_model, size, kMindIR_Opt, context);
|
||||
ASSERT_EQ(status, kSuccess);
|
||||
|
||||
/* check */
|
||||
|
|
|
@ -371,7 +371,7 @@ TEST_F(MultipleDeviceTest, NewApi1) {
|
|||
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::GPUDeviceInfo>());
|
||||
|
||||
mindspore::Model *model = new mindspore::Model();
|
||||
auto ret = model->Build(content, size, mindspore::kFlatBuffer, context);
|
||||
auto ret = model->Build(content, size, mindspore::kMindIR_Opt, context);
|
||||
ASSERT_EQ(false, ret.IsOk());
|
||||
|
||||
delete model;
|
||||
|
@ -436,7 +436,7 @@ TEST_F(MultipleDeviceTest, NewApi5) {
|
|||
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::CPUDeviceInfo>());
|
||||
|
||||
auto model_impl = std::make_shared<mindspore::ModelImpl>();
|
||||
auto ret = model_impl->Build(content, size, mindspore::kFlatBuffer, context);
|
||||
auto ret = model_impl->Build(content, size, mindspore::kMindIR_Opt, context);
|
||||
ASSERT_EQ(mindspore::kSuccess, ret.StatusCode());
|
||||
|
||||
CheckResult(reinterpret_cast<const mindspore::lite::LiteSession *>(model_impl->GetSession())->get_kernels(),
|
||||
|
@ -481,7 +481,7 @@ TEST_F(MultipleDeviceTest, NewApi6) {
|
|||
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::GPUDeviceInfo>());
|
||||
|
||||
auto model_impl = std::make_shared<mindspore::ModelImpl>();
|
||||
auto ret = model_impl->Build(content, size, mindspore::kFlatBuffer, context);
|
||||
auto ret = model_impl->Build(content, size, mindspore::kMindIR_Opt, context);
|
||||
ASSERT_EQ(mindspore::kSuccess, ret.StatusCode());
|
||||
|
||||
CheckResult(reinterpret_cast<const mindspore::lite::LiteSession *>(model_impl->GetSession())->get_kernels(),
|
||||
|
@ -525,7 +525,7 @@ TEST_F(MultipleDeviceTest, NewApi7) {
|
|||
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::GPUDeviceInfo>());
|
||||
|
||||
auto model_impl = std::make_shared<mindspore::ModelImpl>();
|
||||
auto ret = model_impl->Build(content, size, mindspore::kFlatBuffer, context);
|
||||
auto ret = model_impl->Build(content, size, mindspore::kMindIR_Opt, context);
|
||||
ASSERT_EQ(mindspore::kSuccess, ret.StatusCode());
|
||||
|
||||
CheckResult(reinterpret_cast<const mindspore::lite::LiteSession *>(model_impl->GetSession())->get_kernels(),
|
||||
|
@ -550,7 +550,7 @@ TEST_F(MultipleDeviceTest, NewApi8) {
|
|||
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::KirinNPUDeviceInfo>());
|
||||
|
||||
auto model_impl = std::make_shared<mindspore::ModelImpl>();
|
||||
auto ret = model_impl->Build(content, size, mindspore::kFlatBuffer, context);
|
||||
auto ret = model_impl->Build(content, size, mindspore::kMindIR_Opt, context);
|
||||
ASSERT_EQ(mindspore::kSuccess, ret.StatusCode());
|
||||
|
||||
CheckResult(reinterpret_cast<const mindspore::lite::LiteSession *>(model_impl->GetSession())->get_kernels(),
|
||||
|
|
|
@ -203,7 +203,7 @@ TEST_F(TestRegistryCustomOp, TestCustomAdd) {
|
|||
|
||||
// build a model
|
||||
auto model = std::make_shared<mindspore::Model>();
|
||||
auto ret = model->Build(content, size, kFlatBuffer, context);
|
||||
auto ret = model->Build(content, size, kMindIR_Opt, context);
|
||||
ASSERT_EQ(kSuccess, ret.StatusCode());
|
||||
auto inputs = model->GetInputs();
|
||||
ASSERT_EQ(inputs.size(), 2);
|
||||
|
|
|
@ -498,7 +498,7 @@ TEST_F(TestGPURegistryCustomOp, TestGPUCustomAdd) {
|
|||
|
||||
// build a model
|
||||
auto model = std::make_shared<mindspore::Model>();
|
||||
auto ret = model->Build(content, size, kFlatBuffer, context);
|
||||
auto ret = model->Build(content, size, kMindIR_Opt, context);
|
||||
ASSERT_EQ(kSuccess, ret.StatusCode());
|
||||
auto inputs = model->GetInputs();
|
||||
ASSERT_EQ(inputs.size(), 2);
|
||||
|
|
|
@ -166,7 +166,7 @@ TEST_F(TestRegistry, TestAdd) {
|
|||
|
||||
// build a model
|
||||
auto model = std::make_shared<mindspore::Model>();
|
||||
auto ret = model->Build(content, size, kFlatBuffer, context);
|
||||
auto ret = model->Build(content, size, kMindIR_Opt, context);
|
||||
ASSERT_EQ(kSuccess, ret.StatusCode());
|
||||
auto inputs = model->GetInputs();
|
||||
ASSERT_EQ(inputs.size(), 2);
|
||||
|
|
|
@ -48,4 +48,42 @@ TEST_F(RuntimeConvert, relu1) {
|
|||
ASSERT_LE(fp32_data[2], 3.0);
|
||||
ASSERT_LE(fp32_data[3], 4.0);
|
||||
}
|
||||
|
||||
TEST_F(RuntimeConvert, relu2) {
|
||||
Model model;
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::CPUDeviceInfo>());
|
||||
Status build_ret = model.Build("./relu.mindir", mindspore::kMindIR_Opt, context);
|
||||
ASSERT_NE(build_ret, Status::OK());
|
||||
}
|
||||
|
||||
TEST_F(RuntimeConvert, relu3) {
|
||||
size_t size;
|
||||
char *mindir_buf = lite::ReadFile("./relu.mindir", &size);
|
||||
ASSERT_NE(mindir_buf, nullptr);
|
||||
|
||||
Model model;
|
||||
auto context = std::make_shared<mindspore::Context>();
|
||||
context->MutableDeviceInfo().push_back(std::make_shared<mindspore::CPUDeviceInfo>());
|
||||
Status build_ret = model.Build(mindir_buf, size, mindspore::kMindIR, context);
|
||||
ASSERT_EQ(build_ret, Status::OK());
|
||||
|
||||
auto inputs = model.GetInputs();
|
||||
auto in = inputs[0];
|
||||
std::vector<float> in_float = {1.0, 2.0, -3.0, -4.0};
|
||||
memcpy(inputs[0].MutableData(), in_float.data(), in.DataSize());
|
||||
auto outputs = model.GetOutputs();
|
||||
|
||||
auto predict_ret = model.Predict(inputs, &outputs);
|
||||
ASSERT_EQ(predict_ret, Status::OK());
|
||||
|
||||
/* checkout output */
|
||||
auto out = outputs[0];
|
||||
void *out_data = out.MutableData();
|
||||
float *fp32_data = reinterpret_cast<float *>(out_data);
|
||||
ASSERT_LE(fp32_data[0], 1.0);
|
||||
ASSERT_LE(fp32_data[1], 2.0);
|
||||
ASSERT_LE(fp32_data[2], 3.0);
|
||||
ASSERT_LE(fp32_data[3], 4.0);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -222,6 +222,33 @@ void BenchmarkFlags::InitResizeDimsList() {
|
|||
}
|
||||
}
|
||||
|
||||
int BenchmarkBase::CheckModelValid() {
|
||||
this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
|
||||
|
||||
if (!flags_->benchmark_data_type_.empty()) {
|
||||
if (data_type_map_.find(flags_->benchmark_data_type_) == data_type_map_.end()) {
|
||||
MS_LOG(ERROR) << "CalibDataType not supported: " << flags_->benchmark_data_type_.c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
msCalibDataType = data_type_map_.at(flags_->benchmark_data_type_);
|
||||
MS_LOG(INFO) << "CalibDataType = " << flags_->benchmark_data_type_.c_str();
|
||||
std::cout << "CalibDataType = " << flags_->benchmark_data_type_.c_str() << std::endl;
|
||||
}
|
||||
|
||||
if (flags_->model_file_.empty()) {
|
||||
MS_LOG(ERROR) << "modelPath is required";
|
||||
std::cerr << "modelPath is required" << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (ModelTypeMap.find(flags_->model_type_) == ModelTypeMap.end()) {
|
||||
MS_LOG(ERROR) << "Invalid model type: " << flags_->model_type_;
|
||||
std::cerr << "Invalid model type: " << flags_->model_type_ << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BenchmarkBase::CheckThreadNumValid() {
|
||||
if (this->flags_->num_threads_ < kThreadNumMin) {
|
||||
MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0";
|
||||
|
@ -333,6 +360,7 @@ int BenchmarkBase::Init() {
|
|||
return 1;
|
||||
}
|
||||
MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_;
|
||||
MS_LOG(INFO) << "ModelType = " << this->flags_->model_type_;
|
||||
MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_;
|
||||
MS_LOG(INFO) << "ConfigFilePath = " << this->flags_->config_file_;
|
||||
MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_;
|
||||
|
@ -345,6 +373,7 @@ int BenchmarkBase::Init() {
|
|||
MS_LOG(INFO) << "EnableParallel = " << this->flags_->enable_parallel_;
|
||||
MS_LOG(INFO) << "calibDataPath = " << this->flags_->benchmark_data_file_;
|
||||
std::cout << "ModelPath = " << this->flags_->model_file_ << std::endl;
|
||||
std::cout << "ModelType = " << this->flags_->model_type_ << std::endl;
|
||||
std::cout << "InDataPath = " << this->flags_->in_data_file_ << std::endl;
|
||||
std::cout << "ConfigFilePath = " << this->flags_->config_file_ << std::endl;
|
||||
std::cout << "InDataType = " << this->flags_->in_data_type_in_ << std::endl;
|
||||
|
@ -368,6 +397,7 @@ int BenchmarkBase::Init() {
|
|||
std::cerr << "Invalid numThreads." << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
static std::vector<std::string> CPU_BIND_MODE_MAP = {"NO_BIND", "HIGHER_CPU", "MID_CPU"};
|
||||
if (this->flags_->cpu_bind_mode_ >= 1) {
|
||||
MS_LOG(INFO) << "cpuBindMode = " << CPU_BIND_MODE_MAP[this->flags_->cpu_bind_mode_];
|
||||
|
@ -377,23 +407,13 @@ int BenchmarkBase::Init() {
|
|||
std::cout << "cpuBindMode = NO_BIND" << std::endl;
|
||||
}
|
||||
|
||||
this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
|
||||
|
||||
if (!flags_->benchmark_data_type_.empty()) {
|
||||
if (data_type_map_.find(flags_->benchmark_data_type_) == data_type_map_.end()) {
|
||||
MS_LOG(ERROR) << "CalibDataType not supported: " << flags_->benchmark_data_type_.c_str();
|
||||
return RET_ERROR;
|
||||
}
|
||||
msCalibDataType = data_type_map_.at(flags_->benchmark_data_type_);
|
||||
MS_LOG(INFO) << "CalibDataType = " << flags_->benchmark_data_type_.c_str();
|
||||
std::cout << "CalibDataType = " << flags_->benchmark_data_type_.c_str() << std::endl;
|
||||
auto model_ret = CheckModelValid();
|
||||
if (model_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Invalid Model File.";
|
||||
std::cerr << "Invalid Model File." << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (flags_->model_file_.empty()) {
|
||||
MS_LOG(ERROR) << "modelPath is required";
|
||||
std::cerr << "modelPath is required" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
flags_->InitInputDataList();
|
||||
flags_->InitResizeDimsList();
|
||||
if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include <utility>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include "include/model.h"
|
||||
#include "include/api/types.h"
|
||||
#include "tools/common/flag_parser.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/common/utils.h"
|
||||
|
@ -62,7 +63,9 @@ constexpr const char *DELIM_SLASH = "/";
|
|||
extern const std::unordered_map<int, std::string> kTypeIdMap;
|
||||
extern const std::unordered_map<schema::Format, std::string> kTensorFormatMap;
|
||||
|
||||
//
|
||||
const std::unordered_map<std::string, mindspore::ModelType> ModelTypeMap{
|
||||
{"MindIR_Opt", mindspore::ModelType::kMindIR_Opt}, {"MindIR", mindspore::ModelType::kMindIR}};
|
||||
|
||||
namespace dump {
|
||||
constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG";
|
||||
constexpr auto kSettings = "common_dump_settings";
|
||||
|
@ -103,11 +106,11 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
BenchmarkFlags() {
|
||||
// common
|
||||
AddFlag(&BenchmarkFlags::model_file_, "modelFile", "Input model file", "");
|
||||
AddFlag(&BenchmarkFlags::model_type_, "modelType", "Input model type. MindIR | MindIR_Opt", "MindIR");
|
||||
AddFlag(&BenchmarkFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", "");
|
||||
AddFlag(&BenchmarkFlags::config_file_, "configFile", "Config file", "");
|
||||
AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310", "CPU");
|
||||
AddFlag(&BenchmarkFlags::cpu_bind_mode_, "cpuBindMode",
|
||||
"Input 0 for NO_BIND, 1 for HIGHER_CPU, 2 for MID_CPU, default value: 1", 1);
|
||||
AddFlag(&BenchmarkFlags::cpu_bind_mode_, "cpuBindMode", "Input 0 for NO_BIND, 1 for HIGHER_CPU, 2 for MID_CPU.", 1);
|
||||
// MarkPerformance
|
||||
AddFlag(&BenchmarkFlags::loop_count_, "loopCount", "Run loop count", 10);
|
||||
AddFlag(&BenchmarkFlags::num_threads_, "numThreads", "Run threads number", 2);
|
||||
|
@ -138,6 +141,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
std::string model_file_;
|
||||
std::string in_data_file_;
|
||||
std::string config_file_;
|
||||
std::string model_type_;
|
||||
std::vector<std::string> input_data_list_;
|
||||
InDataType in_data_type_ = kBinary;
|
||||
std::string in_data_type_in_ = "bin";
|
||||
|
@ -298,6 +302,8 @@ class MS_API BenchmarkBase {
|
|||
|
||||
int CheckThreadNumValid();
|
||||
|
||||
int CheckModelValid();
|
||||
|
||||
int CheckDeviceTypeValid();
|
||||
|
||||
protected:
|
||||
|
|
|
@ -445,6 +445,7 @@ int BenchmarkUnifiedApi::RunBenchmark() {
|
|||
auto start_prepare_time = GetTimeUs();
|
||||
// Load graph
|
||||
std::string model_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1);
|
||||
mindspore::ModelType model_type = ModelTypeMap.at(flags_->model_type_);
|
||||
|
||||
MS_LOG(INFO) << "start unified benchmark run";
|
||||
std::cout << "start unified benchmark run" << std::endl;
|
||||
|
@ -466,7 +467,7 @@ int BenchmarkUnifiedApi::RunBenchmark() {
|
|||
}
|
||||
}
|
||||
|
||||
auto ret = ms_model_.Build(flags_->model_file_, kMindIR, context);
|
||||
auto ret = ms_model_.Build(flags_->model_file_, model_type, context);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "ms_model_.Build failed while running ", model_name.c_str();
|
||||
std::cout << "ms_model_.Build failed while running ", model_name.c_str();
|
||||
|
|
|
@ -39,7 +39,6 @@
|
|||
#include "include/api/model.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
||||
class MS_API BenchmarkUnifiedApi : public BenchmarkBase {
|
||||
public:
|
||||
explicit BenchmarkUnifiedApi(BenchmarkFlags *flags) : BenchmarkBase(flags) {}
|
||||
|
|
Loading…
Reference in New Issue