model pool for r1.6
This commit is contained in:
parent
9391dcd1e2
commit
4f839a7521
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_MODEL_RUNNER_H
|
||||
#define MINDSPORE_INCLUDE_API_MODEL_RUNNER_H
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ModelPool;
|
||||
|
||||
struct RunnerConfig {
|
||||
RunnerConfig(std::shared_ptr<Context> &ctx, int num) : model_ctx(ctx), num_model(num) {}
|
||||
std::shared_ptr<Context> model_ctx = nullptr;
|
||||
int num_model;
|
||||
};
|
||||
|
||||
/// \brief The ModelRunner class is used to define a MindSpore ModelPoolManager, facilitating Model management.
|
||||
class MS_API ModelParallelRunner {
|
||||
public:
|
||||
ModelParallelRunner() = default;
|
||||
~ModelParallelRunner() = default;
|
||||
|
||||
/// \brief build a model runner from model path so that it can run on a device. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] model_path Define the model path.
|
||||
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
|
||||
/// ModelType::kMindIR is valid for Lite.
|
||||
/// \param[in] model_context Define the context used to store options during execution.
|
||||
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
||||
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr,
|
||||
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||
|
||||
/// \brief Obtains all input tensors of the model.
|
||||
///
|
||||
/// \return The vector that includes all input tensors.
|
||||
std::vector<MSTensor> GetInputs();
|
||||
|
||||
/// \brief Inference ModelPoolManager.
|
||||
///
|
||||
/// \param[in] inputs A vector where model inputs are arranged in sequence.
|
||||
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
|
||||
/// \param[in] before CallBack before predict.
|
||||
/// \param[in] after CallBack after predict.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_RUNNER_H
|
|
@ -39,6 +39,7 @@ option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off)
|
|||
option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off)
|
||||
option(MSLITE_ENABLE_COVERAGE "enable code coverage" off)
|
||||
option(MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL "enable sharing memory with OpenGL" off)
|
||||
option(MSLITE_ENABLE_SERVER_INFERENCE "enable inference on server" off)
|
||||
|
||||
#Option that can be configured through manually
|
||||
option(ENABLE_VERBOSE "" off)
|
||||
|
@ -140,6 +141,9 @@ endif()
|
|||
if(DEFINED ENV{MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL})
|
||||
set(MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL $ENV{MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_SERVER_INFERENCE})
|
||||
set(MSLITE_ENABLE_SERVER_INFERENCE $ENV{MSLITE_ENABLE_SERVER_INFERENCE})
|
||||
endif()
|
||||
|
||||
if(MACHINE_LINUX_ARM64)
|
||||
add_compile_definitions(MACHINE_LINUX_ARM64)
|
||||
|
@ -159,6 +163,9 @@ elseif(TOOLCHAIN_NAME STREQUAL "ohos-lite")
|
|||
set(TARGET_OHOS_LITE on)
|
||||
SET_PROPERTY(GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS TRUE)
|
||||
endif()
|
||||
if(MSLITE_ENABLE_SERVER_INFERENCE)
|
||||
add_compile_definitions(SERVER_INFERENCE)
|
||||
endif()
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0
|
||||
AND NOT TARGET_HIMIX AND NOT TARGET_MIX210)
|
||||
|
@ -294,6 +301,7 @@ message(STATUS "\tMSLITE_ENABLE_RUNTIME_CONVERT = \t${MSLITE_ENABLE_RUNTIME_
|
|||
message(STATUS "\tMSLITE_ENABLE_RUNTIME_GLOG = \t${MSLITE_ENABLE_RUNTIME_GLOG}")
|
||||
message(STATUS "\tMSLITE_ENABLE_COVERAGE = \t${MSLITE_ENABLE_COVERAGE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_SHARING_MEM_WITH_OPENGL = \t${MSLITE_ENABLE_SHARING_MEM_WITH_OPENGL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_SERVER_INFERENCE = \t${MSLITE_ENABLE_SERVER_INFERENCE}")
|
||||
|
||||
if((MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_TESTCASES) AND (
|
||||
NOT MSLITE_ENABLE_MINDRT
|
||||
|
|
|
@ -66,6 +66,16 @@ file(GLOB CXX_API_SRCS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor/*.cc
|
||||
)
|
||||
|
||||
if(MSLITE_ENABLE_SERVER_INFERENCE)
|
||||
set(CXX_API_SRCS
|
||||
${CXX_API_SRCS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_parallel_runner.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_pool.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_thread.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/predict_task_queue.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
file(GLOB C_API_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/c_api/*.cc
|
||||
)
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/model_parallel_runner.h"
|
||||
#include "src/cxx_api/model_pool/model_pool.h"
|
||||
#include "src/common/log.h"
|
||||
|
||||
namespace mindspore {
|
||||
Status ModelParallelRunner::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config,
|
||||
const Key &dec_key, const std::string &dec_mode) {
|
||||
auto status = ModelPool::GetInstance()->Init(model_path, runner_config, dec_key, dec_mode);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model runner init failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> ModelParallelRunner::GetInputs() {
|
||||
auto inputs = ModelPool::GetInstance()->GetInputs();
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(ERROR) << "model pool input is empty.";
|
||||
return {};
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
Status ModelParallelRunner::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (outputs == nullptr) {
|
||||
MS_LOG(ERROR) << "predict output is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto status = ModelPool::GetInstance()->Predict(inputs, outputs, before, after);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model runner predict failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,170 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/cxx_api/model_pool/model_pool.h"
|
||||
#include <unistd.h>
|
||||
#include <future>
|
||||
#include "src/common/log.h"
|
||||
#include "include/lite_types.h"
|
||||
#include "src/common/config_file.h"
|
||||
namespace mindspore {
|
||||
void ModelPool::SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num) {
|
||||
int core_num = 1;
|
||||
#if defined(_MSC_VER) || defined(_WIN32)
|
||||
SYSTEM_INFO sysinfo;
|
||||
GetSystemInfo(&sysinfo);
|
||||
core_num = sysinfo.dwNumberOfProcessors;
|
||||
#else
|
||||
core_num = sysconf(_SC_NPROCESSORS_CONF);
|
||||
#endif
|
||||
if (thread_num == 0) {
|
||||
MS_LOG(ERROR) << "thread num is zero.";
|
||||
return;
|
||||
}
|
||||
num_models_ = core_num / thread_num;
|
||||
int core_id = 0;
|
||||
for (size_t i = 0; i < num_models_; i++) {
|
||||
std::vector<int> bind_id;
|
||||
for (int j = 0; j < thread_num; j++) {
|
||||
if (core_id >= core_num) {
|
||||
core_id = 0;
|
||||
}
|
||||
bind_id.push_back(core_id);
|
||||
core_id++;
|
||||
}
|
||||
all_model_bind_list->push_back(bind_id);
|
||||
}
|
||||
}
|
||||
|
||||
ModelPool *ModelPool::GetInstance() {
|
||||
static ModelPool instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
std::shared_ptr<Context> ModelPool::InitContext(const std::shared_ptr<RunnerConfig> &runner_config) {
|
||||
auto model_context = std::make_shared<mindspore::Context>();
|
||||
if (model_context == nullptr) {
|
||||
MS_LOG(ERROR) << "New context failed in ModelPool.";
|
||||
return nullptr;
|
||||
}
|
||||
if (runner_config != nullptr) {
|
||||
model_context = runner_config->model_ctx;
|
||||
num_models_ = runner_config->num_model;
|
||||
auto device_list = model_context->MutableDeviceInfo();
|
||||
if (device_list.size() != 1) {
|
||||
MS_LOG(ERROR) << "model pool only support device num 1.";
|
||||
return nullptr;
|
||||
}
|
||||
auto device = device_list.front();
|
||||
if (device->GetDeviceType() != kCPU) {
|
||||
MS_LOG(ERROR) << "model pool only support cpu type.";
|
||||
return nullptr;
|
||||
}
|
||||
auto cpu_context = device->Cast<CPUDeviceInfo>();
|
||||
auto enable_fp16 = cpu_context->GetEnableFP16();
|
||||
if (enable_fp16) {
|
||||
MS_LOG(ERROR) << "model pool not support enable fp16.";
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "use default config.";
|
||||
model_context->SetThreadNum(1);
|
||||
model_context->SetEnableParallel(false);
|
||||
model_context->SetThreadAffinity(lite::NO_BIND);
|
||||
auto &device_list = model_context->MutableDeviceInfo();
|
||||
auto device_info = std::shared_ptr<CPUDeviceInfo>();
|
||||
device_info->SetEnableFP16(false);
|
||||
device_list.push_back(device_info);
|
||||
}
|
||||
return model_context;
|
||||
}
|
||||
|
||||
ModelPoolContex ModelPool::CreateModelContext(const std::shared_ptr<RunnerConfig> &runner_config) {
|
||||
auto model_context = InitContext(runner_config);
|
||||
if (model_context == nullptr) {
|
||||
MS_LOG(ERROR) << "context is nullptr.";
|
||||
return {};
|
||||
}
|
||||
ModelPoolContex model_pool_context;
|
||||
std::vector<std::vector<int>> all_model_bind_list;
|
||||
if (model_context->GetThreadAffinityMode() == lite::HIGHER_CPU) {
|
||||
SetBindStrategy(&all_model_bind_list, static_cast<int>(model_context->GetThreadNum()));
|
||||
} else if (model_context->GetThreadAffinityMode() == lite::MID_CPU) {
|
||||
MS_LOG(ERROR) << "not support bind MID_CPU.";
|
||||
return {};
|
||||
}
|
||||
for (size_t i = 0; i < num_models_; i++) {
|
||||
auto context = std::make_shared<Context>();
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "New Context failed.";
|
||||
return {};
|
||||
}
|
||||
context->SetThreadNum(model_context->GetThreadNum());
|
||||
context->SetEnableParallel(model_context->GetEnableParallel());
|
||||
if (model_context->GetThreadAffinityMode() != lite::NO_BIND) {
|
||||
// bind by core id
|
||||
context->SetThreadAffinity(all_model_bind_list[i]);
|
||||
} else {
|
||||
// not bind core
|
||||
context->SetThreadAffinity(model_context->GetThreadAffinityMode());
|
||||
}
|
||||
auto &new_device_list = context->MutableDeviceInfo();
|
||||
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
|
||||
device_info->SetEnableFP16(false);
|
||||
new_device_list.push_back(device_info);
|
||||
model_pool_context.push_back(context);
|
||||
}
|
||||
return model_pool_context;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> ModelPool::GetInputs() {
|
||||
if (model_inputs_.empty()) {
|
||||
MS_LOG(ERROR) << "model input is empty.";
|
||||
return {};
|
||||
}
|
||||
return model_inputs_;
|
||||
}
|
||||
|
||||
Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config,
|
||||
const Key &dec_key, const std::string &dec_mode) {
|
||||
auto model_pool_context = CreateModelContext(runner_config);
|
||||
for (size_t i = 0; i < num_models_; i++) {
|
||||
auto model_thread = std::make_shared<ModelThread>();
|
||||
auto status = model_thread->Init(model_path, model_pool_context[i], dec_key, dec_mode);
|
||||
if (model_inputs_.empty()) {
|
||||
model_inputs_ = model_thread->GetInputs();
|
||||
}
|
||||
model_thread_vec_.push_back(std::thread(&ModelThread::Run, model_thread));
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
outputs->clear();
|
||||
auto predict_task = std::make_shared<PredictTask>(&inputs, outputs, before, after);
|
||||
PredictTaskQueue::GetInstance()->PushPredictTask(predict_task);
|
||||
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(outputs);
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
ModelPool::~ModelPool() {
|
||||
for (auto &th : model_thread_vec_) {
|
||||
if (th.joinable()) {
|
||||
th.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H
|
||||
#define MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/model_parallel_runner.h"
|
||||
#include "src/cxx_api/model_pool/model_thread.h"
|
||||
#include "src/cxx_api/model_pool/predict_task_queue.h"
|
||||
namespace mindspore {
|
||||
class ModelPool {
|
||||
public:
|
||||
static ModelPool *GetInstance();
|
||||
~ModelPool();
|
||||
|
||||
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr,
|
||||
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||
|
||||
std::vector<MSTensor> GetInputs();
|
||||
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
|
||||
private:
|
||||
ModelPool() = default;
|
||||
void SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num);
|
||||
ModelPoolContex CreateModelContext(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
std::shared_ptr<Context> InitContext(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
|
||||
std::vector<std::thread> model_thread_vec_;
|
||||
std::vector<MSTensor> model_inputs_;
|
||||
size_t num_models_ = 10;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H
|
|
@ -0,0 +1,108 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/cxx_api/model_pool/model_thread.h"
|
||||
#include "src/common/log.h"
|
||||
#include "src/common/utils.h"
|
||||
namespace mindspore {
|
||||
void ModelThread::Run() {
|
||||
while (!PredictTaskQueue::GetInstance()->IsPredictTaskDone()) {
|
||||
auto task = PredictTaskQueue::GetInstance()->GetPredictTask();
|
||||
if (task == nullptr) {
|
||||
break;
|
||||
}
|
||||
auto inputs = task->inputs;
|
||||
auto *outputs = task->outputs;
|
||||
auto before = task->before;
|
||||
auto after = task->after;
|
||||
auto status = Predict(*inputs, outputs, before, after);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model predict failed.";
|
||||
return;
|
||||
}
|
||||
auto output_size = outputs->size();
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
auto copy_tensor =
|
||||
mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(),
|
||||
outputs->at(i).MutableData(), outputs->at(i).DataSize());
|
||||
outputs->erase(outputs->begin());
|
||||
outputs->push_back(*copy_tensor);
|
||||
}
|
||||
PredictTaskQueue::GetInstance()->ActiveTask();
|
||||
}
|
||||
}
|
||||
|
||||
Status ModelThread::Init(const std::string &model_path, const std::shared_ptr<Context> &model_context,
|
||||
const Key &dec_key, const std::string &dec_mode) {
|
||||
model_ = std::make_shared<Model>();
|
||||
mindspore::ModelType model_type = kMindIR;
|
||||
auto status = model_->Build(model_path, model_type, model_context, dec_key, dec_mode);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model build failed in ModelPool Init";
|
||||
return status;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> ModelThread::GetInputs() {
|
||||
if (model_ == nullptr) {
|
||||
MS_LOG(ERROR) << "model is nullptr in ModelThread.";
|
||||
return {};
|
||||
}
|
||||
auto inputs = model_->GetInputs();
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::pair<std::vector<std::vector<int64_t>>, bool> ModelThread::GetModelResize(
|
||||
const std::vector<MSTensor> &model_inputs, const std::vector<MSTensor> &inputs) {
|
||||
std::unique_lock<std::mutex> model_lock(mtx_model_);
|
||||
std::vector<std::vector<int64_t>> dims;
|
||||
bool need_resize = false;
|
||||
for (size_t i = 0; i < model_inputs.size(); i++) {
|
||||
for (size_t j = 0; j < model_inputs[i].Shape().size(); j++) {
|
||||
if (model_inputs[i].Shape()[j] != inputs[i].Shape()[j]) {
|
||||
need_resize = true;
|
||||
}
|
||||
}
|
||||
dims.push_back(inputs[i].Shape());
|
||||
}
|
||||
return std::make_pair(dims, need_resize);
|
||||
}
|
||||
|
||||
Status ModelThread::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
// model
|
||||
auto model_input = model_->GetInputs();
|
||||
if (model_input.size() != inputs.size()) {
|
||||
MS_LOG(ERROR) << "model input size is: " << model_input.size() << ", but get input size is: " << inputs.size();
|
||||
return kLiteError;
|
||||
}
|
||||
auto resize_pair = GetModelResize(model_input, inputs);
|
||||
if (resize_pair.second) {
|
||||
auto dims = resize_pair.first;
|
||||
auto status = model_->Resize(model_->GetInputs(), dims);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model pool resize failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
}
|
||||
auto status = model_->Predict(inputs, outputs, before, after);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model predict failed.";
|
||||
return status;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_
|
||||
#define MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#include <future>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "include/api/model.h"
|
||||
#include "src/cxx_api/model_pool/predict_task_queue.h"
|
||||
namespace mindspore {
|
||||
using ModelPoolContex = std::vector<std::shared_ptr<Context>>;
|
||||
|
||||
class ModelThread {
|
||||
public:
|
||||
ModelThread() = default;
|
||||
|
||||
~ModelThread() = default;
|
||||
|
||||
// the model pool is initialized once and can always accept model run requests
|
||||
Status Init(const std::string &model_path, const std::shared_ptr<Context> &model_context, const Key &dec_key = {},
|
||||
const std::string &dec_mode = kDecModeAesGcm);
|
||||
|
||||
std::vector<MSTensor> GetInputs();
|
||||
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
|
||||
void Run();
|
||||
|
||||
private:
|
||||
std::pair<std::vector<std::vector<int64_t>>, bool> GetModelResize(const std::vector<MSTensor> &model_inputs,
|
||||
const std::vector<MSTensor> &inputs);
|
||||
|
||||
private:
|
||||
std::shared_ptr<mindspore::Model> model_ = nullptr;
|
||||
std::mutex mtx_model_;
|
||||
std::condition_variable model_cond_;
|
||||
|
||||
// num thread is configured according to the hardware
|
||||
int num_models_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/cxx_api/model_pool/predict_task_queue.h"
|
||||
namespace mindspore {
|
||||
PredictTaskQueue::~PredictTaskQueue() {
|
||||
predict_task_done_ = true;
|
||||
task_push_cond_.notify_all();
|
||||
}
|
||||
|
||||
void PredictTaskQueue::WaitUntilPredictActive(std::vector<MSTensor> *outputs) {
|
||||
std::unique_lock<std::mutex> result_lock(mtx_predict_task_);
|
||||
while (outputs->empty()) {
|
||||
task_pop_cond_.wait(result_lock);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PredictTaskQueue::ActiveTask() { task_pop_cond_.notify_all(); }
|
||||
|
||||
PredictTaskQueue *PredictTaskQueue::GetInstance() {
|
||||
static PredictTaskQueue instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
void PredictTaskQueue::PushPredictTask(std::shared_ptr<PredictTask> task) {
|
||||
std::unique_lock<std::mutex> data_lock(mtx_predict_task_);
|
||||
predict_task_.push(task);
|
||||
task_push_cond_.notify_one();
|
||||
}
|
||||
|
||||
std::shared_ptr<PredictTask> PredictTaskQueue::GetPredictTask() {
|
||||
std::unique_lock<std::mutex> task_lock(mtx_predict_task_);
|
||||
while (predict_task_.empty() && !predict_task_done_) {
|
||||
task_push_cond_.wait(task_lock);
|
||||
}
|
||||
if (predict_task_done_) {
|
||||
return nullptr;
|
||||
}
|
||||
auto predict_task = predict_task_.front();
|
||||
predict_task_.pop();
|
||||
return predict_task;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_PREDICT_TASK_QUEUE_H_
|
||||
#define MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_PREDICT_TASK_QUEUE_H_
|
||||
|
||||
#include <queue>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <condition_variable>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/status.h"
|
||||
namespace mindspore {
|
||||
struct PredictTask {
|
||||
PredictTask(const std::vector<MSTensor> *in, std::vector<MSTensor> *out, MSKernelCallBack before,
|
||||
MSKernelCallBack after)
|
||||
: inputs(in), outputs(out), before(before), after(after) {}
|
||||
const std::vector<MSTensor> *inputs;
|
||||
std::vector<MSTensor> *outputs;
|
||||
MSKernelCallBack before;
|
||||
MSKernelCallBack after;
|
||||
};
|
||||
|
||||
class PredictTaskQueue {
|
||||
public:
|
||||
static PredictTaskQueue *GetInstance();
|
||||
~PredictTaskQueue();
|
||||
|
||||
void PushPredictTask(std::shared_ptr<PredictTask> task);
|
||||
void WaitUntilPredictActive(std::vector<MSTensor> *outputs);
|
||||
std::shared_ptr<PredictTask> GetPredictTask();
|
||||
void ActiveTask();
|
||||
bool IsPredictTaskDone() { return predict_task_done_; }
|
||||
|
||||
private:
|
||||
PredictTaskQueue() = default;
|
||||
std::queue<std::shared_ptr<PredictTask>> predict_task_;
|
||||
|
||||
std::mutex mtx_predict_task_;
|
||||
std::condition_variable task_pop_cond_;
|
||||
std::condition_variable task_push_cond_;
|
||||
bool predict_task_done_ = false;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_PREDICT_TASK_QUEUE_H_
|
Loading…
Reference in New Issue