forked from mindspore-Ecosystem/mindspore
!40255 [MS][LITE][parallel predict]model parallel runner
Merge pull request !40255 from yefeng/381-fix_model_parallel_runner_lock
This commit is contained in:
commit
4469fc0df3
|
@ -727,6 +727,7 @@ Status ModelPool::CreateWorkers(const char *graph_buf, size_t size, const ModelP
|
|||
return kLiteError;
|
||||
}
|
||||
bool create_worker_success = true;
|
||||
MS_LOG(INFO) << "Strategy: " << strategy << " | worker num: " << model_pool_info_[strategy].all_workers_num_;
|
||||
for (size_t i = 0; i < model_pool_info_[strategy].all_workers_num_; i++) {
|
||||
model_pool_config[i]->strategy = strategy;
|
||||
int numa_node_id = model_pool_config[i]->numa_id;
|
||||
|
@ -741,6 +742,22 @@ Status ModelPool::CreateWorkers(const char *graph_buf, size_t size, const ModelP
|
|||
}
|
||||
auto task_queue_id = model_pool_config[i]->task_queue_id;
|
||||
model_pool_info_[strategy].predict_task_queue_->IncreaseWaitModelNum(1, task_queue_id);
|
||||
MS_LOG(INFO) << "Strategy: " << strategy << " | create worker index: " << i
|
||||
<< " | numa id: " << model_pool_config[i]->numa_id
|
||||
<< " | worker affinity mode: " << model_pool_config[i]->context->GetThreadAffinityMode()
|
||||
<< " | worker bind core list: " << model_pool_config[i]->context->GetThreadAffinityCoreList()
|
||||
<< " | worker thread num: " << model_pool_config[i]->context->GetThreadNum()
|
||||
<< " | inter op parallel num: " << model_pool_config[i]->context->GetInterOpParallelNum();
|
||||
if (!model_pool_config[i]->config_info.empty()) {
|
||||
for (auto &item : model_pool_config[i]->config_info) {
|
||||
auto section = item.first;
|
||||
MS_LOG(INFO) << "section: " << section;
|
||||
auto configs = item.second;
|
||||
for (auto &config : configs) {
|
||||
MS_LOG(INFO) << "\t key: " << config.first << " | value: " << config.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
worker_thread_vec_.push_back(std::thread(&ModelWorker::CreateThreadWorker, model_worker, new_model_buf, size,
|
||||
model_pool_config[i], model_pool_info_[strategy].predict_task_queue_,
|
||||
&create_worker_success));
|
||||
|
@ -824,6 +841,7 @@ Status ModelPool::InitAdvancedStrategy(const char *model_buf, size_t size, int b
|
|||
use_advanced_strategy_ = false;
|
||||
return kSuccess;
|
||||
}
|
||||
MS_LOG(INFO) << "use advanced strategy";
|
||||
model_pool_info_[ADVANCED].use_numa = numa_available_;
|
||||
auto status = SetDefaultOptimalModelNum(advanced_thread_num, ADVANCED);
|
||||
if (status != kSuccess) {
|
||||
|
|
|
@ -55,8 +55,11 @@ void ModelWorker::CreateThreadWorker(const char *model_buf, size_t size,
|
|||
if (status != kSuccess) {
|
||||
PrintWorkerInfo();
|
||||
MS_LOG(ERROR) << "init failed in model worker.";
|
||||
*create_success = false;
|
||||
create_work_done_ = true;
|
||||
{
|
||||
std::unique_lock<std::mutex> create_work_lock(create_work_done_mutex_);
|
||||
*create_success = false;
|
||||
create_work_done_ = true;
|
||||
}
|
||||
create_work_done_condition_.notify_one();
|
||||
}
|
||||
Run();
|
||||
|
@ -64,7 +67,11 @@ void ModelWorker::CreateThreadWorker(const char *model_buf, size_t size,
|
|||
|
||||
void ModelWorker::Run() {
|
||||
int task_queue_id = worker_config_->task_queue_id;
|
||||
create_work_done_ = true;
|
||||
{
|
||||
// The scope of the lock is only for this variable
|
||||
std::unique_lock<std::mutex> create_work_lock(create_work_done_mutex_);
|
||||
create_work_done_ = true;
|
||||
}
|
||||
create_work_done_condition_.notify_one();
|
||||
while (!predict_task_queue_->IsPredictTaskDone()) {
|
||||
auto task = predict_task_queue_->GetPredictTask(task_queue_id, this);
|
||||
|
|
|
@ -19,6 +19,11 @@
|
|||
namespace mindspore {
|
||||
PredictTaskQueue::~PredictTaskQueue() {
|
||||
if (predict_task_ != nullptr) {
|
||||
#ifdef USE_HQUEUE
|
||||
for (size_t i = 0; i < task_queue_num_; i++) {
|
||||
predict_task_[i].Clean();
|
||||
}
|
||||
#endif
|
||||
delete[] predict_task_;
|
||||
predict_task_ = nullptr;
|
||||
}
|
||||
|
@ -39,6 +44,7 @@ Status PredictTaskQueue::InitTaskQueue(size_t num, size_t max_queue_size) {
|
|||
return kLiteError;
|
||||
}
|
||||
#ifdef USE_HQUEUE
|
||||
task_queue_num_ = num;
|
||||
predict_task_ = new (std::nothrow) HQueue<PredictTask>[num]();
|
||||
if (predict_task_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new predict task failed.";
|
||||
|
|
|
@ -68,6 +68,7 @@ class PredictTaskQueue {
|
|||
#else
|
||||
std::queue<PredictTask *> *predict_task_;
|
||||
#endif
|
||||
size_t task_queue_num_ = -1;
|
||||
std::atomic_int *idle_worker_num_;
|
||||
std::mutex mtx_predict_task_;
|
||||
std::condition_variable task_pop_cond_;
|
||||
|
|
Loading…
Reference in New Issue