From ca6c84b806378803070dcac912503e5189322cd7 Mon Sep 17 00:00:00 2001 From: hangq Date: Sat, 8 Aug 2020 15:29:42 +0800 Subject: [PATCH] modify return type of Model::Import from std::shared_ptr to Model * --- build.sh | 2 +- mindspore/lite/include/model.h | 8 +++---- mindspore/lite/src/model.cc | 11 ++++++---- mindspore/lite/src/model_impl.cc | 4 ++-- mindspore/lite/src/model_impl.h | 2 +- mindspore/lite/test/ut/src/infer_test.cc | 6 +++--- mindspore/lite/tools/benchmark/benchmark.cc | 21 ++++++++++++------- .../converter/quantizer/post_training.cc | 2 +- .../lite/tools/time_profile/time_profile.cc | 4 +++- 9 files changed, 35 insertions(+), 25 deletions(-) diff --git a/build.sh b/build.sh index 7502e757c8..907e1c0e59 100755 --- a/build.sh +++ b/build.sh @@ -397,7 +397,7 @@ checkndk() { if [ "${ANDROID_NDK}" ]; then echo -e "\e[31mANDROID_NDK_PATH=$ANDROID_NDK \e[0m" else - echo -e "\e[31mplease set ANDROID_NDK_PATH in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m" + echo -e "\e[31mplease set ANDROID_NDK in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m" exit 1 fi } diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h index 5000d8a1d5..5813caeca8 100644 --- a/mindspore/lite/include/model.h +++ b/mindspore/lite/include/model.h @@ -45,7 +45,7 @@ class MS_API Model { /// \param[in] size Define bytes numbers of model buffer. /// /// \return Pointer of MindSpore Lite Model. - static std::shared_ptr Import(const char *model_buf, size_t size); + static Model *Import(const char *model_buf, size_t size); /// \brief Constructor of MindSpore Lite Model using default value for parameters. /// @@ -53,7 +53,7 @@ class MS_API Model { Model() = default; /// \brief Destructor of MindSpore Lite Model. - virtual ~Model() = default; + virtual ~Model(); /// \brief Get MindSpore Lite Primitive by name. /// @@ -70,13 +70,13 @@ class MS_API Model { /// \brief Get MindSpore Lite ModelImpl. /// /// \return A pointer of MindSpore Lite ModelImpl. - std::shared_ptr model_impl(); + ModelImpl *model_impl(); /// \brief Free MetaGraph in MindSpore Lite Model. void FreeMetaGraph(); protected: - std::shared_ptr model_impl_ = nullptr; + ModelImpl *model_impl_ = nullptr; }; /// \brief ModelBuilder defined by MindSpore Lite. diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index 55d10278f7..0fecef68e4 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -24,12 +24,16 @@ namespace mindspore::lite { -std::shared_ptr Model::Import(const char *model_buf, size_t size) { - auto model = std::make_shared(); +Model *Model::Import(const char *model_buf, size_t size) { + auto model = new Model(); model->model_impl_ = ModelImpl::Import(model_buf, size); return model; } +Model::~Model() { + delete(this->model_impl_); +} + lite::Primitive *Model::GetOp(const std::string &name) const { MS_EXCEPTION_IF_NULL(model_impl_); return const_cast(model_impl_->GetOp(name)); @@ -45,9 +49,8 @@ const schema::MetaGraph *Model::GetMetaGraph() const { return model_impl_->GetMetaGraph(); } -std::shared_ptr Model::model_impl() { +ModelImpl *Model::model_impl() { MS_EXCEPTION_IF_NULL(model_impl_); return this->model_impl_; } } // namespace mindspore::lite - diff --git a/mindspore/lite/src/model_impl.cc b/mindspore/lite/src/model_impl.cc index abead7cba1..045b3786e9 100644 --- a/mindspore/lite/src/model_impl.cc +++ b/mindspore/lite/src/model_impl.cc @@ -20,7 +20,7 @@ #include "utils/log_adapter.h" namespace mindspore::lite { -std::shared_ptr ModelImpl::Import(const char *model_buf, size_t size) { +ModelImpl *ModelImpl::Import(const char *model_buf, size_t size) { MS_EXCEPTION_IF_NULL(model_buf); flatbuffers::Verifier verify((const uint8_t *)model_buf, size); if (!schema::VerifyMetaGraphBuffer(verify)) { @@ -33,7 +33,7 @@ std::shared_ptr ModelImpl::Import(const char *model_buf, size_t size) return nullptr; } memcpy(inner_model_buf, model_buf, size); - auto model = std::make_shared(inner_model_buf, size); + auto model = new (std::nothrow) ModelImpl(inner_model_buf, size); if (model == nullptr) { MS_LOG(ERROR) << "Create modelImpl failed"; return nullptr; diff --git a/mindspore/lite/src/model_impl.h b/mindspore/lite/src/model_impl.h index 14e0a1ccb9..8a1af3e9dc 100644 --- a/mindspore/lite/src/model_impl.h +++ b/mindspore/lite/src/model_impl.h @@ -27,7 +27,7 @@ namespace mindspore { namespace lite { class ModelImpl { public: - static std::shared_ptr Import(const char *model_buf, size_t size); + static ModelImpl *Import(const char *model_buf, size_t size); ModelImpl() = default; explicit ModelImpl(const char *model_buf, size_t size) : model_buf_(model_buf), buf_size_(size) { meta_graph = schema::GetMetaGraph(model_buf); diff --git a/mindspore/lite/test/ut/src/infer_test.cc b/mindspore/lite/test/ut/src/infer_test.cc index 5931d89034..6bce0ddad2 100644 --- a/mindspore/lite/test/ut/src/infer_test.cc +++ b/mindspore/lite/test/ut/src/infer_test.cc @@ -109,7 +109,7 @@ TEST_F(InferTest, TestConvNode) { context->thread_num_ = 4; auto session = session::LiteSession::CreateSession(context); ASSERT_NE(nullptr, session); - auto ret = session->CompileGraph(model.get()); + auto ret = session->CompileGraph(model); ASSERT_EQ(lite::RET_OK, ret); auto inputs = session->GetInputs(); ASSERT_EQ(inputs.size(), 1); @@ -206,7 +206,7 @@ TEST_F(InferTest, TestAddNode) { context->thread_num_ = 4; auto session = session::LiteSession::CreateSession(context); ASSERT_NE(nullptr, session); - auto ret = session->CompileGraph(model.get()); + auto ret = session->CompileGraph(model); ASSERT_EQ(lite::RET_OK, ret); auto inputs = session->GetInputs(); ASSERT_EQ(inputs.size(), 2); @@ -257,7 +257,7 @@ TEST_F(InferTest, TestModel) { context->thread_num_ = 4; auto session = session::LiteSession::CreateSession(context); ASSERT_NE(nullptr, session); - auto ret = session->CompileGraph(model.get()); + auto ret = session->CompileGraph(model); ASSERT_EQ(lite::RET_OK, ret); auto inputs = session->GetInputs(); ASSERT_EQ(inputs.size(), 1); diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index c36fe113fd..57f2ed1e5a 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -371,7 +371,7 @@ int Benchmark::RunBenchmark(const std::string &deviceType) { return RET_ERROR; } delete[](graphBuf); - auto context = new(std::nothrow) lite::Context; + auto context = new (std::nothrow) lite::Context; if (context == nullptr) { MS_LOG(ERROR) << "New context failed while running %s", modelName.c_str(); return RET_ERROR; @@ -393,15 +393,16 @@ int Benchmark::RunBenchmark(const std::string &deviceType) { } context->thread_num_ = _flags->numThreads; session = session::LiteSession::CreateSession(context); - delete(context); + delete (context); if (session == nullptr) { MS_LOG(ERROR) << "CreateSession failed while running %s", modelName.c_str(); return RET_ERROR; } - auto ret = session->CompileGraph(model.get()); + auto ret = session->CompileGraph(model); if (ret != RET_OK) { MS_LOG(ERROR) << "CompileGraph failed while running %s", modelName.c_str(); - delete(session); + delete (session); + delete (model); return ret; } msInputs = session->GetInputs(); @@ -419,21 +420,24 @@ int Benchmark::RunBenchmark(const std::string &deviceType) { auto status = LoadInput(); if (status != 0) { MS_LOG(ERROR) << "Generate input data error"; - delete(session); + delete (session); + delete (model); return status; } if (!_flags->calibDataPath.empty()) { status = MarkAccuracy(); if (status != 0) { MS_LOG(ERROR) << "Run MarkAccuracy error: %d" << status; - delete(session); + delete (session); + delete (model); return status; } } else { status = MarkPerformance(); if (status != 0) { MS_LOG(ERROR) << "Run MarkPerformance error: %d" << status; - delete(session); + delete (session); + delete (model); return status; } } @@ -447,7 +451,8 @@ int Benchmark::RunBenchmark(const std::string &deviceType) { calibData.clear(); } - delete(session); + delete (session); + delete (model); return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/post_training.cc b/mindspore/lite/tools/converter/quantizer/post_training.cc index 10e0609add..fe4e16bdaa 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training.cc @@ -920,7 +920,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) { return RET_ERROR; } - auto ret = session_->CompileGraph(model.get()); + auto ret = session_->CompileGraph(model); if (ret != lite::RET_OK) { MS_LOG(ERROR) << "compile graph error"; return RET_ERROR; diff --git a/mindspore/lite/tools/time_profile/time_profile.cc b/mindspore/lite/tools/time_profile/time_profile.cc index 65b375c284..a8536173ca 100644 --- a/mindspore/lite/tools/time_profile/time_profile.cc +++ b/mindspore/lite/tools/time_profile/time_profile.cc @@ -278,7 +278,7 @@ int TimeProfile::RunTimeProfile() { } auto model = lite::Model::Import(graphBuf, size); - auto ret = session_->CompileGraph(model.get()); + auto ret = session_->CompileGraph(model); if (ret != RET_OK) { MS_LOG(ERROR) << "Compile graph failed."; return RET_ERROR; @@ -336,6 +336,8 @@ int TimeProfile::RunTimeProfile() { } ms_inputs_.clear(); delete graphBuf; + delete session_; + delete model; return ret; }