!40255 [MS][LITE][parallel predict]model parallel runner

Merge pull request !40255 from yefeng/381-fix_model_parallel_runner_lock
This commit is contained in:
i-robot 2022-08-15 06:17:14 +00:00 committed by Gitee
commit 4469fc0df3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 35 additions and 3 deletions

View File

@ -727,6 +727,7 @@ Status ModelPool::CreateWorkers(const char *graph_buf, size_t size, const ModelP
return kLiteError;
}
bool create_worker_success = true;
MS_LOG(INFO) << "Strategy: " << strategy << " | worker num: " << model_pool_info_[strategy].all_workers_num_;
for (size_t i = 0; i < model_pool_info_[strategy].all_workers_num_; i++) {
model_pool_config[i]->strategy = strategy;
int numa_node_id = model_pool_config[i]->numa_id;
@ -741,6 +742,22 @@ Status ModelPool::CreateWorkers(const char *graph_buf, size_t size, const ModelP
}
auto task_queue_id = model_pool_config[i]->task_queue_id;
model_pool_info_[strategy].predict_task_queue_->IncreaseWaitModelNum(1, task_queue_id);
MS_LOG(INFO) << "Strategy: " << strategy << " | create worker index: " << i
<< " | numa id: " << model_pool_config[i]->numa_id
<< " | worker affinity mode: " << model_pool_config[i]->context->GetThreadAffinityMode()
<< " | worker bind core list: " << model_pool_config[i]->context->GetThreadAffinityCoreList()
<< " | worker thread num: " << model_pool_config[i]->context->GetThreadNum()
<< " | inter op parallel num: " << model_pool_config[i]->context->GetInterOpParallelNum();
if (!model_pool_config[i]->config_info.empty()) {
for (auto &item : model_pool_config[i]->config_info) {
auto section = item.first;
MS_LOG(INFO) << "section: " << section;
auto configs = item.second;
for (auto &config : configs) {
MS_LOG(INFO) << "\t key: " << config.first << " | value: " << config.second;
}
}
}
worker_thread_vec_.push_back(std::thread(&ModelWorker::CreateThreadWorker, model_worker, new_model_buf, size,
model_pool_config[i], model_pool_info_[strategy].predict_task_queue_,
&create_worker_success));
@ -824,6 +841,7 @@ Status ModelPool::InitAdvancedStrategy(const char *model_buf, size_t size, int b
use_advanced_strategy_ = false;
return kSuccess;
}
MS_LOG(INFO) << "use advanced strategy";
model_pool_info_[ADVANCED].use_numa = numa_available_;
auto status = SetDefaultOptimalModelNum(advanced_thread_num, ADVANCED);
if (status != kSuccess) {

View File

@ -55,8 +55,11 @@ void ModelWorker::CreateThreadWorker(const char *model_buf, size_t size,
if (status != kSuccess) {
PrintWorkerInfo();
MS_LOG(ERROR) << "init failed in model worker.";
*create_success = false;
create_work_done_ = true;
{
std::unique_lock<std::mutex> create_work_lock(create_work_done_mutex_);
*create_success = false;
create_work_done_ = true;
}
create_work_done_condition_.notify_one();
}
Run();
@ -64,7 +67,11 @@ void ModelWorker::CreateThreadWorker(const char *model_buf, size_t size,
void ModelWorker::Run() {
int task_queue_id = worker_config_->task_queue_id;
create_work_done_ = true;
{
// The scope of the lock is only for this variable
std::unique_lock<std::mutex> create_work_lock(create_work_done_mutex_);
create_work_done_ = true;
}
create_work_done_condition_.notify_one();
while (!predict_task_queue_->IsPredictTaskDone()) {
auto task = predict_task_queue_->GetPredictTask(task_queue_id, this);

View File

@ -19,6 +19,11 @@
namespace mindspore {
PredictTaskQueue::~PredictTaskQueue() {
if (predict_task_ != nullptr) {
#ifdef USE_HQUEUE
for (size_t i = 0; i < task_queue_num_; i++) {
predict_task_[i].Clean();
}
#endif
delete[] predict_task_;
predict_task_ = nullptr;
}
@ -39,6 +44,7 @@ Status PredictTaskQueue::InitTaskQueue(size_t num, size_t max_queue_size) {
return kLiteError;
}
#ifdef USE_HQUEUE
task_queue_num_ = num;
predict_task_ = new (std::nothrow) HQueue<PredictTask>[num]();
if (predict_task_ == nullptr) {
MS_LOG(ERROR) << "new predict task failed.";

View File

@ -68,6 +68,7 @@ class PredictTaskQueue {
#else
std::queue<PredictTask *> *predict_task_;
#endif
size_t task_queue_num_ = -1;
std::atomic_int *idle_worker_num_;
std::mutex mtx_predict_task_;
std::condition_variable task_pop_cond_;