From 4f839a752179795fa5e39326965f8b21dabfbbcf Mon Sep 17 00:00:00 2001 From: yefeng Date: Tue, 8 Feb 2022 17:22:34 +0800 Subject: [PATCH] model pool for r1.6 --- include/api/model_parallel_runner.h | 70 ++++++++ mindspore/lite/CMakeLists.txt | 8 + mindspore/lite/src/CMakeLists.txt | 10 ++ .../model_pool/model_parallel_runner.cc | 53 ++++++ .../lite/src/cxx_api/model_pool/model_pool.cc | 170 ++++++++++++++++++ .../lite/src/cxx_api/model_pool/model_pool.h | 54 ++++++ .../src/cxx_api/model_pool/model_thread.cc | 108 +++++++++++ .../src/cxx_api/model_pool/model_thread.h | 61 +++++++ .../cxx_api/model_pool/predict_task_queue.cc | 57 ++++++ .../cxx_api/model_pool/predict_task_queue.h | 58 ++++++ 10 files changed, 649 insertions(+) create mode 100644 include/api/model_parallel_runner.h create mode 100644 mindspore/lite/src/cxx_api/model_pool/model_parallel_runner.cc create mode 100644 mindspore/lite/src/cxx_api/model_pool/model_pool.cc create mode 100644 mindspore/lite/src/cxx_api/model_pool/model_pool.h create mode 100644 mindspore/lite/src/cxx_api/model_pool/model_thread.cc create mode 100644 mindspore/lite/src/cxx_api/model_pool/model_thread.h create mode 100644 mindspore/lite/src/cxx_api/model_pool/predict_task_queue.cc create mode 100644 mindspore/lite/src/cxx_api/model_pool/predict_task_queue.h diff --git a/include/api/model_parallel_runner.h b/include/api/model_parallel_runner.h new file mode 100644 index 00000000000..8a6417925d5 --- /dev/null +++ b/include/api/model_parallel_runner.h @@ -0,0 +1,70 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_MODEL_RUNNER_H +#define MINDSPORE_INCLUDE_API_MODEL_RUNNER_H +#include +#include +#include +#include +#include "include/api/status.h" +#include "include/api/context.h" + +namespace mindspore { +class ModelPool; + +struct RunnerConfig { + RunnerConfig(std::shared_ptr &ctx, int num) : model_ctx(ctx), num_model(num) {} + std::shared_ptr model_ctx = nullptr; + int num_model; +}; + +/// \brief The ModelRunner class is used to define a MindSpore ModelPoolManager, facilitating Model management. +class MS_API ModelParallelRunner { + public: + ModelParallelRunner() = default; + ~ModelParallelRunner() = default; + + /// \brief build a model runner from model path so that it can run on a device. Only valid for Lite. + /// + /// \param[in] model_path Define the model path. + /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only + /// ModelType::kMindIR is valid for Lite. + /// \param[in] model_context Define the context used to store options during execution. + /// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32. + /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC. + /// + /// \return Status. + Status Init(const std::string &model_path, const std::shared_ptr &runner_config = nullptr, + const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); + + /// \brief Obtains all input tensors of the model. + /// + /// \return The vector that includes all input tensors. + std::vector GetInputs(); + + /// \brief Inference ModelPoolManager. + /// + /// \param[in] inputs A vector where model inputs are arranged in sequence. + /// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence. + /// \param[in] before CallBack before predict. + /// \param[in] after CallBack after predict. + /// + /// \return Status. + Status Predict(const std::vector &inputs, std::vector *outputs, + const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); +}; +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_MODEL_RUNNER_H diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index f04d5515530..f4f95b57460 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -39,6 +39,7 @@ option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off) option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off) option(MSLITE_ENABLE_COVERAGE "enable code coverage" off) option(MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL "enable sharing memory with OpenGL" off) +option(MSLITE_ENABLE_SERVER_INFERENCE "enable inference on server" off) #Option that can be configured through manually option(ENABLE_VERBOSE "" off) @@ -140,6 +141,9 @@ endif() if(DEFINED ENV{MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL}) set(MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL $ENV{MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL}) endif() +if(DEFINED ENV{MSLITE_ENABLE_SERVER_INFERENCE}) + set(MSLITE_ENABLE_SERVER_INFERENCE $ENV{MSLITE_ENABLE_SERVER_INFERENCE}) +endif() if(MACHINE_LINUX_ARM64) add_compile_definitions(MACHINE_LINUX_ARM64) @@ -159,6 +163,9 @@ elseif(TOOLCHAIN_NAME STREQUAL "ohos-lite") set(TARGET_OHOS_LITE on) SET_PROPERTY(GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS TRUE) endif() +if(MSLITE_ENABLE_SERVER_INFERENCE) + add_compile_definitions(SERVER_INFERENCE) +endif() if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0 AND NOT TARGET_HIMIX AND NOT TARGET_MIX210) @@ -294,6 +301,7 @@ message(STATUS "\tMSLITE_ENABLE_RUNTIME_CONVERT = \t${MSLITE_ENABLE_RUNTIME_ message(STATUS "\tMSLITE_ENABLE_RUNTIME_GLOG = \t${MSLITE_ENABLE_RUNTIME_GLOG}") message(STATUS "\tMSLITE_ENABLE_COVERAGE = \t${MSLITE_ENABLE_COVERAGE}") message(STATUS "\tMSLITE_ENABLE_SHARING_MEM_WITH_OPENGL = \t${MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL}") +message(STATUS "\tMSLITE_ENABLE_SERVER_INFERENCE = \t${MSLITE_ENABLE_SERVER_INFERENCE}") if((MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_TESTCASES) AND ( NOT MSLITE_ENABLE_MINDRT diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index f8c706c7e53..95909357cc3 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -66,6 +66,16 @@ file(GLOB CXX_API_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor/*.cc ) +if(MSLITE_ENABLE_SERVER_INFERENCE) + set(CXX_API_SRCS + ${CXX_API_SRCS} + ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_parallel_runner.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_pool.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_thread.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/predict_task_queue.cc + ) +endif() + file(GLOB C_API_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/c_api/*.cc ) diff --git a/mindspore/lite/src/cxx_api/model_pool/model_parallel_runner.cc b/mindspore/lite/src/cxx_api/model_pool/model_parallel_runner.cc new file mode 100644 index 00000000000..9fc56966801 --- /dev/null +++ b/mindspore/lite/src/cxx_api/model_pool/model_parallel_runner.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "include/api/model_parallel_runner.h" +#include "src/cxx_api/model_pool/model_pool.h" +#include "src/common/log.h" + +namespace mindspore { +Status ModelParallelRunner::Init(const std::string &model_path, const std::shared_ptr &runner_config, + const Key &dec_key, const std::string &dec_mode) { + auto status = ModelPool::GetInstance()->Init(model_path, runner_config, dec_key, dec_mode); + if (status != kSuccess) { + MS_LOG(ERROR) << "model runner init failed."; + return kLiteError; + } + return status; +} + +std::vector ModelParallelRunner::GetInputs() { + auto inputs = ModelPool::GetInstance()->GetInputs(); + if (inputs.empty()) { + MS_LOG(ERROR) << "model pool input is empty."; + return {}; + } + return inputs; +} + +Status ModelParallelRunner::Predict(const std::vector &inputs, std::vector *outputs, + const MSKernelCallBack &before, const MSKernelCallBack &after) { + if (outputs == nullptr) { + MS_LOG(ERROR) << "predict output is nullptr."; + return kLiteNullptr; + } + auto status = ModelPool::GetInstance()->Predict(inputs, outputs, before, after); + if (status != kSuccess) { + MS_LOG(ERROR) << "model runner predict failed."; + return kLiteError; + } + return kSuccess; +} +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model_pool/model_pool.cc b/mindspore/lite/src/cxx_api/model_pool/model_pool.cc new file mode 100644 index 00000000000..2474de4caea --- /dev/null +++ b/mindspore/lite/src/cxx_api/model_pool/model_pool.cc @@ -0,0 +1,170 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/cxx_api/model_pool/model_pool.h" +#include +#include +#include "src/common/log.h" +#include "include/lite_types.h" +#include "src/common/config_file.h" +namespace mindspore { +void ModelPool::SetBindStrategy(std::vector> *all_model_bind_list, int thread_num) { + int core_num = 1; +#if defined(_MSC_VER) || defined(_WIN32) + SYSTEM_INFO sysinfo; + GetSystemInfo(&sysinfo); + core_num = sysinfo.dwNumberOfProcessors; +#else + core_num = sysconf(_SC_NPROCESSORS_CONF); +#endif + if (thread_num == 0) { + MS_LOG(ERROR) << "thread num is zero."; + return; + } + num_models_ = core_num / thread_num; + int core_id = 0; + for (size_t i = 0; i < num_models_; i++) { + std::vector bind_id; + for (int j = 0; j < thread_num; j++) { + if (core_id >= core_num) { + core_id = 0; + } + bind_id.push_back(core_id); + core_id++; + } + all_model_bind_list->push_back(bind_id); + } +} + +ModelPool *ModelPool::GetInstance() { + static ModelPool instance; + return &instance; +} + +std::shared_ptr ModelPool::InitContext(const std::shared_ptr &runner_config) { + auto model_context = std::make_shared(); + if (model_context == nullptr) { + MS_LOG(ERROR) << "New context failed in ModelPool."; + return nullptr; + } + if (runner_config != nullptr) { + model_context = runner_config->model_ctx; + num_models_ = runner_config->num_model; + auto device_list = model_context->MutableDeviceInfo(); + if (device_list.size() != 1) { + MS_LOG(ERROR) << "model pool only support device num 1."; + return nullptr; + } + auto device = device_list.front(); + if (device->GetDeviceType() != kCPU) { + MS_LOG(ERROR) << "model pool only support cpu type."; + return nullptr; + } + auto cpu_context = device->Cast(); + auto enable_fp16 = cpu_context->GetEnableFP16(); + if (enable_fp16) { + MS_LOG(ERROR) << "model pool not support enable fp16."; + return nullptr; + } + } else { + MS_LOG(DEBUG) << "use default config."; + model_context->SetThreadNum(1); + model_context->SetEnableParallel(false); + model_context->SetThreadAffinity(lite::NO_BIND); + auto &device_list = model_context->MutableDeviceInfo(); + auto device_info = std::shared_ptr(); + device_info->SetEnableFP16(false); + device_list.push_back(device_info); + } + return model_context; +} + +ModelPoolContex ModelPool::CreateModelContext(const std::shared_ptr &runner_config) { + auto model_context = InitContext(runner_config); + if (model_context == nullptr) { + MS_LOG(ERROR) << "context is nullptr."; + return {}; + } + ModelPoolContex model_pool_context; + std::vector> all_model_bind_list; + if (model_context->GetThreadAffinityMode() == lite::HIGHER_CPU) { + SetBindStrategy(&all_model_bind_list, static_cast(model_context->GetThreadNum())); + } else if (model_context->GetThreadAffinityMode() == lite::MID_CPU) { + MS_LOG(ERROR) << "not support bind MID_CPU."; + return {}; + } + for (size_t i = 0; i < num_models_; i++) { + auto context = std::make_shared(); + if (context == nullptr) { + MS_LOG(ERROR) << "New Context failed."; + return {}; + } + context->SetThreadNum(model_context->GetThreadNum()); + context->SetEnableParallel(model_context->GetEnableParallel()); + if (model_context->GetThreadAffinityMode() != lite::NO_BIND) { + // bind by core id + context->SetThreadAffinity(all_model_bind_list[i]); + } else { + // not bind core + context->SetThreadAffinity(model_context->GetThreadAffinityMode()); + } + auto &new_device_list = context->MutableDeviceInfo(); + std::shared_ptr device_info = std::make_shared(); + device_info->SetEnableFP16(false); + new_device_list.push_back(device_info); + model_pool_context.push_back(context); + } + return model_pool_context; +} + +std::vector ModelPool::GetInputs() { + if (model_inputs_.empty()) { + MS_LOG(ERROR) << "model input is empty."; + return {}; + } + return model_inputs_; +} + +Status ModelPool::Init(const std::string &model_path, const std::shared_ptr &runner_config, + const Key &dec_key, const std::string &dec_mode) { + auto model_pool_context = CreateModelContext(runner_config); + for (size_t i = 0; i < num_models_; i++) { + auto model_thread = std::make_shared(); + auto status = model_thread->Init(model_path, model_pool_context[i], dec_key, dec_mode); + if (model_inputs_.empty()) { + model_inputs_ = model_thread->GetInputs(); + } + model_thread_vec_.push_back(std::thread(&ModelThread::Run, model_thread)); + } + return kSuccess; +} + +Status ModelPool::Predict(const std::vector &inputs, std::vector *outputs, + const MSKernelCallBack &before, const MSKernelCallBack &after) { + outputs->clear(); + auto predict_task = std::make_shared(&inputs, outputs, before, after); + PredictTaskQueue::GetInstance()->PushPredictTask(predict_task); + PredictTaskQueue::GetInstance()->WaitUntilPredictActive(outputs); + return kSuccess; +} + +ModelPool::~ModelPool() { + for (auto &th : model_thread_vec_) { + if (th.joinable()) { + th.join(); + } + } +} +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model_pool/model_pool.h b/mindspore/lite/src/cxx_api/model_pool/model_pool.h new file mode 100644 index 00000000000..cb165b410c2 --- /dev/null +++ b/mindspore/lite/src/cxx_api/model_pool/model_pool.h @@ -0,0 +1,54 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H +#define MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H +#include +#include +#include +#include +#include +#include +#include "include/api/status.h" +#include "include/api/context.h" +#include "include/api/model_parallel_runner.h" +#include "src/cxx_api/model_pool/model_thread.h" +#include "src/cxx_api/model_pool/predict_task_queue.h" +namespace mindspore { +class ModelPool { + public: + static ModelPool *GetInstance(); + ~ModelPool(); + + Status Init(const std::string &model_path, const std::shared_ptr &runner_config = nullptr, + const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); + + std::vector GetInputs(); + + Status Predict(const std::vector &inputs, std::vector *outputs, + const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); + + private: + ModelPool() = default; + void SetBindStrategy(std::vector> *all_model_bind_list, int thread_num); + ModelPoolContex CreateModelContext(const std::shared_ptr &runner_config); + std::shared_ptr InitContext(const std::shared_ptr &runner_config); + + std::vector model_thread_vec_; + std::vector model_inputs_; + size_t num_models_ = 10; +}; +} // namespace mindspore +#endif // MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H diff --git a/mindspore/lite/src/cxx_api/model_pool/model_thread.cc b/mindspore/lite/src/cxx_api/model_pool/model_thread.cc new file mode 100644 index 00000000000..c6d26a11195 --- /dev/null +++ b/mindspore/lite/src/cxx_api/model_pool/model_thread.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/cxx_api/model_pool/model_thread.h" +#include "src/common/log.h" +#include "src/common/utils.h" +namespace mindspore { +void ModelThread::Run() { + while (!PredictTaskQueue::GetInstance()->IsPredictTaskDone()) { + auto task = PredictTaskQueue::GetInstance()->GetPredictTask(); + if (task == nullptr) { + break; + } + auto inputs = task->inputs; + auto *outputs = task->outputs; + auto before = task->before; + auto after = task->after; + auto status = Predict(*inputs, outputs, before, after); + if (status != kSuccess) { + MS_LOG(ERROR) << "model predict failed."; + return; + } + auto output_size = outputs->size(); + for (size_t i = 0; i < output_size; i++) { + auto copy_tensor = + mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(), + outputs->at(i).MutableData(), outputs->at(i).DataSize()); + outputs->erase(outputs->begin()); + outputs->push_back(*copy_tensor); + } + PredictTaskQueue::GetInstance()->ActiveTask(); + } +} + +Status ModelThread::Init(const std::string &model_path, const std::shared_ptr &model_context, + const Key &dec_key, const std::string &dec_mode) { + model_ = std::make_shared(); + mindspore::ModelType model_type = kMindIR; + auto status = model_->Build(model_path, model_type, model_context, dec_key, dec_mode); + if (status != kSuccess) { + MS_LOG(ERROR) << "model build failed in ModelPool Init"; + return status; + } + return kSuccess; +} + +std::vector ModelThread::GetInputs() { + if (model_ == nullptr) { + MS_LOG(ERROR) << "model is nullptr in ModelThread."; + return {}; + } + auto inputs = model_->GetInputs(); + return inputs; +} + +std::pair>, bool> ModelThread::GetModelResize( + const std::vector &model_inputs, const std::vector &inputs) { + std::unique_lock model_lock(mtx_model_); + std::vector> dims; + bool need_resize = false; + for (size_t i = 0; i < model_inputs.size(); i++) { + for (size_t j = 0; j < model_inputs[i].Shape().size(); j++) { + if (model_inputs[i].Shape()[j] != inputs[i].Shape()[j]) { + need_resize = true; + } + } + dims.push_back(inputs[i].Shape()); + } + return std::make_pair(dims, need_resize); +} + +Status ModelThread::Predict(const std::vector &inputs, std::vector *outputs, + const MSKernelCallBack &before, const MSKernelCallBack &after) { + // model + auto model_input = model_->GetInputs(); + if (model_input.size() != inputs.size()) { + MS_LOG(ERROR) << "model input size is: " << model_input.size() << ", but get input size is: " << inputs.size(); + return kLiteError; + } + auto resize_pair = GetModelResize(model_input, inputs); + if (resize_pair.second) { + auto dims = resize_pair.first; + auto status = model_->Resize(model_->GetInputs(), dims); + if (status != kSuccess) { + MS_LOG(ERROR) << "model pool resize failed."; + return kLiteError; + } + } + auto status = model_->Predict(inputs, outputs, before, after); + if (status != kSuccess) { + MS_LOG(ERROR) << "model predict failed."; + return status; + } + return kSuccess; +} +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model_pool/model_thread.h b/mindspore/lite/src/cxx_api/model_pool/model_thread.h new file mode 100644 index 00000000000..72afa23b097 --- /dev/null +++ b/mindspore/lite/src/cxx_api/model_pool/model_thread.h @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_ +#define MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_ +#include +#include +#include +#include +#include +#include +#include +#include "include/api/model.h" +#include "src/cxx_api/model_pool/predict_task_queue.h" +namespace mindspore { +using ModelPoolContex = std::vector>; + +class ModelThread { + public: + ModelThread() = default; + + ~ModelThread() = default; + + // the model pool is initialized once and can always accept model run requests + Status Init(const std::string &model_path, const std::shared_ptr &model_context, const Key &dec_key = {}, + const std::string &dec_mode = kDecModeAesGcm); + + std::vector GetInputs(); + + Status Predict(const std::vector &inputs, std::vector *outputs, + const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr); + + void Run(); + + private: + std::pair>, bool> GetModelResize(const std::vector &model_inputs, + const std::vector &inputs); + + private: + std::shared_ptr model_ = nullptr; + std::mutex mtx_model_; + std::condition_variable model_cond_; + + // num thread is configured according to the hardware + int num_models_; +}; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_ diff --git a/mindspore/lite/src/cxx_api/model_pool/predict_task_queue.cc b/mindspore/lite/src/cxx_api/model_pool/predict_task_queue.cc new file mode 100644 index 00000000000..69b0ed1bda7 --- /dev/null +++ b/mindspore/lite/src/cxx_api/model_pool/predict_task_queue.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/cxx_api/model_pool/predict_task_queue.h" +namespace mindspore { +PredictTaskQueue::~PredictTaskQueue() { + predict_task_done_ = true; + task_push_cond_.notify_all(); +} + +void PredictTaskQueue::WaitUntilPredictActive(std::vector *outputs) { + std::unique_lock result_lock(mtx_predict_task_); + while (outputs->empty()) { + task_pop_cond_.wait(result_lock); + } + return; +} + +void PredictTaskQueue::ActiveTask() { task_pop_cond_.notify_all(); } + +PredictTaskQueue *PredictTaskQueue::GetInstance() { + static PredictTaskQueue instance; + return &instance; +} + +void PredictTaskQueue::PushPredictTask(std::shared_ptr task) { + std::unique_lock data_lock(mtx_predict_task_); + predict_task_.push(task); + task_push_cond_.notify_one(); +} + +std::shared_ptr PredictTaskQueue::GetPredictTask() { + std::unique_lock task_lock(mtx_predict_task_); + while (predict_task_.empty() && !predict_task_done_) { + task_push_cond_.wait(task_lock); + } + if (predict_task_done_) { + return nullptr; + } + auto predict_task = predict_task_.front(); + predict_task_.pop(); + return predict_task; +} +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/model_pool/predict_task_queue.h b/mindspore/lite/src/cxx_api/model_pool/predict_task_queue.h new file mode 100644 index 00000000000..96aeda66110 --- /dev/null +++ b/mindspore/lite/src/cxx_api/model_pool/predict_task_queue.h @@ -0,0 +1,58 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_PREDICT_TASK_QUEUE_H_ +#define MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_PREDICT_TASK_QUEUE_H_ + +#include +#include +#include +#include +#include +#include "include/api/types.h" +#include "include/api/status.h" +namespace mindspore { +struct PredictTask { + PredictTask(const std::vector *in, std::vector *out, MSKernelCallBack before, + MSKernelCallBack after) + : inputs(in), outputs(out), before(before), after(after) {} + const std::vector *inputs; + std::vector *outputs; + MSKernelCallBack before; + MSKernelCallBack after; +}; + +class PredictTaskQueue { + public: + static PredictTaskQueue *GetInstance(); + ~PredictTaskQueue(); + + void PushPredictTask(std::shared_ptr task); + void WaitUntilPredictActive(std::vector *outputs); + std::shared_ptr GetPredictTask(); + void ActiveTask(); + bool IsPredictTaskDone() { return predict_task_done_; } + + private: + PredictTaskQueue() = default; + std::queue> predict_task_; + + std::mutex mtx_predict_task_; + std::condition_variable task_pop_cond_; + std::condition_variable task_push_cond_; + bool predict_task_done_ = false; +}; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_PREDICT_TASK_QUEUE_H_