model pool for r1.6

This commit is contained in:
yefeng 2022-02-08 17:22:34 +08:00
parent 9391dcd1e2
commit 4f839a7521
10 changed files with 649 additions and 0 deletions

View File

@ -0,0 +1,70 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_MODEL_RUNNER_H
#define MINDSPORE_INCLUDE_API_MODEL_RUNNER_H
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include "include/api/status.h"
#include "include/api/context.h"
namespace mindspore {
class ModelPool;
struct RunnerConfig {
RunnerConfig(std::shared_ptr<Context> &ctx, int num) : model_ctx(ctx), num_model(num) {}
std::shared_ptr<Context> model_ctx = nullptr;
int num_model;
};
/// \brief The ModelRunner class is used to define a MindSpore ModelPoolManager, facilitating Model management.
class MS_API ModelParallelRunner {
public:
ModelParallelRunner() = default;
~ModelParallelRunner() = default;
/// \brief build a model runner from model path so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_path Define the model path.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
///
/// \return Status.
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr,
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
/// \brief Obtains all input tensors of the model.
///
/// \return The vector that includes all input tensors.
std::vector<MSTensor> GetInputs();
/// \brief Inference ModelPoolManager.
///
/// \param[in] inputs A vector where model inputs are arranged in sequence.
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
/// \param[in] before CallBack before predict.
/// \param[in] after CallBack after predict.
///
/// \return Status.
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_RUNNER_H

View File

@ -39,6 +39,7 @@ option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off)
option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off)
option(MSLITE_ENABLE_COVERAGE "enable code coverage" off)
option(MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL "enable sharing memory with OpenGL" off)
option(MSLITE_ENABLE_SERVER_INFERENCE "enable inference on server" off)
#Option that can be configured through manually
option(ENABLE_VERBOSE "" off)
@ -140,6 +141,9 @@ endif()
if(DEFINED ENV{MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL})
set(MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL $ENV{MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL})
endif()
if(DEFINED ENV{MSLITE_ENABLE_SERVER_INFERENCE})
set(MSLITE_ENABLE_SERVER_INFERENCE $ENV{MSLITE_ENABLE_SERVER_INFERENCE})
endif()
if(MACHINE_LINUX_ARM64)
add_compile_definitions(MACHINE_LINUX_ARM64)
@ -159,6 +163,9 @@ elseif(TOOLCHAIN_NAME STREQUAL "ohos-lite")
set(TARGET_OHOS_LITE on)
SET_PROPERTY(GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS TRUE)
endif()
if(MSLITE_ENABLE_SERVER_INFERENCE)
add_compile_definitions(SERVER_INFERENCE)
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0
AND NOT TARGET_HIMIX AND NOT TARGET_MIX210)
@ -294,6 +301,7 @@ message(STATUS "\tMSLITE_ENABLE_RUNTIME_CONVERT = \t${MSLITE_ENABLE_RUNTIME_
message(STATUS "\tMSLITE_ENABLE_RUNTIME_GLOG = \t${MSLITE_ENABLE_RUNTIME_GLOG}")
message(STATUS "\tMSLITE_ENABLE_COVERAGE = \t${MSLITE_ENABLE_COVERAGE}")
message(STATUS "\tMSLITE_ENABLE_SHARING_MEM_WITH_OPENGL = \t${MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL}")
message(STATUS "\tMSLITE_ENABLE_SERVER_INFERENCE = \t${MSLITE_ENABLE_SERVER_INFERENCE}")
if((MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_TESTCASES) AND (
NOT MSLITE_ENABLE_MINDRT

View File

@ -66,6 +66,16 @@ file(GLOB CXX_API_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor/*.cc
)
if(MSLITE_ENABLE_SERVER_INFERENCE)
set(CXX_API_SRCS
${CXX_API_SRCS}
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_parallel_runner.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_pool.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_thread.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/predict_task_queue.cc
)
endif()
file(GLOB C_API_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/c_api/*.cc
)

View File

@ -0,0 +1,53 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/model_parallel_runner.h"
#include "src/cxx_api/model_pool/model_pool.h"
#include "src/common/log.h"
namespace mindspore {
Status ModelParallelRunner::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config,
const Key &dec_key, const std::string &dec_mode) {
auto status = ModelPool::GetInstance()->Init(model_path, runner_config, dec_key, dec_mode);
if (status != kSuccess) {
MS_LOG(ERROR) << "model runner init failed.";
return kLiteError;
}
return status;
}
std::vector<MSTensor> ModelParallelRunner::GetInputs() {
auto inputs = ModelPool::GetInstance()->GetInputs();
if (inputs.empty()) {
MS_LOG(ERROR) << "model pool input is empty.";
return {};
}
return inputs;
}
Status ModelParallelRunner::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (outputs == nullptr) {
MS_LOG(ERROR) << "predict output is nullptr.";
return kLiteNullptr;
}
auto status = ModelPool::GetInstance()->Predict(inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model runner predict failed.";
return kLiteError;
}
return kSuccess;
}
} // namespace mindspore

View File

@ -0,0 +1,170 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/cxx_api/model_pool/model_pool.h"
#include <unistd.h>
#include <future>
#include "src/common/log.h"
#include "include/lite_types.h"
#include "src/common/config_file.h"
namespace mindspore {
void ModelPool::SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num) {
int core_num = 1;
#if defined(_MSC_VER) || defined(_WIN32)
SYSTEM_INFO sysinfo;
GetSystemInfo(&sysinfo);
core_num = sysinfo.dwNumberOfProcessors;
#else
core_num = sysconf(_SC_NPROCESSORS_CONF);
#endif
if (thread_num == 0) {
MS_LOG(ERROR) << "thread num is zero.";
return;
}
num_models_ = core_num / thread_num;
int core_id = 0;
for (size_t i = 0; i < num_models_; i++) {
std::vector<int> bind_id;
for (int j = 0; j < thread_num; j++) {
if (core_id >= core_num) {
core_id = 0;
}
bind_id.push_back(core_id);
core_id++;
}
all_model_bind_list->push_back(bind_id);
}
}
ModelPool *ModelPool::GetInstance() {
static ModelPool instance;
return &instance;
}
std::shared_ptr<Context> ModelPool::InitContext(const std::shared_ptr<RunnerConfig> &runner_config) {
auto model_context = std::make_shared<mindspore::Context>();
if (model_context == nullptr) {
MS_LOG(ERROR) << "New context failed in ModelPool.";
return nullptr;
}
if (runner_config != nullptr) {
model_context = runner_config->model_ctx;
num_models_ = runner_config->num_model;
auto device_list = model_context->MutableDeviceInfo();
if (device_list.size() != 1) {
MS_LOG(ERROR) << "model pool only support device num 1.";
return nullptr;
}
auto device = device_list.front();
if (device->GetDeviceType() != kCPU) {
MS_LOG(ERROR) << "model pool only support cpu type.";
return nullptr;
}
auto cpu_context = device->Cast<CPUDeviceInfo>();
auto enable_fp16 = cpu_context->GetEnableFP16();
if (enable_fp16) {
MS_LOG(ERROR) << "model pool not support enable fp16.";
return nullptr;
}
} else {
MS_LOG(DEBUG) << "use default config.";
model_context->SetThreadNum(1);
model_context->SetEnableParallel(false);
model_context->SetThreadAffinity(lite::NO_BIND);
auto &device_list = model_context->MutableDeviceInfo();
auto device_info = std::shared_ptr<CPUDeviceInfo>();
device_info->SetEnableFP16(false);
device_list.push_back(device_info);
}
return model_context;
}
ModelPoolContex ModelPool::CreateModelContext(const std::shared_ptr<RunnerConfig> &runner_config) {
auto model_context = InitContext(runner_config);
if (model_context == nullptr) {
MS_LOG(ERROR) << "context is nullptr.";
return {};
}
ModelPoolContex model_pool_context;
std::vector<std::vector<int>> all_model_bind_list;
if (model_context->GetThreadAffinityMode() == lite::HIGHER_CPU) {
SetBindStrategy(&all_model_bind_list, static_cast<int>(model_context->GetThreadNum()));
} else if (model_context->GetThreadAffinityMode() == lite::MID_CPU) {
MS_LOG(ERROR) << "not support bind MID_CPU.";
return {};
}
for (size_t i = 0; i < num_models_; i++) {
auto context = std::make_shared<Context>();
if (context == nullptr) {
MS_LOG(ERROR) << "New Context failed.";
return {};
}
context->SetThreadNum(model_context->GetThreadNum());
context->SetEnableParallel(model_context->GetEnableParallel());
if (model_context->GetThreadAffinityMode() != lite::NO_BIND) {
// bind by core id
context->SetThreadAffinity(all_model_bind_list[i]);
} else {
// not bind core
context->SetThreadAffinity(model_context->GetThreadAffinityMode());
}
auto &new_device_list = context->MutableDeviceInfo();
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetEnableFP16(false);
new_device_list.push_back(device_info);
model_pool_context.push_back(context);
}
return model_pool_context;
}
std::vector<MSTensor> ModelPool::GetInputs() {
if (model_inputs_.empty()) {
MS_LOG(ERROR) << "model input is empty.";
return {};
}
return model_inputs_;
}
Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config,
const Key &dec_key, const std::string &dec_mode) {
auto model_pool_context = CreateModelContext(runner_config);
for (size_t i = 0; i < num_models_; i++) {
auto model_thread = std::make_shared<ModelThread>();
auto status = model_thread->Init(model_path, model_pool_context[i], dec_key, dec_mode);
if (model_inputs_.empty()) {
model_inputs_ = model_thread->GetInputs();
}
model_thread_vec_.push_back(std::thread(&ModelThread::Run, model_thread));
}
return kSuccess;
}
Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
outputs->clear();
auto predict_task = std::make_shared<PredictTask>(&inputs, outputs, before, after);
PredictTaskQueue::GetInstance()->PushPredictTask(predict_task);
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(outputs);
return kSuccess;
}
ModelPool::~ModelPool() {
for (auto &th : model_thread_vec_) {
if (th.joinable()) {
th.join();
}
}
}
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H
#define MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include <queue>
#include <map>
#include "include/api/status.h"
#include "include/api/context.h"
#include "include/api/model_parallel_runner.h"
#include "src/cxx_api/model_pool/model_thread.h"
#include "src/cxx_api/model_pool/predict_task_queue.h"
namespace mindspore {
class ModelPool {
public:
static ModelPool *GetInstance();
~ModelPool();
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr,
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
std::vector<MSTensor> GetInputs();
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
private:
ModelPool() = default;
void SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num);
ModelPoolContex CreateModelContext(const std::shared_ptr<RunnerConfig> &runner_config);
std::shared_ptr<Context> InitContext(const std::shared_ptr<RunnerConfig> &runner_config);
std::vector<std::thread> model_thread_vec_;
std::vector<MSTensor> model_inputs_;
size_t num_models_ = 10;
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H

View File

@ -0,0 +1,108 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/cxx_api/model_pool/model_thread.h"
#include "src/common/log.h"
#include "src/common/utils.h"
namespace mindspore {
void ModelThread::Run() {
while (!PredictTaskQueue::GetInstance()->IsPredictTaskDone()) {
auto task = PredictTaskQueue::GetInstance()->GetPredictTask();
if (task == nullptr) {
break;
}
auto inputs = task->inputs;
auto *outputs = task->outputs;
auto before = task->before;
auto after = task->after;
auto status = Predict(*inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed.";
return;
}
auto output_size = outputs->size();
for (size_t i = 0; i < output_size; i++) {
auto copy_tensor =
mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(),
outputs->at(i).MutableData(), outputs->at(i).DataSize());
outputs->erase(outputs->begin());
outputs->push_back(*copy_tensor);
}
PredictTaskQueue::GetInstance()->ActiveTask();
}
}
Status ModelThread::Init(const std::string &model_path, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
model_ = std::make_shared<Model>();
mindspore::ModelType model_type = kMindIR;
auto status = model_->Build(model_path, model_type, model_context, dec_key, dec_mode);
if (status != kSuccess) {
MS_LOG(ERROR) << "model build failed in ModelPool Init";
return status;
}
return kSuccess;
}
std::vector<MSTensor> ModelThread::GetInputs() {
if (model_ == nullptr) {
MS_LOG(ERROR) << "model is nullptr in ModelThread.";
return {};
}
auto inputs = model_->GetInputs();
return inputs;
}
std::pair<std::vector<std::vector<int64_t>>, bool> ModelThread::GetModelResize(
const std::vector<MSTensor> &model_inputs, const std::vector<MSTensor> &inputs) {
std::unique_lock<std::mutex> model_lock(mtx_model_);
std::vector<std::vector<int64_t>> dims;
bool need_resize = false;
for (size_t i = 0; i < model_inputs.size(); i++) {
for (size_t j = 0; j < model_inputs[i].Shape().size(); j++) {
if (model_inputs[i].Shape()[j] != inputs[i].Shape()[j]) {
need_resize = true;
}
}
dims.push_back(inputs[i].Shape());
}
return std::make_pair(dims, need_resize);
}
Status ModelThread::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
// model
auto model_input = model_->GetInputs();
if (model_input.size() != inputs.size()) {
MS_LOG(ERROR) << "model input size is: " << model_input.size() << ", but get input size is: " << inputs.size();
return kLiteError;
}
auto resize_pair = GetModelResize(model_input, inputs);
if (resize_pair.second) {
auto dims = resize_pair.first;
auto status = model_->Resize(model_->GetInputs(), dims);
if (status != kSuccess) {
MS_LOG(ERROR) << "model pool resize failed.";
return kLiteError;
}
}
auto status = model_->Predict(inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed.";
return status;
}
return kSuccess;
}
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_
#define MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_
#include <queue>
#include <string>
#include <mutex>
#include <future>
#include <vector>
#include <utility>
#include <memory>
#include "include/api/model.h"
#include "src/cxx_api/model_pool/predict_task_queue.h"
namespace mindspore {
using ModelPoolContex = std::vector<std::shared_ptr<Context>>;
class ModelThread {
public:
ModelThread() = default;
~ModelThread() = default;
// the model pool is initialized once and can always accept model run requests
Status Init(const std::string &model_path, const std::shared_ptr<Context> &model_context, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
std::vector<MSTensor> GetInputs();
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
void Run();
private:
std::pair<std::vector<std::vector<int64_t>>, bool> GetModelResize(const std::vector<MSTensor> &model_inputs,
const std::vector<MSTensor> &inputs);
private:
std::shared_ptr<mindspore::Model> model_ = nullptr;
std::mutex mtx_model_;
std::condition_variable model_cond_;
// num thread is configured according to the hardware
int num_models_;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/cxx_api/model_pool/predict_task_queue.h"
namespace mindspore {
PredictTaskQueue::~PredictTaskQueue() {
predict_task_done_ = true;
task_push_cond_.notify_all();
}
void PredictTaskQueue::WaitUntilPredictActive(std::vector<MSTensor> *outputs) {
std::unique_lock<std::mutex> result_lock(mtx_predict_task_);
while (outputs->empty()) {
task_pop_cond_.wait(result_lock);
}
return;
}
void PredictTaskQueue::ActiveTask() { task_pop_cond_.notify_all(); }
PredictTaskQueue *PredictTaskQueue::GetInstance() {
static PredictTaskQueue instance;
return &instance;
}
void PredictTaskQueue::PushPredictTask(std::shared_ptr<PredictTask> task) {
std::unique_lock<std::mutex> data_lock(mtx_predict_task_);
predict_task_.push(task);
task_push_cond_.notify_one();
}
std::shared_ptr<PredictTask> PredictTaskQueue::GetPredictTask() {
std::unique_lock<std::mutex> task_lock(mtx_predict_task_);
while (predict_task_.empty() && !predict_task_done_) {
task_push_cond_.wait(task_lock);
}
if (predict_task_done_) {
return nullptr;
}
auto predict_task = predict_task_.front();
predict_task_.pop();
return predict_task;
}
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_PREDICT_TASK_QUEUE_H_
#define MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_PREDICT_TASK_QUEUE_H_
#include <queue>
#include <mutex>
#include <memory>
#include <vector>
#include <condition_variable>
#include "include/api/types.h"
#include "include/api/status.h"
namespace mindspore {
struct PredictTask {
PredictTask(const std::vector<MSTensor> *in, std::vector<MSTensor> *out, MSKernelCallBack before,
MSKernelCallBack after)
: inputs(in), outputs(out), before(before), after(after) {}
const std::vector<MSTensor> *inputs;
std::vector<MSTensor> *outputs;
MSKernelCallBack before;
MSKernelCallBack after;
};
class PredictTaskQueue {
public:
static PredictTaskQueue *GetInstance();
~PredictTaskQueue();
void PushPredictTask(std::shared_ptr<PredictTask> task);
void WaitUntilPredictActive(std::vector<MSTensor> *outputs);
std::shared_ptr<PredictTask> GetPredictTask();
void ActiveTask();
bool IsPredictTaskDone() { return predict_task_done_; }
private:
PredictTaskQueue() = default;
std::queue<std::shared_ptr<PredictTask>> predict_task_;
std::mutex mtx_predict_task_;
std::condition_variable task_pop_cond_;
std::condition_variable task_push_cond_;
bool predict_task_done_ = false;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_PREDICT_TASK_QUEUE_H_