modify return type of Model::Import from std::shared_ptr<Model> to Model

*
This commit is contained in:
hangq 2020-08-08 15:29:42 +08:00
parent 123e43cd02
commit ca6c84b806
9 changed files with 35 additions and 25 deletions

View File

@ -397,7 +397,7 @@ checkndk() {
if [ "${ANDROID_NDK}" ]; then if [ "${ANDROID_NDK}" ]; then
echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m" echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m"
else else
echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m" echo -e "\e[31mplease set ANDROID_NDK in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m"
exit 1 exit 1
fi fi
} }

View File

@ -45,7 +45,7 @@ class MS_API Model {
/// \param[in] size Define bytes numbers of model buffer. /// \param[in] size Define bytes numbers of model buffer.
/// ///
/// \return Pointer of MindSpore Lite Model. /// \return Pointer of MindSpore Lite Model.
static std::shared_ptr<Model> Import(const char *model_buf, size_t size); static Model *Import(const char *model_buf, size_t size);
/// \brief Constructor of MindSpore Lite Model using default value for parameters. /// \brief Constructor of MindSpore Lite Model using default value for parameters.
/// ///
@ -53,7 +53,7 @@ class MS_API Model {
Model() = default; Model() = default;
/// \brief Destructor of MindSpore Lite Model. /// \brief Destructor of MindSpore Lite Model.
virtual ~Model() = default; virtual ~Model();
/// \brief Get MindSpore Lite Primitive by name. /// \brief Get MindSpore Lite Primitive by name.
/// ///
@ -70,13 +70,13 @@ class MS_API Model {
/// \brief Get MindSpore Lite ModelImpl. /// \brief Get MindSpore Lite ModelImpl.
/// ///
/// \return A pointer of MindSpore Lite ModelImpl. /// \return A pointer of MindSpore Lite ModelImpl.
std::shared_ptr<ModelImpl> model_impl(); ModelImpl *model_impl();
/// \brief Free MetaGraph in MindSpore Lite Model. /// \brief Free MetaGraph in MindSpore Lite Model.
void FreeMetaGraph(); void FreeMetaGraph();
protected: protected:
std::shared_ptr<ModelImpl> model_impl_ = nullptr; ModelImpl *model_impl_ = nullptr;
}; };
/// \brief ModelBuilder defined by MindSpore Lite. /// \brief ModelBuilder defined by MindSpore Lite.

View File

@ -24,12 +24,16 @@
namespace mindspore::lite { namespace mindspore::lite {
std::shared_ptr<Model> Model::Import(const char *model_buf, size_t size) { Model *Model::Import(const char *model_buf, size_t size) {
auto model = std::make_shared<Model>(); auto model = new Model();
model->model_impl_ = ModelImpl::Import(model_buf, size); model->model_impl_ = ModelImpl::Import(model_buf, size);
return model; return model;
} }
Model::~Model() {
delete(this->model_impl_);
}
lite::Primitive *Model::GetOp(const std::string &name) const { lite::Primitive *Model::GetOp(const std::string &name) const {
MS_EXCEPTION_IF_NULL(model_impl_); MS_EXCEPTION_IF_NULL(model_impl_);
return const_cast<Primitive *>(model_impl_->GetOp(name)); return const_cast<Primitive *>(model_impl_->GetOp(name));
@ -45,9 +49,8 @@ const schema::MetaGraph *Model::GetMetaGraph() const {
return model_impl_->GetMetaGraph(); return model_impl_->GetMetaGraph();
} }
std::shared_ptr<ModelImpl> Model::model_impl() { ModelImpl *Model::model_impl() {
MS_EXCEPTION_IF_NULL(model_impl_); MS_EXCEPTION_IF_NULL(model_impl_);
return this->model_impl_; return this->model_impl_;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -20,7 +20,7 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore::lite { namespace mindspore::lite {
std::shared_ptr<ModelImpl> ModelImpl::Import(const char *model_buf, size_t size) { ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) {
MS_EXCEPTION_IF_NULL(model_buf); MS_EXCEPTION_IF_NULL(model_buf);
flatbuffers::Verifier verify((const uint8_t *)model_buf, size); flatbuffers::Verifier verify((const uint8_t *)model_buf, size);
if (!schema::VerifyMetaGraphBuffer(verify)) { if (!schema::VerifyMetaGraphBuffer(verify)) {
@ -33,7 +33,7 @@ std::shared_ptr<ModelImpl> ModelImpl::Import(const char *model_buf, size_t size)
return nullptr; return nullptr;
} }
memcpy(inner_model_buf, model_buf, size); memcpy(inner_model_buf, model_buf, size);
auto model = std::make_shared<ModelImpl>(inner_model_buf, size); auto model = new (std::nothrow) ModelImpl(inner_model_buf, size);
if (model == nullptr) { if (model == nullptr) {
MS_LOG(ERROR) << "Create modelImpl failed"; MS_LOG(ERROR) << "Create modelImpl failed";
return nullptr; return nullptr;

View File

@ -27,7 +27,7 @@ namespace mindspore {
namespace lite { namespace lite {
class ModelImpl { class ModelImpl {
public: public:
static std::shared_ptr<ModelImpl> Import(const char *model_buf, size_t size); static ModelImpl *Import(const char *model_buf, size_t size);
ModelImpl() = default; ModelImpl() = default;
explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) { explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) {
meta_graph = schema::GetMetaGraph(model_buf); meta_graph = schema::GetMetaGraph(model_buf);

View File

@ -109,7 +109,7 @@ TEST_F(InferTest, TestConvNode) {
context->thread_num_ = 4; context->thread_num_ = 4;
auto session = session::LiteSession::CreateSession(context); auto session = session::LiteSession::CreateSession(context);
ASSERT_NE(nullptr, session); ASSERT_NE(nullptr, session);
auto ret = session->CompileGraph(model.get()); auto ret = session->CompileGraph(model);
ASSERT_EQ(lite::RET_OK, ret); ASSERT_EQ(lite::RET_OK, ret);
auto inputs = session->GetInputs(); auto inputs = session->GetInputs();
ASSERT_EQ(inputs.size(), 1); ASSERT_EQ(inputs.size(), 1);
@ -206,7 +206,7 @@ TEST_F(InferTest, TestAddNode) {
context->thread_num_ = 4; context->thread_num_ = 4;
auto session = session::LiteSession::CreateSession(context); auto session = session::LiteSession::CreateSession(context);
ASSERT_NE(nullptr, session); ASSERT_NE(nullptr, session);
auto ret = session->CompileGraph(model.get()); auto ret = session->CompileGraph(model);
ASSERT_EQ(lite::RET_OK, ret); ASSERT_EQ(lite::RET_OK, ret);
auto inputs = session->GetInputs(); auto inputs = session->GetInputs();
ASSERT_EQ(inputs.size(), 2); ASSERT_EQ(inputs.size(), 2);
@ -257,7 +257,7 @@ TEST_F(InferTest, TestModel) {
context->thread_num_ = 4; context->thread_num_ = 4;
auto session = session::LiteSession::CreateSession(context); auto session = session::LiteSession::CreateSession(context);
ASSERT_NE(nullptr, session); ASSERT_NE(nullptr, session);
auto ret = session->CompileGraph(model.get()); auto ret = session->CompileGraph(model);
ASSERT_EQ(lite::RET_OK, ret); ASSERT_EQ(lite::RET_OK, ret);
auto inputs = session->GetInputs(); auto inputs = session->GetInputs();
ASSERT_EQ(inputs.size(), 1); ASSERT_EQ(inputs.size(), 1);

View File

@ -371,7 +371,7 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
return RET_ERROR; return RET_ERROR;
} }
delete[](graphBuf); delete[](graphBuf);
auto context = new(std::nothrow) lite::Context; auto context = new (std::nothrow) lite::Context;
if (context == nullptr) { if (context == nullptr) {
MS_LOG(ERROR) << "New context failed while running %s", modelName.c_str(); MS_LOG(ERROR) << "New context failed while running %s", modelName.c_str();
return RET_ERROR; return RET_ERROR;
@ -393,15 +393,16 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
} }
context->thread_num_ = _flags->numThreads; context->thread_num_ = _flags->numThreads;
session = session::LiteSession::CreateSession(context); session = session::LiteSession::CreateSession(context);
delete(context); delete (context);
if (session == nullptr) { if (session == nullptr) {
MS_LOG(ERROR) << "CreateSession failed while running %s", modelName.c_str(); MS_LOG(ERROR) << "CreateSession failed while running %s", modelName.c_str();
return RET_ERROR; return RET_ERROR;
} }
auto ret = session->CompileGraph(model.get()); auto ret = session->CompileGraph(model);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "CompileGraph failed while running %s", modelName.c_str(); MS_LOG(ERROR) << "CompileGraph failed while running %s", modelName.c_str();
delete(session); delete (session);
delete (model);
return ret; return ret;
} }
msInputs = session->GetInputs(); msInputs = session->GetInputs();
@ -419,21 +420,24 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
auto status = LoadInput(); auto status = LoadInput();
if (status != 0) { if (status != 0) {
MS_LOG(ERROR) << "Generate input data error"; MS_LOG(ERROR) << "Generate input data error";
delete(session); delete (session);
delete (model);
return status; return status;
} }
if (!_flags->calibDataPath.empty()) { if (!_flags->calibDataPath.empty()) {
status = MarkAccuracy(); status = MarkAccuracy();
if (status != 0) { if (status != 0) {
MS_LOG(ERROR) << "Run MarkAccuracy error: %d" << status; MS_LOG(ERROR) << "Run MarkAccuracy error: %d" << status;
delete(session); delete (session);
delete (model);
return status; return status;
} }
} else { } else {
status = MarkPerformance(); status = MarkPerformance();
if (status != 0) { if (status != 0) {
MS_LOG(ERROR) << "Run MarkPerformance error: %d" << status; MS_LOG(ERROR) << "Run MarkPerformance error: %d" << status;
delete(session); delete (session);
delete (model);
return status; return status;
} }
} }
@ -447,7 +451,8 @@ int Benchmark::RunBenchmark(const std::string &deviceType) {
calibData.clear(); calibData.clear();
} }
delete(session); delete (session);
delete (model);
return RET_OK; return RET_OK;
} }

View File

@ -920,7 +920,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
return RET_ERROR; return RET_ERROR;
} }
auto ret = session_->CompileGraph(model.get()); auto ret = session_->CompileGraph(model);
if (ret != lite::RET_OK) { if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "compile graph error"; MS_LOG(ERROR) << "compile graph error";
return RET_ERROR; return RET_ERROR;

View File

@ -278,7 +278,7 @@ int TimeProfile::RunTimeProfile() {
} }
auto model = lite::Model::Import(graphBuf, size); auto model = lite::Model::Import(graphBuf, size);
auto ret = session_->CompileGraph(model.get()); auto ret = session_->CompileGraph(model);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Compile graph failed."; MS_LOG(ERROR) << "Compile graph failed.";
return RET_ERROR; return RET_ERROR;
@ -336,6 +336,8 @@ int TimeProfile::RunTimeProfile() {
} }
ms_inputs_.clear(); ms_inputs_.clear();
delete graphBuf; delete graphBuf;
delete session_;
delete model;
return ret; return ret;
} }