!48584 [MS][LITE][parallel predict] fix thread pool

Merge pull request !48584 from yefeng/517-fix_shared_thread_pool
This commit is contained in:
i-robot 2023-02-09 06:29:27 +00:00 committed by Gitee
commit 22a5c32b1d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 81 additions and 32 deletions

View File

@ -51,6 +51,9 @@ void ParallelThreadPoolManager::Init(bool enable_shared_thread_pool, const std::
runner_id_pools_[runner_id] = runner_pools; runner_id_pools_[runner_id] = runner_pools;
remaining_thread_num_[runner_id] = remaining_thread_num; remaining_thread_num_[runner_id] = remaining_thread_num;
thread_num_limit_[runner_id] = thread_num_limit; thread_num_limit_[runner_id] = thread_num_limit;
idle_pool_num_[runner_id] = worker_num;
runner_worker_num_[runner_id] = worker_num;
worker_init_num_[runner_id] = 0;
#endif #endif
} }
@ -136,6 +139,7 @@ void ParallelThreadPoolManager::BindPoolToRunner(
auto worker = static_cast<ParallelWorker *>(all_workers[i]); auto worker = static_cast<ParallelWorker *>(all_workers[i]);
pool_workers_[parallel_pool].push_back(worker); pool_workers_[parallel_pool].push_back(worker);
} }
worker_init_num_[runner_id]++;
#endif #endif
} }
@ -150,7 +154,11 @@ bool ParallelThreadPoolManager::GetEnableSharedThreadPool(std::string runner_id)
void ParallelThreadPoolManager::ActivatePool(const std::string &runner_id, int model_id) { void ParallelThreadPoolManager::ActivatePool(const std::string &runner_id, int model_id) {
#ifdef THREAD_POOL_MANAGER #ifdef THREAD_POOL_MANAGER
std::shared_lock<std::shared_mutex> l(pool_manager_mutex_); std::shared_lock<std::shared_mutex> l(pool_manager_mutex_);
if (!enable_shared_thread_pool_[runner_id]) {
return;
}
auto &pool = runner_id_pools_[runner_id][model_id]; auto &pool = runner_id_pools_[runner_id][model_id];
idle_pool_num_[runner_id]--;
pool->UseThreadPool(1); pool->UseThreadPool(1);
auto &workers = pool_workers_[pool]; auto &workers = pool_workers_[pool];
for (auto &worker : workers) { for (auto &worker : workers) {
@ -162,15 +170,19 @@ void ParallelThreadPoolManager::ActivatePool(const std::string &runner_id, int m
void ParallelThreadPoolManager::SetFreePool(const std::string &runner_id, int model_id) { void ParallelThreadPoolManager::SetFreePool(const std::string &runner_id, int model_id) {
#ifdef THREAD_POOL_MANAGER #ifdef THREAD_POOL_MANAGER
std::shared_lock<std::shared_mutex> l(pool_manager_mutex_); std::shared_lock<std::shared_mutex> l(pool_manager_mutex_);
if (!enable_shared_thread_pool_[runner_id]) {
return;
}
auto &pool = runner_id_pools_[runner_id][model_id]; auto &pool = runner_id_pools_[runner_id][model_id];
pool->UseThreadPool(-1); pool->UseThreadPool(-1);
idle_pool_num_[runner_id]++;
#endif #endif
} }
#ifdef ENABLE_MINDRT #ifdef ENABLE_MINDRT
ParallelThreadPool *ParallelThreadPoolManager::GetIdleThreadPool(const std::string &runner_id, ParallelTask *task) { ParallelThreadPool *ParallelThreadPoolManager::GetIdleThreadPool(const std::string &runner_id, ParallelTask *task) {
#ifdef THREAD_POOL_MANAGER #ifdef THREAD_POOL_MANAGER
if (!has_idle_pool_[runner_id]) { if (runner_worker_num_[runner_id] != worker_init_num_[runner_id] || idle_pool_num_[runner_id] <= 0) {
return nullptr; return nullptr;
} }
std::shared_lock<std::shared_mutex> l(pool_manager_mutex_); std::shared_lock<std::shared_mutex> l(pool_manager_mutex_);
@ -205,6 +217,9 @@ void ParallelThreadPoolManager::ResetParallelThreadPoolManager(const std::string
enable_shared_thread_pool_.erase(runner_id); enable_shared_thread_pool_.erase(runner_id);
remaining_thread_num_.erase(runner_id); remaining_thread_num_.erase(runner_id);
thread_num_limit_.erase(runner_id); thread_num_limit_.erase(runner_id);
runner_worker_num_.erase(runner_id);
worker_init_num_.erase(runner_id);
idle_pool_num_.erase(runner_id);
#endif #endif
} }
@ -218,6 +233,9 @@ ParallelThreadPoolManager::~ParallelThreadPoolManager() {
enable_shared_thread_pool_.clear(); enable_shared_thread_pool_.clear();
remaining_thread_num_.clear(); remaining_thread_num_.clear();
thread_num_limit_.clear(); thread_num_limit_.clear();
runner_worker_num_.clear();
worker_init_num_.clear();
idle_pool_num_.clear();
THREAD_INFO("~ParallelThreadPoolManager end."); THREAD_INFO("~ParallelThreadPoolManager end.");
#endif #endif
} }

View File

@ -81,6 +81,9 @@ class ParallelThreadPoolManager {
std::shared_mutex pool_manager_mutex_; std::shared_mutex pool_manager_mutex_;
std::map<std::string, bool> has_idle_pool_; std::map<std::string, bool> has_idle_pool_;
std::map<std::string, bool> enable_shared_thread_pool_; std::map<std::string, bool> enable_shared_thread_pool_;
std::map<std::string, int> runner_worker_num_;
std::map<std::string, int> worker_init_num_;
std::map<std::string, int> idle_pool_num_;
std::map<std::string, int> remaining_thread_num_; std::map<std::string, int> remaining_thread_num_;
std::map<std::string, int> thread_num_limit_; std::map<std::string, int> thread_num_limit_;
#endif #endif

View File

@ -76,6 +76,7 @@ static const char *const kInnerSharingWeightCopyBufKey = "sharing_weight_copy_bu
static const char *const kInnerModelIDKey = "inner_model_id"; static const char *const kInnerModelIDKey = "inner_model_id";
static const char *const kInnerRunnerIDKey = "inner_runner_id"; static const char *const kInnerRunnerIDKey = "inner_runner_id";
static const char *const kInnerNumaIDKey = "inner_numa_id"; static const char *const kInnerNumaIDKey = "inner_numa_id";
static const char *const kInnerWorkerNumKey = "inner_worker_num";
// gpu context // gpu context
static const char *const kGPUContextSection = "gpu_context"; static const char *const kGPUContextSection = "gpu_context";
static const char *const kInputShapeKey = "input_shape"; static const char *const kInputShapeKey = "input_shape";

View File

@ -698,19 +698,22 @@ Status ModelPool::CreateWorkers(const char *graph_buf, size_t size, const ModelP
MS_LOG(INFO) << "runner_id_: " << runner_id_ << " | enable_shared_thread_pool_: " << enable_shared_thread_pool_ MS_LOG(INFO) << "runner_id_: " << runner_id_ << " | enable_shared_thread_pool_: " << enable_shared_thread_pool_
<< " | workers_num_: " << workers_num_ << " | remaining_thread_num_: " << remaining_thread_num_ << " | workers_num_: " << workers_num_ << " | remaining_thread_num_: " << remaining_thread_num_
<< " | thread_num_limit_: " << thread_num_limit_; << " | thread_num_limit_: " << thread_num_limit_;
ParallelThreadPoolManager::GetInstance()->Init(enable_shared_thread_pool_, runner_id_, workers_num_,
remaining_thread_num_, thread_num_limit_);
for (size_t i = 0; i < workers_num_; i++) { for (size_t i = 0; i < workers_num_; i++) {
int numa_node_id = model_pool_config[i]->numa_id; int numa_node_id = model_pool_config[i]->numa_id;
std::map<std::string, std::string> ids; std::map<std::string, std::string> ids;
ids[lite::kInnerModelIDKey] = std::to_string(i); ids[lite::kInnerModelIDKey] = std::to_string(i);
ids[lite::kInnerRunnerIDKey] = runner_id_; ids[lite::kInnerRunnerIDKey] = runner_id_;
ids[lite::kInnerNumaIDKey] = std::to_string(model_pool_config[i]->numa_id); ids[lite::kInnerNumaIDKey] = std::to_string(model_pool_config[i]->numa_id);
model_pool_config[i]->config_info[lite::kInnerModelParallelRunnerSection] = ids; if (enable_shared_thread_pool_) {
ids[lite::kInnerWorkerNumKey] = std::to_string(workers_num_);
ids[lite::kEnableSharedThreadPoolKey] = "true";
ids[lite::kThreadNumRemainingPerWorkerKey] = std::to_string(remaining_thread_num_);
ids[lite::kThreadNumLimitPerWorkerKey] = std::to_string(thread_num_limit_);
}
if (!copy_model || model_pool_config[i]->numa_id == 0) { if (!copy_model || model_pool_config[i]->numa_id == 0) {
ids[lite::kInnerSharingWeightCopyBufKey] = "false"; ids[lite::kInnerSharingWeightCopyBufKey] = "false";
} }
model_pool_config[i]->config_info[lite::kInnerModelParallelRunnerSection] = ids;
model_worker = std::make_shared<ModelWorker>(); model_worker = std::make_shared<ModelWorker>();
if (model_worker == nullptr) { if (model_worker == nullptr) {
MS_LOG(ERROR) << "model worker is nullptr."; MS_LOG(ERROR) << "model worker is nullptr.";
@ -911,8 +914,13 @@ Status ModelPool::ParseSharedThreadPoolParam(const std::shared_ptr<RunnerConfig>
} }
ModelPoolConfig ModelPool::Init(const std::shared_ptr<RunnerConfig> &runner_config) { ModelPoolConfig ModelPool::Init(const std::shared_ptr<RunnerConfig> &runner_config) {
auto status = ParseSharedThreadPoolParam(runner_config);
if (status != kSuccess) {
MS_LOG(WARNING) << "ParseSharedThreadPoolParam failed, Not use thread pool shared.";
enable_shared_thread_pool_ = false;
}
ModelPoolConfig model_pool_config = {}; ModelPoolConfig model_pool_config = {};
auto status = CanUseAllPhysicalResources(); status = CanUseAllPhysicalResources();
if (status != kSuccess) { if (status != kSuccess) {
MS_LOG(ERROR) << "parser sys file failed."; MS_LOG(ERROR) << "parser sys file failed.";
return model_pool_config; return model_pool_config;
@ -1080,25 +1088,15 @@ Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
auto available_worker = GetMaxWaitWorkerNum(&max_wait_worker_node_id, &max_wait_worker_num); auto available_worker = GetMaxWaitWorkerNum(&max_wait_worker_node_id, &max_wait_worker_num);
if (available_worker != nullptr) { if (available_worker != nullptr) {
// dispatch tasks directly to workers // dispatch tasks directly to workers
if (enable_shared_thread_pool_) {
ParallelThreadPoolManager::GetInstance()->SetHasIdlePool(runner_id_, true);
ParallelThreadPoolManager::GetInstance()->ActivatePool(runner_id_, available_worker->GetWorkerID());
}
auto ret = available_worker->Predict(inputs, outputs, before, after); auto ret = available_worker->Predict(inputs, outputs, before, after);
if (ret != kSuccess) { if (ret != kSuccess) {
MS_LOG(ERROR) << "direct predict failed."; MS_LOG(ERROR) << "direct predict failed.";
return kLiteError; return kLiteError;
} }
predict_task_queue_->IncreaseWaitModelNum(1, max_wait_worker_node_id); predict_task_queue_->IncreaseWaitModelNum(1, max_wait_worker_node_id);
if (enable_shared_thread_pool_) {
ParallelThreadPoolManager::GetInstance()->SetFreePool(runner_id_, available_worker->GetWorkerID());
}
return kSuccess; return kSuccess;
} else { } else {
// do predict // do predict
if (enable_shared_thread_pool_) {
ParallelThreadPoolManager::GetInstance()->SetHasIdlePool(runner_id_, false);
}
size_t task_id; size_t task_id;
auto task = CreatePredictTask(inputs, outputs, before, after, &task_id); auto task = CreatePredictTask(inputs, outputs, before, after, &task_id);
if (task == nullptr) { if (task == nullptr) {
@ -1134,9 +1132,6 @@ ModelPool::~ModelPool() {
if (thread_.joinable()) { if (thread_.joinable()) {
thread_.join(); thread_.join();
} }
if (enable_shared_thread_pool_) {
ParallelThreadPoolManager::GetInstance()->ResetParallelThreadPoolManager(runner_id_);
}
MS_LOG(INFO) << "delete model pool task."; MS_LOG(INFO) << "delete model pool task.";
if (tasks_ != nullptr) { if (tasks_ != nullptr) {
delete[] tasks_; delete[] tasks_;

View File

@ -758,6 +758,7 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af
MS_LOG(ERROR) << "Not support multi-threading"; MS_LOG(ERROR) << "Not support multi-threading";
return RET_ERROR; return RET_ERROR;
} }
ParallelThreadPoolManager::GetInstance()->ActivatePool(runner_id_, worker_id_);
STATUS ret = CheckTensorsInvalid(inputs_); STATUS ret = CheckTensorsInvalid(inputs_);
if (MS_UNLIKELY(ret != RET_OK)) { if (MS_UNLIKELY(ret != RET_OK)) {
is_running_.store(false); is_running_.store(false);
@ -781,27 +782,49 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af
input->set_shape_changed(false); input->set_shape_changed(false);
} }
} }
ParallelThreadPoolManager::GetInstance()->SetFreePool(runner_id_, worker_id_);
is_running_.store(false); is_running_.store(false);
return ret; return ret;
} }
int LiteSession::InitSharedThreadPool() {
int workers_num = -1;
int remaining_thread_num = -1;
int thread_num_limit = -1;
bool enable_shared_pool = false;
if (config_info_ != nullptr) {
auto runner_info_item = config_info_->find(kInnerModelParallelRunnerSection);
if (runner_info_item != config_info_->end()) {
auto item_runner = runner_info_item->second.find(kInnerRunnerIDKey);
if (item_runner != runner_info_item->second.end()) {
runner_id_ = runner_info_item->second.at(kInnerRunnerIDKey);
}
auto shared_pool_item = runner_info_item->second.find(kEnableSharedThreadPoolKey);
if (shared_pool_item != runner_info_item->second.end() &&
runner_info_item->second.at(kEnableSharedThreadPoolKey) == "true") {
workers_num = std::atoi(runner_info_item->second.at(kInnerWorkerNumKey).c_str());
remaining_thread_num = std::atoi(runner_info_item->second.at(kThreadNumRemainingPerWorkerKey).c_str());
thread_num_limit = std::atoi(runner_info_item->second.at(kThreadNumLimitPerWorkerKey).c_str());
worker_id_ = std::atoi(runner_info_item->second.at(kInnerModelIDKey).c_str());
enable_shared_pool = true;
}
}
}
MS_LOG(INFO) << "runner id: " << runner_id_ << " enable_shared_pool: " << enable_shared_pool
<< " workers_num: " << workers_num << " thread_num_limit: " << thread_num_limit
<< " remaining_thread_num: " << remaining_thread_num;
ParallelThreadPoolManager::GetInstance()->Init(enable_shared_pool, runner_id_, workers_num, remaining_thread_num,
thread_num_limit);
return RET_OK;
}
int LiteSession::ContextInit(const std::shared_ptr<InnerContext> &context) { int LiteSession::ContextInit(const std::shared_ptr<InnerContext> &context) {
if (context == nullptr) { if (context == nullptr) {
MS_LOG(ERROR) << "context is nullptr"; MS_LOG(ERROR) << "context is nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
this->context_ = context; this->context_ = context;
std::string runner_id; context_->SetBindRunnerId(runner_id_);
if (config_info_ != nullptr) {
auto it_id = config_info_->find(kInnerModelParallelRunnerSection);
if (it_id != config_info_->end()) {
auto item_runner = it_id->second.find(kInnerRunnerIDKey);
if (item_runner != it_id->second.end()) {
runner_id = it_id->second.at(kInnerRunnerIDKey);
}
}
}
context_->SetBindRunnerId(runner_id);
auto ret = this->context_->Init(); auto ret = this->context_->Init();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init Context failed"; MS_LOG(ERROR) << "Init Context failed";
@ -819,8 +842,8 @@ int LiteSession::ContextInit(const std::shared_ptr<InnerContext> &context) {
context_->thread_pool_->SetMinSpinCount(kDefaulLiteIosSpinCount); context_->thread_pool_->SetMinSpinCount(kDefaulLiteIosSpinCount);
#endif #endif
if (context_->inter_op_parallel_num_ > 1 && !runner_id.empty() && if (context_->inter_op_parallel_num_ > 1 && !runner_id_.empty() &&
ParallelThreadPoolManager::GetInstance()->GetEnableSharedThreadPool(runner_id)) { ParallelThreadPoolManager::GetInstance()->GetEnableSharedThreadPool(runner_id_)) {
MS_LOG(INFO) << "Enable subgraph parallelism and enable thread pool sharing"; MS_LOG(INFO) << "Enable subgraph parallelism and enable thread pool sharing";
ParallelThreadPoolManager::GetInstance()->BindPoolToRunner(context_->thread_pool_, config_info_); ParallelThreadPoolManager::GetInstance()->BindPoolToRunner(context_->thread_pool_, config_info_);
} }
@ -982,6 +1005,12 @@ int LiteSession::Init(const std::shared_ptr<InnerContext> &context) {
return RET_NOT_SUPPORT; return RET_NOT_SUPPORT;
} }
auto status = InitSharedThreadPool();
if (status != RET_OK) {
MS_LOG(ERROR) << "init Shared thread pool failed";
is_running_.store(false);
return status;
}
auto ret = ContextInit(context); auto ret = ContextInit(context);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Init Context failed"; MS_LOG(ERROR) << "Init Context failed";
@ -1077,6 +1106,7 @@ LiteSession::~LiteSession() {
#endif #endif
delete ms_context_; delete ms_context_;
ms_context_ = nullptr; ms_context_ = nullptr;
ParallelThreadPoolManager::GetInstance()->ResetParallelThreadPoolManager(runner_id_);
lite::PackWeightManager::GetInstance()->FreePackWeight(runner_id_, model_id_); lite::PackWeightManager::GetInstance()->FreePackWeight(runner_id_, model_id_);
if (model_ != nullptr && is_shared_weight_) { if (model_ != nullptr && is_shared_weight_) {
model_->buf = nullptr; model_->buf = nullptr;

View File

@ -170,6 +170,7 @@ class LiteSession {
int CreateCoreMLDelegate(); int CreateCoreMLDelegate();
int DelegateInit(); int DelegateInit();
int InitGPURuntime(); int InitGPURuntime();
int InitSharedThreadPool();
private: private:
int IsolateOutputTensor(); int IsolateOutputTensor();
@ -242,6 +243,7 @@ class LiteSession {
std::vector<kernel::KernelExec *> non_tail_call_kernels_; std::vector<kernel::KernelExec *> non_tail_call_kernels_;
std::string model_id_; std::string model_id_;
std::string runner_id_; std::string runner_id_;
int worker_id_;
bool is_shared_weight_ = false; bool is_shared_weight_ = false;
}; };
} // namespace lite } // namespace lite