add model parallel runner impl

This commit is contained in:
yefeng 2022-09-28 16:07:35 +08:00
parent b2c39cc90c
commit 385ad6e354
7 changed files with 231 additions and 87 deletions

View File

@ -29,18 +29,28 @@ class RunnerConfig {
public:
struct Data;
RunnerConfig();
~RunnerConfig() = default;
~RunnerConfig();
/// \brief Set the number of workers at runtime. Only valid for ModelParallelRunner.
///
/// \param[in] workers_num the number of workers at runtime.
void SetWorkersNum(int32_t workers_num);
/// \brief Get the current operators parallel workers number setting. Only valid for ModelParallelRunner.
///
/// \return The current operators parallel workers number setting.
int32_t GetWorkersNum() const;
/// \brief Set the context at runtime. Only valid for ModelParallelRunner.
///
/// \param[in] context store environment variables at runtime.
void SetContext(const std::shared_ptr<Context> &context);
/// \brief Get the current context setting. Only valid for ModelParallelRunner.
///
/// \return The current operators context setting.
std::shared_ptr<Context> GetContext() const;
/// \brief Set the config before runtime. Only valid for ModelParallelRunner.
///
/// \param[in] section The category of the configuration parameter.
@ -52,16 +62,6 @@ class RunnerConfig {
/// \return The current config setting.
inline std::map<std::string, std::map<std::string, std::string>> GetConfigInfo() const;
/// \brief Get the current operators parallel workers number setting. Only valid for ModelParallelRunner.
///
/// \return The current operators parallel workers number setting.
int32_t GetWorkersNum() const;
/// \brief Get the current context setting. Only valid for ModelParallelRunner.
///
/// \return The current operators context setting.
std::shared_ptr<Context> GetContext() const;
/// \brief Set the config path before runtime. Only valid for ModelParallelRunner.
///
/// \param[in] config_path The path of the configuration parameter.
@ -92,14 +92,14 @@ void RunnerConfig::SetConfigPath(const std::string &config_path) { SetConfigPath
std::string RunnerConfig::GetConfigPath() const { return CharToString(GetConfigPathChar()); }
class ModelPool;
class ModelParallelRunnerImpl;
/// \brief The ModelParallelRunner class is used to define a MindSpore ModelParallelRunner, facilitating Model
/// management.
class MS_API ModelParallelRunner {
public:
ModelParallelRunner() = default;
~ModelParallelRunner() = default;
ModelParallelRunner();
~ModelParallelRunner();
/// \brief build a model parallel runner from model path so that it can run on a device.
///
@ -142,7 +142,7 @@ class MS_API ModelParallelRunner {
private:
Status Init(const std::vector<char> &model_path, const std::shared_ptr<RunnerConfig> &runner_config);
std::shared_ptr<ModelPool> model_pool_ = nullptr;
std::shared_ptr<ModelParallelRunnerImpl> model_parallel_runner_impl_ = nullptr;
};
Status ModelParallelRunner::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config) {

View File

@ -76,6 +76,7 @@ if(MSLITE_ENABLE_PARALLEL_INFERENCE)
${CMAKE_CURRENT_SOURCE_DIR}/extendrt/cxx_api/model_pool/model_worker.cc
${CMAKE_CURRENT_SOURCE_DIR}/extendrt/cxx_api/model_pool/model_pool.cc
${CMAKE_CURRENT_SOURCE_DIR}/extendrt/cxx_api/model_pool/model_parallel_runner.cc
${CMAKE_CURRENT_SOURCE_DIR}/extendrt/cxx_api/model_pool/model_parallel_runner_impl.cc
)
endif()

View File

@ -51,6 +51,7 @@ set(MSLIB_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_pool/model_worker.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_pool/model_pool.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_pool/model_parallel_runner.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_pool/model_parallel_runner_impl.cc
${API_MS_INFER_SRC}
${API_ACL_SRC}
${API_OPS_SRC}

View File

@ -14,10 +14,9 @@
* limitations under the License.
*/
#include "include/api/model_parallel_runner.h"
#include "src/extendrt/cxx_api/model_pool/model_pool.h"
#include "src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.h"
#include "src/extendrt/cxx_api/model_pool/runner_config.h"
#include "src/common/log_adapter.h"
#include "src/litert/cpu_info.h"
#ifdef CAPTURE_SIGNALS
#include "src/extendrt/signal_handler.h"
#endif
@ -32,8 +31,12 @@ extern void mindspore_log_init();
}
#endif
std::mutex g_model_parallel_runner_mutex;
RunnerConfig::RunnerConfig() : data_(std::make_shared<Data>()) {}
RunnerConfig::~RunnerConfig() {}
void RunnerConfig::SetWorkersNum(int32_t workers_num) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Runner config data is nullptr.";
@ -111,101 +114,70 @@ std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> Runn
return MapMapStringToChar(data_->config_info);
}
ModelParallelRunner::ModelParallelRunner() {}
ModelParallelRunner::~ModelParallelRunner() {}
Status ModelParallelRunner::Init(const std::vector<char> &model_path,
const std::shared_ptr<RunnerConfig> &runner_config) {
{
std::lock_guard<std::mutex> l(g_model_parallel_runner_mutex);
#ifdef USE_GLOG
mindspore::mindspore_log_init();
mindspore::mindspore_log_init();
#endif
if (model_pool_ != nullptr && model_pool_->IsInitialized()) {
MS_LOG(WARNING) << "ModelParallelRunner is already initialized, not need to initialize it again";
return kSuccess;
if (model_parallel_runner_impl_ == nullptr) {
model_parallel_runner_impl_ = std::make_shared<ModelParallelRunnerImpl>();
if (model_parallel_runner_impl_ == nullptr) {
MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
return kLiteNullptr;
}
}
}
auto new_model_pool = std::make_shared<ModelPool>();
if (new_model_pool == nullptr) {
MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
return kLiteNullptr;
}
if (!PlatformInstructionSetSupportCheck()) {
return kLiteNotSupport;
}
auto status = new_model_pool->InitByPath(CharToString(model_path), runner_config);
if (status != kSuccess) {
MS_LOG(ERROR) << "ModelParallelRunner init failed.";
return kLiteError;
}
if (model_pool_ != nullptr && model_pool_->IsInitialized()) {
MS_LOG(WARNING) << "ModelParallelRunner is already initialized, not need to initialize it again";
return kSuccess;
}
model_pool_ = new_model_pool;
#ifdef CAPTURE_SIGNALS
CaptureSignal();
#endif
return status;
return model_parallel_runner_impl_->Init(CharToString(model_path), runner_config);
}
Status ModelParallelRunner::Init(const void *model_data, size_t data_size,
const std::shared_ptr<RunnerConfig> &runner_config) {
{
std::lock_guard<std::mutex> l(g_model_parallel_runner_mutex);
#ifdef USE_GLOG
mindspore::mindspore_log_init();
mindspore::mindspore_log_init();
#endif
if (model_pool_ != nullptr && model_pool_->IsInitialized()) {
MS_LOG(WARNING) << "ModelParallelRunner is already initialized, not need to initialize it again";
return kSuccess;
if (model_parallel_runner_impl_ == nullptr) {
model_parallel_runner_impl_ = std::make_shared<ModelParallelRunnerImpl>();
if (model_parallel_runner_impl_ == nullptr) {
MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
return kLiteNullptr;
}
}
}
auto new_model_pool = std::make_shared<ModelPool>();
if (new_model_pool == nullptr) {
MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
return kLiteNullptr;
}
if (!PlatformInstructionSetSupportCheck()) {
return kLiteNotSupport;
}
auto status = new_model_pool->InitByBuf(static_cast<const char *>(model_data), data_size, runner_config);
if (status != kSuccess) {
MS_LOG(ERROR) << "model runner init failed.";
return kLiteError;
}
if (model_pool_ != nullptr && model_pool_->IsInitialized()) {
MS_LOG(WARNING) << "ModelParallelRunner is already initialized, not need to initialize it again";
return kSuccess;
}
model_pool_ = new_model_pool;
#ifdef CAPTURE_SIGNALS
CaptureSignal();
#endif
return status;
return model_parallel_runner_impl_->Init(model_data, data_size, runner_config);
}
std::vector<MSTensor> ModelParallelRunner::GetInputs() {
if (model_pool_ == nullptr) {
if (model_parallel_runner_impl_ == nullptr) {
std::vector<MSTensor> empty;
MS_LOG(ERROR) << "Please initialize ModelParallelRunner before calling GetInput API.";
return empty;
}
return model_pool_->GetInputs();
return model_parallel_runner_impl_->GetInputs();
}
std::vector<MSTensor> ModelParallelRunner::GetOutputs() {
if (model_pool_ == nullptr) {
if (model_parallel_runner_impl_ == nullptr) {
std::vector<MSTensor> empty;
MS_LOG(ERROR) << "Please initialize ModelParallelRunner before calling GetOutputs API.";
return empty;
}
return model_pool_->GetOutputs();
return model_parallel_runner_impl_->GetOutputs();
}
Status ModelParallelRunner::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (outputs == nullptr || model_pool_ == nullptr) {
MS_LOG(ERROR) << "predict output is nullptr or ModelParallelRunner Not Initialize.";
if (model_parallel_runner_impl_ == nullptr) {
MS_LOG(ERROR) << "ModelParallelRunner Not Initialize.";
return kLiteNullptr;
}
auto status = model_pool_->Predict(inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model runner predict failed.";
return status;
}
return kSuccess;
return model_parallel_runner_impl_->Predict(inputs, outputs, before, after);
}
} // namespace mindspore

View File

@ -0,0 +1,129 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/extendrt/cxx_api/model_pool/model_parallel_runner_impl.h"
#include "src/extendrt/cxx_api/model_pool/runner_config.h"
#include "src/common/log_adapter.h"
#include "src/litert/cpu_info.h"
#ifdef CAPTURE_SIGNALS
#include "src/extendrt/signal_handler.h"
#endif
namespace mindspore {
Status ModelParallelRunnerImpl::Init(const std::string &model_path,
const std::shared_ptr<RunnerConfig> &runner_config) {
std::unique_lock<std::shared_mutex> l(model_parallel_runner_impl_mutex_);
if (model_pool_ != nullptr && model_pool_->IsInitialized()) {
MS_LOG(WARNING) << "ModelParallelRunner is already initialized, not need to initialize it again";
return kSuccess;
}
model_pool_ = new (std::nothrow) ModelPool;
if (model_pool_ == nullptr) {
MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
return kLiteNullptr;
}
if (!PlatformInstructionSetSupportCheck()) {
delete model_pool_;
model_pool_ = nullptr;
return kLiteNotSupport;
}
auto status = model_pool_->InitByPath(model_path, runner_config);
if (status != kSuccess) {
MS_LOG(ERROR) << "ModelParallelRunner init failed.";
delete model_pool_;
model_pool_ = nullptr;
return kLiteError;
}
#ifdef CAPTURE_SIGNALS
CaptureSignal();
#endif
return status;
}
Status ModelParallelRunnerImpl::Init(const void *model_data, size_t data_size,
const std::shared_ptr<RunnerConfig> &runner_config) {
std::unique_lock<std::shared_mutex> l(model_parallel_runner_impl_mutex_);
if (model_pool_ != nullptr && model_pool_->IsInitialized()) {
MS_LOG(WARNING) << "ModelParallelRunner is already initialized, not need to initialize it again";
return kSuccess;
}
model_pool_ = new (std::nothrow) ModelPool;
if (model_pool_ == nullptr) {
MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
return kLiteNullptr;
}
if (!PlatformInstructionSetSupportCheck()) {
delete model_pool_;
model_pool_ = nullptr;
return kLiteNotSupport;
}
auto status = model_pool_->InitByBuf(static_cast<const char *>(model_data), data_size, runner_config);
if (status != kSuccess) {
MS_LOG(ERROR) << "model runner init failed.";
delete model_pool_;
model_pool_ = nullptr;
return kLiteError;
}
#ifdef CAPTURE_SIGNALS
CaptureSignal();
#endif
return status;
}
std::vector<MSTensor> ModelParallelRunnerImpl::GetInputs() {
std::shared_lock<std::shared_mutex> l(model_parallel_runner_impl_mutex_);
if (model_pool_ == nullptr) {
std::vector<MSTensor> empty;
MS_LOG(ERROR) << "Please initialize ModelParallelRunner before calling GetInput API.";
return empty;
}
return model_pool_->GetInputs();
}
std::vector<MSTensor> ModelParallelRunnerImpl::GetOutputs() {
std::shared_lock<std::shared_mutex> l(model_parallel_runner_impl_mutex_);
if (model_pool_ == nullptr) {
std::vector<MSTensor> empty;
MS_LOG(ERROR) << "Please initialize ModelParallelRunner before calling GetOutputs API.";
return empty;
}
return model_pool_->GetOutputs();
}
Status ModelParallelRunnerImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
std::shared_lock<std::shared_mutex> l(model_parallel_runner_impl_mutex_);
if (outputs == nullptr || model_pool_ == nullptr) {
MS_LOG(ERROR) << "predict output is nullptr or ModelParallelRunner Not Initialize.";
return kLiteNullptr;
}
auto status = model_pool_->Predict(inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model runner predict failed.";
return status;
}
return kSuccess;
}
ModelParallelRunnerImpl::~ModelParallelRunnerImpl() {
std::unique_lock<std::shared_mutex> l(model_parallel_runner_impl_mutex_);
if (model_pool_ != nullptr) {
MS_LOG(INFO) << "delete model pool impl.";
delete model_pool_;
model_pool_ = nullptr;
} else {
MS_LOG(INFO) << "model pool is nullptr.";
}
MS_LOG(INFO) << "delete model pool done.";
}
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_MODEL_POOL_MODEL_PARALLEL_RUNNER_IMPL_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_MODEL_POOL_MODEL_PARALLEL_RUNNER_IMPL_H_
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include <string>
#include "src/extendrt/cxx_api/model_pool/model_pool.h"
#include "include/api/context.h"
namespace mindspore {
class ModelParallelRunnerImpl {
public:
ModelParallelRunnerImpl() = default;
~ModelParallelRunnerImpl();
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr);
Status Init(const void *model_data, size_t data_size, const std::shared_ptr<RunnerConfig> &runner_config = nullptr);
std::vector<MSTensor> GetInputs();
std::vector<MSTensor> GetOutputs();
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
private:
ModelPool *model_pool_ = nullptr;
std::shared_mutex model_parallel_runner_impl_mutex_;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_MODEL_POOL_MODEL_PARALLEL_RUNNER_IMPL_H_

View File

@ -764,7 +764,6 @@ ModelPoolConfig ModelPool::CreateBaseStrategyModelPoolConfig(const std::shared_p
}
std::vector<MSTensor> ModelPool::GetInputs() {
std::shared_lock<std::shared_mutex> l(model_pool_mutex_);
std::vector<MSTensor> inputs;
if (inputs_info_.empty()) {
MS_LOG(ERROR) << "model input is empty.";
@ -787,7 +786,6 @@ std::vector<MSTensor> ModelPool::GetInputs() {
}
std::vector<MSTensor> ModelPool::GetOutputs() {
std::shared_lock<std::shared_mutex> l(model_pool_mutex_);
std::vector<MSTensor> outputs;
if (outputs_info_.empty()) {
MS_LOG(ERROR) << "model output is empty.";
@ -1093,7 +1091,6 @@ Status ModelPool::Init(const char *model_buf, size_t size, const std::shared_ptr
}
Status ModelPool::InitByBuf(const char *model_data, size_t size, const std::shared_ptr<RunnerConfig> &runner_config) {
std::unique_lock<std::shared_mutex> l(model_pool_mutex_);
auto status = Init(model_data, size, runner_config);
if (status != kSuccess) {
MS_LOG(ERROR) << "init by buf failed.";
@ -1104,7 +1101,6 @@ Status ModelPool::InitByBuf(const char *model_data, size_t size, const std::shar
}
Status ModelPool::InitByPath(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config) {
std::unique_lock<std::shared_mutex> l(model_pool_mutex_);
model_path_ = model_path;
size_t size = 0;
graph_buf_ = lite::ReadFile(model_path.c_str(), &size);
@ -1212,7 +1208,6 @@ Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
return kLiteInputTensorError;
}
}
std::shared_lock<std::shared_mutex> l(model_pool_mutex_);
predict_task_mutex_.lock();
int max_wait_worker_node_id = 0;
int max_wait_worker_num = 0;
@ -1250,7 +1245,6 @@ Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
}
ModelPool::~ModelPool() {
std::unique_lock<std::shared_mutex> l(model_pool_mutex_);
is_initialized_ = false;
for (auto &item : model_pool_info_) {
auto strategy = item.first;