forked from mindspore-Ecosystem/mindspore
parallel predict support GPU && one Runner bind one model pool
This commit is contained in:
parent
c2a5cc1486
commit
63f1786356
|
@ -26,6 +26,7 @@ struct RunnerConfig {
|
|||
std::shared_ptr<Context> context = nullptr;
|
||||
int workers_num = 0;
|
||||
};
|
||||
class ModelPool;
|
||||
|
||||
/// \brief The ModelParallelRunner class is used to define a MindSpore ModelParallelRunner, facilitating Model
|
||||
/// management.
|
||||
|
@ -62,6 +63,9 @@ class MS_API ModelParallelRunner {
|
|||
/// \return Status.
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
|
||||
private:
|
||||
std::shared_ptr<ModelPool> model_pool_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H
|
||||
|
|
|
@ -19,7 +19,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
Status ModelParallelRunner::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config) {
|
||||
auto status = ModelPool::GetInstance()->Init(model_path, runner_config);
|
||||
model_pool_ = std::make_shared<ModelPool>();
|
||||
if (model_pool_ == nullptr) {
|
||||
MS_LOG(ERROR) << "model pool is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto status = model_pool_->Init(model_path, runner_config);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model runner init failed.";
|
||||
return kLiteError;
|
||||
|
@ -27,15 +32,9 @@ Status ModelParallelRunner::Init(const std::string &model_path, const std::share
|
|||
return status;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> ModelParallelRunner::GetInputs() {
|
||||
auto inputs = ModelPool::GetInstance()->GetInputs();
|
||||
return inputs;
|
||||
}
|
||||
std::vector<MSTensor> ModelParallelRunner::GetInputs() { return model_pool_->GetInputs(); }
|
||||
|
||||
std::vector<MSTensor> ModelParallelRunner::GetOutputs() {
|
||||
auto outputs = ModelPool::GetInstance()->GetOutputs();
|
||||
return outputs;
|
||||
}
|
||||
std::vector<MSTensor> ModelParallelRunner::GetOutputs() { return model_pool_->GetOutputs(); }
|
||||
|
||||
Status ModelParallelRunner::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
|
@ -43,7 +42,7 @@ Status ModelParallelRunner::Predict(const std::vector<MSTensor> &inputs, std::ve
|
|||
MS_LOG(ERROR) << "predict output is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto status = ModelPool::GetInstance()->Predict(inputs, outputs, before, after);
|
||||
auto status = model_pool_->Predict(inputs, outputs, before, after);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model runner predict failed.";
|
||||
return kLiteError;
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr int32_t kNumThreads = 4;
|
||||
constexpr int kNumDeviceInfo = 2;
|
||||
int GetCoreNum() {
|
||||
int core_num = 1;
|
||||
#if defined(_MSC_VER) || defined(_WIN32)
|
||||
|
@ -87,11 +88,6 @@ void ModelPool::SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_li
|
|||
}
|
||||
}
|
||||
|
||||
ModelPool *ModelPool::GetInstance() {
|
||||
static ModelPool instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
Status ModelPool::SetDefaultOptimalModelNum(const std::shared_ptr<mindspore::Context> &context) {
|
||||
if (use_numa_bind_mode_) {
|
||||
// now only supports the same number of cores per numa node
|
||||
|
@ -117,6 +113,10 @@ Status ModelPool::InitDefaultContext(const std::shared_ptr<mindspore::Context> &
|
|||
context->SetThreadAffinity(lite::HIGHER_CPU);
|
||||
auto &device_list = context->MutableDeviceInfo();
|
||||
auto device_info = std::make_shared<CPUDeviceInfo>();
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "device_info is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
device_info->SetEnableFP16(false);
|
||||
device_list.push_back(device_info);
|
||||
// set model num
|
||||
|
@ -135,39 +135,41 @@ std::shared_ptr<Context> ModelPool::InitUserDefineContext(const std::shared_ptr<
|
|||
return nullptr;
|
||||
}
|
||||
auto device_list = context->MutableDeviceInfo();
|
||||
if (device_list.size() != 1) {
|
||||
MS_LOG(ERROR) << "model pool only support device num 1.";
|
||||
if (device_list.size() > kNumDeviceInfo) {
|
||||
MS_LOG(ERROR) << "model pool only support device CPU or GPU.";
|
||||
return nullptr;
|
||||
}
|
||||
auto device = device_list.front();
|
||||
if (device->GetDeviceType() != kCPU && device->GetDeviceType() != kGPU) {
|
||||
MS_LOG(ERROR) << "model pool only support cpu or gpu type.";
|
||||
return nullptr;
|
||||
}
|
||||
auto cpu_context = device->Cast<CPUDeviceInfo>();
|
||||
auto enable_fp16 = cpu_context->GetEnableFP16();
|
||||
if (enable_fp16) {
|
||||
MS_LOG(ERROR) << "model pool not support enable fp16.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (device->GetDeviceType() == kGPU) {
|
||||
workers_num_ = 1;
|
||||
} else if (device->GetDeviceType() == kCPU) {
|
||||
if (runner_config->workers_num == 0) {
|
||||
// the user does not define the number of models, the default optimal number of models is used
|
||||
auto status = SetDefaultOptimalModelNum(context);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "SetDefaultOptimalModelNum failed.";
|
||||
for (size_t i = 0; i < device_list.size(); i++) {
|
||||
auto device = device_list[i];
|
||||
if (device->GetDeviceType() != kCPU && device->GetDeviceType() != kGPU) {
|
||||
MS_LOG(ERROR) << "model pool only support cpu or gpu type.";
|
||||
return nullptr;
|
||||
}
|
||||
if (device->GetDeviceType() == kGPU) {
|
||||
workers_num_ = 1;
|
||||
return context;
|
||||
} else if (device->GetDeviceType() == kCPU) {
|
||||
auto cpu_context = device->Cast<CPUDeviceInfo>();
|
||||
auto enable_fp16 = cpu_context->GetEnableFP16();
|
||||
if (enable_fp16) {
|
||||
MS_LOG(ERROR) << "model pool not support enable fp16.";
|
||||
return nullptr;
|
||||
}
|
||||
if (runner_config->workers_num == 0) {
|
||||
// the user does not define the number of models, the default optimal number of models is used
|
||||
auto status = SetDefaultOptimalModelNum(context);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "SetDefaultOptimalModelNum failed.";
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
// User defined number of models
|
||||
workers_num_ = runner_config->workers_num;
|
||||
}
|
||||
} else {
|
||||
// User defined number of models
|
||||
workers_num_ = runner_config->workers_num;
|
||||
MS_LOG(ERROR) << "not support device: " << device->GetDeviceType();
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "not support device: " << device->GetDeviceType();
|
||||
return nullptr;
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
@ -219,6 +221,11 @@ ModelPoolContex ModelPool::CreateModelContext(const std::shared_ptr<RunnerConfig
|
|||
MS_LOG(ERROR) << "Invalid thread num " << model_context->GetThreadNum();
|
||||
return {};
|
||||
}
|
||||
auto device_num = model_context->MutableDeviceInfo().size();
|
||||
if (device_num > 1) {
|
||||
used_numa_node_num_ = 1;
|
||||
return {model_context};
|
||||
}
|
||||
ModelPoolContex model_pool_context;
|
||||
std::vector<std::vector<int>> all_model_bind_list;
|
||||
if (model_context->GetThreadAffinityMode() == lite::HIGHER_CPU) {
|
||||
|
@ -248,6 +255,10 @@ ModelPoolContex ModelPool::CreateModelContext(const std::shared_ptr<RunnerConfig
|
|||
}
|
||||
auto &new_device_list = context->MutableDeviceInfo();
|
||||
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
|
||||
if (device_info == nullptr) {
|
||||
MS_LOG(ERROR) << "device_info is nullptr.";
|
||||
return {};
|
||||
}
|
||||
device_info->SetEnableFP16(false);
|
||||
new_device_list.push_back(device_info);
|
||||
model_pool_context.push_back(context);
|
||||
|
@ -272,15 +283,20 @@ std::vector<MSTensor> ModelPool::GetOutputs() {
|
|||
}
|
||||
|
||||
Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config) {
|
||||
predict_task_queue_ = std::make_shared<PredictTaskQueue>();
|
||||
if (predict_task_queue_ == nullptr) {
|
||||
MS_LOG(ERROR) << "create PredictTaskQueue failed, predict task queue is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto model_pool_context = CreateModelContext(runner_config);
|
||||
if (model_pool_context.empty()) {
|
||||
MS_LOG(ERROR) << "CreateModelContext failed, context is empty.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (use_numa_bind_mode_) {
|
||||
PredictTaskQueue::GetInstance()->SetTaskQueueNum(used_numa_node_num_);
|
||||
predict_task_queue_->SetTaskQueueNum(used_numa_node_num_);
|
||||
} else {
|
||||
PredictTaskQueue::GetInstance()->SetTaskQueueNum(1);
|
||||
predict_task_queue_->SetTaskQueueNum(1);
|
||||
}
|
||||
size_t size = 0;
|
||||
if (graph_buf_ != nullptr) {
|
||||
|
@ -317,8 +333,8 @@ Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<Runn
|
|||
MS_LOG(ERROR) << " model thread init failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
PredictTaskQueue::GetInstance()->IncreaseWaitModelNum(1, numa_node_id);
|
||||
model_worker_vec_.push_back(std::thread(&ModelWorker::Run, model_worker, numa_node_id));
|
||||
predict_task_queue_->IncreaseWaitModelNum(1, numa_node_id);
|
||||
model_worker_vec_.push_back(std::thread(&ModelWorker::Run, model_worker, numa_node_id, predict_task_queue_));
|
||||
}
|
||||
if (model_worker != nullptr) {
|
||||
model_inputs_ = model_worker->GetInputs();
|
||||
|
@ -497,9 +513,9 @@ Status ModelPool::FreeSplitTensor(std::vector<std::vector<MSTensor>> *new_inputs
|
|||
|
||||
void ModelPool::GetMaxWaitWorkerNum(int *max_wait_worker_node_id, int *max_wait_worker_num) {
|
||||
*max_wait_worker_node_id = 0;
|
||||
*max_wait_worker_num = PredictTaskQueue::GetInstance()->GetWaitModelNum(0);
|
||||
*max_wait_worker_num = predict_task_queue_->GetWaitModelNum(0);
|
||||
for (int i = 1; i < used_numa_node_num_; i++) {
|
||||
int worker_num = PredictTaskQueue::GetInstance()->GetWaitModelNum(i);
|
||||
int worker_num = predict_task_queue_->GetWaitModelNum(i);
|
||||
if (*max_wait_worker_num < worker_num) {
|
||||
*max_wait_worker_num = worker_num;
|
||||
*max_wait_worker_node_id = i;
|
||||
|
@ -515,10 +531,10 @@ Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
|
|||
GetMaxWaitWorkerNum(&max_wait_worker_node_id, &max_wait_worker_num);
|
||||
|
||||
auto batch = inputs[0].Shape()[0];
|
||||
if (PredictTaskQueue::GetInstance()->GetTaskNum(max_wait_worker_node_id) == 0 && max_wait_worker_num > 1 &&
|
||||
if (predict_task_queue_->GetTaskNum(max_wait_worker_node_id) == 0 && max_wait_worker_num > 1 &&
|
||||
batch >= max_wait_worker_num) {
|
||||
size_t batch_split_num = PredictTaskQueue::GetInstance()->GetWaitModelNum(max_wait_worker_node_id);
|
||||
PredictTaskQueue::GetInstance()->DecreaseWaitModelNum(batch_split_num, max_wait_worker_node_id);
|
||||
size_t batch_split_num = predict_task_queue_->GetWaitModelNum(max_wait_worker_node_id);
|
||||
predict_task_queue_->DecreaseWaitModelNum(batch_split_num, max_wait_worker_node_id);
|
||||
std::vector<std::vector<MSTensor>> new_inputs;
|
||||
std::vector<std::vector<MSTensor>> new_outputs;
|
||||
auto status = SplitInputTensorByBatch(inputs, &new_inputs, batch_split_num);
|
||||
|
@ -535,12 +551,16 @@ Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
|
|||
std::vector<std::shared_ptr<PredictTask>> tasks;
|
||||
for (size_t i = 0; i < batch_split_num; i++) {
|
||||
auto predict_task = std::make_shared<PredictTask>(&new_inputs[i], &new_outputs.at(i), before, after);
|
||||
PredictTaskQueue::GetInstance()->PushPredictTask(predict_task, max_wait_worker_node_id);
|
||||
if (predict_task == nullptr) {
|
||||
MS_LOG(ERROR) << "predict task is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
predict_task_queue_->PushPredictTask(predict_task, max_wait_worker_node_id);
|
||||
tasks.push_back(predict_task);
|
||||
}
|
||||
mtx_split_task_.unlock();
|
||||
for (size_t i = 0; i < batch_split_num; i++) {
|
||||
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(tasks[i]);
|
||||
predict_task_queue_->WaitUntilPredictActive(tasks[i]);
|
||||
}
|
||||
status = ConcatPredictOutput(&new_outputs, outputs);
|
||||
if (status != kSuccess) {
|
||||
|
@ -552,27 +572,32 @@ Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
|
|||
MS_LOG(ERROR) << "free split tensor failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
PredictTaskQueue::GetInstance()->IncreaseWaitModelNum(batch_split_num, max_wait_worker_node_id);
|
||||
predict_task_queue_->IncreaseWaitModelNum(batch_split_num, max_wait_worker_node_id);
|
||||
} else {
|
||||
PredictTaskQueue::GetInstance()->DecreaseWaitModelNum(1, max_wait_worker_node_id);
|
||||
predict_task_queue_->DecreaseWaitModelNum(1, max_wait_worker_node_id);
|
||||
auto predict_task = std::make_shared<PredictTask>(&inputs, outputs, before, after);
|
||||
PredictTaskQueue::GetInstance()->PushPredictTask(predict_task, max_wait_worker_node_id);
|
||||
if (predict_task == nullptr) {
|
||||
MS_LOG(ERROR) << "predict_task is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
predict_task_queue_->PushPredictTask(predict_task, max_wait_worker_node_id);
|
||||
mtx_split_task_.unlock();
|
||||
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(predict_task);
|
||||
PredictTaskQueue::GetInstance()->IncreaseWaitModelNum(1, max_wait_worker_node_id);
|
||||
predict_task_queue_->WaitUntilPredictActive(predict_task);
|
||||
predict_task_queue_->IncreaseWaitModelNum(1, max_wait_worker_node_id);
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
ModelPool::~ModelPool() {
|
||||
if (graph_buf_ != nullptr) {
|
||||
delete[] graph_buf_;
|
||||
graph_buf_ = nullptr;
|
||||
}
|
||||
predict_task_queue_->SetPredictTaskDone();
|
||||
for (auto &th : model_worker_vec_) {
|
||||
if (th.joinable()) {
|
||||
th.join();
|
||||
}
|
||||
}
|
||||
if (graph_buf_ != nullptr) {
|
||||
delete[] graph_buf_;
|
||||
graph_buf_ = nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,7 +31,8 @@ using ModelPoolContex = std::vector<std::shared_ptr<Context>>;
|
|||
|
||||
class ModelPool {
|
||||
public:
|
||||
static ModelPool *GetInstance();
|
||||
ModelPool() = default;
|
||||
|
||||
~ModelPool();
|
||||
|
||||
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr);
|
||||
|
@ -44,8 +45,6 @@ class ModelPool {
|
|||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
|
||||
private:
|
||||
ModelPool() = default;
|
||||
|
||||
ModelPoolContex CreateModelContext(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
std::shared_ptr<Context> InitContext(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
|
||||
|
@ -67,6 +66,7 @@ class ModelPool {
|
|||
void GetMaxWaitWorkerNum(int *max_wait_worker_node_id, int *max_wait_worker_num);
|
||||
|
||||
std::vector<std::thread> model_worker_vec_;
|
||||
std::vector<std::shared_ptr<ModelWorker>> model_workers_;
|
||||
std::vector<MSTensor> model_inputs_;
|
||||
std::vector<MSTensor> model_outputs_;
|
||||
char *graph_buf_ = nullptr;
|
||||
|
@ -76,6 +76,8 @@ class ModelPool {
|
|||
int numa_node_num_ = 1;
|
||||
int used_numa_node_num_ = 0;
|
||||
bool use_numa_bind_mode_ = false;
|
||||
bool use_gpu_ = false;
|
||||
std::shared_ptr<PredictTaskQueue> predict_task_queue_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_POOL_H_
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
#include "src/common/utils.h"
|
||||
#include "src/common/common.h"
|
||||
namespace mindspore {
|
||||
void ModelWorker::Run(int node_id) {
|
||||
while (!PredictTaskQueue::GetInstance()->IsPredictTaskDone()) {
|
||||
auto task = PredictTaskQueue::GetInstance()->GetPredictTask(node_id);
|
||||
void ModelWorker::Run(int node_id, const std::shared_ptr<PredictTaskQueue> &predict_task_queue) {
|
||||
while (!predict_task_queue->IsPredictTaskDone()) {
|
||||
auto task = predict_task_queue->GetPredictTask(node_id);
|
||||
if (task == nullptr) {
|
||||
break;
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ void ModelWorker::Run(int node_id) {
|
|||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model predict failed.";
|
||||
task->ready = true;
|
||||
PredictTaskQueue::GetInstance()->ActiveTask();
|
||||
predict_task_queue->ActiveTask();
|
||||
continue;
|
||||
}
|
||||
if (need_copy_output_) {
|
||||
|
@ -45,7 +45,7 @@ void ModelWorker::Run(int node_id) {
|
|||
if (copy_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "model thread copy output tensor failed.";
|
||||
task->ready = true;
|
||||
PredictTaskQueue::GetInstance()->ActiveTask();
|
||||
predict_task_queue->ActiveTask();
|
||||
continue;
|
||||
}
|
||||
new_outputs.push_back(*copy_tensor);
|
||||
|
@ -55,7 +55,7 @@ void ModelWorker::Run(int node_id) {
|
|||
outputs->insert(outputs->end(), new_outputs.begin(), new_outputs.end());
|
||||
}
|
||||
task->ready = true;
|
||||
PredictTaskQueue::GetInstance()->ActiveTask();
|
||||
predict_task_queue->ActiveTask();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ class ModelWorker {
|
|||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
|
||||
void Run(int node_id);
|
||||
void Run(int node_id, const std::shared_ptr<PredictTaskQueue> &predict_task_queue);
|
||||
|
||||
private:
|
||||
std::pair<std::vector<std::vector<int64_t>>, bool> GetModelResize(const std::vector<MSTensor> &model_inputs,
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "src/cxx_api/model_pool/predict_task_queue.h"
|
||||
namespace mindspore {
|
||||
PredictTaskQueue::~PredictTaskQueue() {
|
||||
void PredictTaskQueue::SetPredictTaskDone() {
|
||||
predict_task_done_ = true;
|
||||
task_push_cond_.notify_all();
|
||||
}
|
||||
|
@ -36,11 +36,6 @@ void PredictTaskQueue::WaitUntilPredictActive(const std::shared_ptr<PredictTask>
|
|||
|
||||
void PredictTaskQueue::ActiveTask() { task_pop_cond_.notify_all(); }
|
||||
|
||||
PredictTaskQueue *PredictTaskQueue::GetInstance() {
|
||||
static PredictTaskQueue instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
void PredictTaskQueue::PushPredictTask(std::shared_ptr<PredictTask> task, int node_id) {
|
||||
std::unique_lock<std::mutex> task_lock(mtx_predict_task_);
|
||||
predict_task_.at(node_id).push(task);
|
||||
|
|
|
@ -37,8 +37,8 @@ struct PredictTask {
|
|||
|
||||
class PredictTaskQueue {
|
||||
public:
|
||||
static PredictTaskQueue *GetInstance();
|
||||
~PredictTaskQueue();
|
||||
PredictTaskQueue() = default;
|
||||
~PredictTaskQueue() = default;
|
||||
|
||||
void PushPredictTask(std::shared_ptr<PredictTask> task, int node_id);
|
||||
void WaitUntilPredictActive(const std::shared_ptr<PredictTask> &task);
|
||||
|
@ -47,13 +47,13 @@ class PredictTaskQueue {
|
|||
int GetTaskNum(int node_id);
|
||||
void SetTaskQueueNum(int num);
|
||||
|
||||
bool IsPredictTaskDone() { return predict_task_done_; }
|
||||
int GetWaitModelNum(int node_id) { return waite_worker_num_.at(node_id); }
|
||||
bool IsPredictTaskDone() const { return predict_task_done_; }
|
||||
void SetPredictTaskDone();
|
||||
int GetWaitModelNum(int node_id) const { return waite_worker_num_.at(node_id); }
|
||||
void DecreaseWaitModelNum(int num, int node_id) { waite_worker_num_.at(node_id) -= num; }
|
||||
void IncreaseWaitModelNum(int num, int node_id) { waite_worker_num_.at(node_id) += num; }
|
||||
|
||||
private:
|
||||
PredictTaskQueue() = default;
|
||||
std::vector<std::queue<std::shared_ptr<PredictTask>>> predict_task_;
|
||||
std::vector<int> waite_worker_num_;
|
||||
std::mutex mtx_predict_task_;
|
||||
|
|
Loading…
Reference in New Issue