modify return type of Model::Import from std::shared_ptr<Model> to Model
*
This commit is contained in:
parent
123e43cd02
commit
ca6c84b806
2
build.sh
2
build.sh
|
@ -397,7 +397,7 @@ checkndk() {
|
||||||
if [ "${ANDROID_NDK}" ]; then
|
if [ "${ANDROID_NDK}" ]; then
|
||||||
echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m"
|
echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m"
|
||||||
else
|
else
|
||||||
echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m"
|
echo -e "\e[31mplease set ANDROID_NDK in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ class MS_API Model {
|
||||||
/// \param[in] size Define bytes numbers of model buffer.
|
/// \param[in] size Define bytes numbers of model buffer.
|
||||||
///
|
///
|
||||||
/// \return Pointer of MindSpore Lite Model.
|
/// \return Pointer of MindSpore Lite Model.
|
||||||
static std::shared_ptr<Model> Import(const char *model_buf, size_t size);
|
static Model *Import(const char *model_buf, size_t size);
|
||||||
|
|
||||||
/// \brief Constructor of MindSpore Lite Model using default value for parameters.
|
/// \brief Constructor of MindSpore Lite Model using default value for parameters.
|
||||||
///
|
///
|
||||||
|
@ -53,7 +53,7 @@ class MS_API Model {
|
||||||
Model() = default;
|
Model() = default;
|
||||||
|
|
||||||
/// \brief Destructor of MindSpore Lite Model.
|
/// \brief Destructor of MindSpore Lite Model.
|
||||||
virtual ~Model() = default;
|
virtual ~Model();
|
||||||
|
|
||||||
/// \brief Get MindSpore Lite Primitive by name.
|
/// \brief Get MindSpore Lite Primitive by name.
|
||||||
///
|
///
|
||||||
|
@ -70,13 +70,13 @@ class MS_API Model {
|
||||||
/// \brief Get MindSpore Lite ModelImpl.
|
/// \brief Get MindSpore Lite ModelImpl.
|
||||||
///
|
///
|
||||||
/// \return A pointer of MindSpore Lite ModelImpl.
|
/// \return A pointer of MindSpore Lite ModelImpl.
|
||||||
std::shared_ptr<ModelImpl> model_impl();
|
ModelImpl *model_impl();
|
||||||
|
|
||||||
/// \brief Free MetaGraph in MindSpore Lite Model.
|
/// \brief Free MetaGraph in MindSpore Lite Model.
|
||||||
void FreeMetaGraph();
|
void FreeMetaGraph();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<ModelImpl> model_impl_ = nullptr;
|
ModelImpl *model_impl_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief ModelBuilder defined by MindSpore Lite.
|
/// \brief ModelBuilder defined by MindSpore Lite.
|
||||||
|
|
|
@ -24,12 +24,16 @@
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
|
||||||
std::shared_ptr<Model> Model::Import(const char *model_buf, size_t size) {
|
Model *Model::Import(const char *model_buf, size_t size) {
|
||||||
auto model = std::make_shared<Model>();
|
auto model = new Model();
|
||||||
model->model_impl_ = ModelImpl::Import(model_buf, size);
|
model->model_impl_ = ModelImpl::Import(model_buf, size);
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Model::~Model() {
|
||||||
|
delete(this->model_impl_);
|
||||||
|
}
|
||||||
|
|
||||||
lite::Primitive *Model::GetOp(const std::string &name) const {
|
lite::Primitive *Model::GetOp(const std::string &name) const {
|
||||||
MS_EXCEPTION_IF_NULL(model_impl_);
|
MS_EXCEPTION_IF_NULL(model_impl_);
|
||||||
return const_cast<Primitive *>(model_impl_->GetOp(name));
|
return const_cast<Primitive *>(model_impl_->GetOp(name));
|
||||||
|
@ -45,9 +49,8 @@ const schema::MetaGraph *Model::GetMetaGraph() const {
|
||||||
return model_impl_->GetMetaGraph();
|
return model_impl_->GetMetaGraph();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ModelImpl> Model::model_impl() {
|
ModelImpl *Model::model_impl() {
|
||||||
MS_EXCEPTION_IF_NULL(model_impl_);
|
MS_EXCEPTION_IF_NULL(model_impl_);
|
||||||
return this->model_impl_;
|
return this->model_impl_;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
std::shared_ptr<ModelImpl> ModelImpl::Import(const char *model_buf, size_t size) {
|
ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) {
|
||||||
MS_EXCEPTION_IF_NULL(model_buf);
|
MS_EXCEPTION_IF_NULL(model_buf);
|
||||||
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
|
||||||
if (!schema::VerifyMetaGraphBuffer(verify)) {
|
if (!schema::VerifyMetaGraphBuffer(verify)) {
|
||||||
|
@ -33,7 +33,7 @@ std::shared_ptr<ModelImpl> ModelImpl::Import(const char *model_buf, size_t size)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
memcpy(inner_model_buf, model_buf, size);
|
memcpy(inner_model_buf, model_buf, size);
|
||||||
auto model = std::make_shared<ModelImpl>(inner_model_buf, size);
|
auto model = new (std::nothrow) ModelImpl(inner_model_buf, size);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "Create modelImpl failed";
|
MS_LOG(ERROR) << "Create modelImpl failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
class ModelImpl {
|
class ModelImpl {
|
||||||
public:
|
public:
|
||||||
static std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size);
|
static ModelImpl *Import(const char *model_buf, size_t size);
|
||||||
ModelImpl() = default;
|
ModelImpl() = default;
|
||||||
explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) {
|
explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) {
|
||||||
meta_graph = schema::GetMetaGraph(model_buf);
|
meta_graph = schema::GetMetaGraph(model_buf);
|
||||||
|
|
|
@ -109,7 +109,7 @@ TEST_F(InferTest, TestConvNode) {
|
||||||
context->thread_num_ = 4;
|
context->thread_num_ = 4;
|
||||||
auto session = session::LiteSession::CreateSession(context);
|
auto session = session::LiteSession::CreateSession(context);
|
||||||
ASSERT_NE(nullptr, session);
|
ASSERT_NE(nullptr, session);
|
||||||
auto ret = session->CompileGraph(model.get());
|
auto ret = session->CompileGraph(model);
|
||||||
ASSERT_EQ(lite::RET_OK, ret);
|
ASSERT_EQ(lite::RET_OK, ret);
|
||||||
auto inputs = session->GetInputs();
|
auto inputs = session->GetInputs();
|
||||||
ASSERT_EQ(inputs.size(), 1);
|
ASSERT_EQ(inputs.size(), 1);
|
||||||
|
@ -206,7 +206,7 @@ TEST_F(InferTest, TestAddNode) {
|
||||||
context->thread_num_ = 4;
|
context->thread_num_ = 4;
|
||||||
auto session = session::LiteSession::CreateSession(context);
|
auto session = session::LiteSession::CreateSession(context);
|
||||||
ASSERT_NE(nullptr, session);
|
ASSERT_NE(nullptr, session);
|
||||||
auto ret = session->CompileGraph(model.get());
|
auto ret = session->CompileGraph(model);
|
||||||
ASSERT_EQ(lite::RET_OK, ret);
|
ASSERT_EQ(lite::RET_OK, ret);
|
||||||
auto inputs = session->GetInputs();
|
auto inputs = session->GetInputs();
|
||||||
ASSERT_EQ(inputs.size(), 2);
|
ASSERT_EQ(inputs.size(), 2);
|
||||||
|
@ -257,7 +257,7 @@ TEST_F(InferTest, TestModel) {
|
||||||
context->thread_num_ = 4;
|
context->thread_num_ = 4;
|
||||||
auto session = session::LiteSession::CreateSession(context);
|
auto session = session::LiteSession::CreateSession(context);
|
||||||
ASSERT_NE(nullptr, session);
|
ASSERT_NE(nullptr, session);
|
||||||
auto ret = session->CompileGraph(model.get());
|
auto ret = session->CompileGraph(model);
|
||||||
ASSERT_EQ(lite::RET_OK, ret);
|
ASSERT_EQ(lite::RET_OK, ret);
|
||||||
auto inputs = session->GetInputs();
|
auto inputs = session->GetInputs();
|
||||||
ASSERT_EQ(inputs.size(), 1);
|
ASSERT_EQ(inputs.size(), 1);
|
||||||
|
|
|
@ -371,7 +371,7 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
delete[](graphBuf);
|
delete[](graphBuf);
|
||||||
auto context = new(std::nothrow) lite::Context;
|
auto context = new (std::nothrow) lite::Context;
|
||||||
if (context == nullptr) {
|
if (context == nullptr) {
|
||||||
MS_LOG(ERROR) << "New context failed while running %s", modelName.c_str();
|
MS_LOG(ERROR) << "New context failed while running %s", modelName.c_str();
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -393,15 +393,16 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
|
||||||
}
|
}
|
||||||
context->thread_num_ = _flags->numThreads;
|
context->thread_num_ = _flags->numThreads;
|
||||||
session = session::LiteSession::CreateSession(context);
|
session = session::LiteSession::CreateSession(context);
|
||||||
delete(context);
|
delete (context);
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
MS_LOG(ERROR) << "CreateSession failed while running %s", modelName.c_str();
|
MS_LOG(ERROR) << "CreateSession failed while running %s", modelName.c_str();
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto ret = session->CompileGraph(model.get());
|
auto ret = session->CompileGraph(model);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "CompileGraph failed while running %s", modelName.c_str();
|
MS_LOG(ERROR) << "CompileGraph failed while running %s", modelName.c_str();
|
||||||
delete(session);
|
delete (session);
|
||||||
|
delete (model);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
msInputs = session->GetInputs();
|
msInputs = session->GetInputs();
|
||||||
|
@ -419,21 +420,24 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
|
||||||
auto status = LoadInput();
|
auto status = LoadInput();
|
||||||
if (status != 0) {
|
if (status != 0) {
|
||||||
MS_LOG(ERROR) << "Generate input data error";
|
MS_LOG(ERROR) << "Generate input data error";
|
||||||
delete(session);
|
delete (session);
|
||||||
|
delete (model);
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
if (!_flags->calibDataPath.empty()) {
|
if (!_flags->calibDataPath.empty()) {
|
||||||
status = MarkAccuracy();
|
status = MarkAccuracy();
|
||||||
if (status != 0) {
|
if (status != 0) {
|
||||||
MS_LOG(ERROR) << "Run MarkAccuracy error: %d" << status;
|
MS_LOG(ERROR) << "Run MarkAccuracy error: %d" << status;
|
||||||
delete(session);
|
delete (session);
|
||||||
|
delete (model);
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
status = MarkPerformance();
|
status = MarkPerformance();
|
||||||
if (status != 0) {
|
if (status != 0) {
|
||||||
MS_LOG(ERROR) << "Run MarkPerformance error: %d" << status;
|
MS_LOG(ERROR) << "Run MarkPerformance error: %d" << status;
|
||||||
delete(session);
|
delete (session);
|
||||||
|
delete (model);
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -447,7 +451,8 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
|
||||||
calibData.clear();
|
calibData.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(session);
|
delete (session);
|
||||||
|
delete (model);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -920,7 +920,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ret = session_->CompileGraph(model.get());
|
auto ret = session_->CompileGraph(model);
|
||||||
if (ret != lite::RET_OK) {
|
if (ret != lite::RET_OK) {
|
||||||
MS_LOG(ERROR) << "compile graph error";
|
MS_LOG(ERROR) << "compile graph error";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -278,7 +278,7 @@ int TimeProfile::RunTimeProfile() {
|
||||||
}
|
}
|
||||||
auto model = lite::Model::Import(graphBuf, size);
|
auto model = lite::Model::Import(graphBuf, size);
|
||||||
|
|
||||||
auto ret = session_->CompileGraph(model.get());
|
auto ret = session_->CompileGraph(model);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Compile graph failed.";
|
MS_LOG(ERROR) << "Compile graph failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -336,6 +336,8 @@ int TimeProfile::RunTimeProfile() {
|
||||||
}
|
}
|
||||||
ms_inputs_.clear();
|
ms_inputs_.clear();
|
||||||
delete graphBuf;
|
delete graphBuf;
|
||||||
|
delete session_;
|
||||||
|
delete model;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue