[MS][LITE] add check input thread_num exceed core_num

This commit is contained in:
luoyuan 2022-08-16 15:09:20 +08:00
parent 495a852c1e
commit ab7b8d6bdd
4 changed files with 37 additions and 7 deletions

View File

@ -16,6 +16,7 @@
#include "src/extendrt/cxx_api/model_pool/model_pool.h"
#include <unistd.h>
#include <future>
#include <algorithm>
#include "mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h"
#include "src/common/log_adapter.h"
#include "include/lite_types.h"
@ -342,24 +343,43 @@ Status ModelPool::CheckAffinityCoreList(const std::shared_ptr<RunnerConfig> &run
return kSuccess;
}
std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<RunnerConfig> &runner_config) {
Status ModelPool::CheckThreadNum(const std::shared_ptr<RunnerConfig> &runner_config) {
auto context = runner_config->GetContext();
MS_CHECK_TRUE_MSG(context != nullptr, nullptr, "user set config context nullptr.");
MS_CHECK_TRUE_MSG(context != nullptr, kLiteError, "user set config context nullptr.");
auto user_thread_num = context->GetThreadNum();
if (user_thread_num < 0) {
MS_LOG(ERROR) << "Invalid thread num " << context->GetThreadNum();
return nullptr;
return kLiteError;
}
int core_num = static_cast<int>(std::max<size_t>(1, std::thread::hardware_concurrency()));
int core_num_times = 5;
int Threshold_thread_num = core_num_times * core_num;
if (user_thread_num > Threshold_thread_num) {
MS_LOG(WARNING) << "Thread num: " << user_thread_num << " is more than 5 times core num: " << Threshold_thread_num
<< ", change it to 5 times core num. Please check whether Thread num is reasonable.";
user_thread_num = Threshold_thread_num;
}
if (user_thread_num == 0) {
// Defaults are automatically adjusted based on computer performance
auto default_thread_num = GetDefaultThreadNum();
if (default_thread_num == 0) {
MS_LOG(ERROR) << "computer thread num failed.";
return nullptr;
return kLiteError;
}
context->SetThreadNum(default_thread_num);
}
auto status = CheckAffinityCoreList(runner_config);
return kSuccess;
}
std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<RunnerConfig> &runner_config) {
auto context = runner_config->GetContext();
MS_CHECK_TRUE_MSG(context != nullptr, nullptr, "user set config context nullptr.");
auto status = CheckThreadNum(runner_config);
if (status != kSuccess) {
MS_LOG(ERROR) << "user set thread num failed.";
return nullptr;
}
status = CheckAffinityCoreList(runner_config);
if (status != kSuccess) {
MS_LOG(ERROR) << "user set core list failed.";
return nullptr;

View File

@ -121,6 +121,8 @@ class ModelPool {
Status CheckAffinityCoreList(const std::shared_ptr<RunnerConfig> &runner_config);
Status CheckThreadNum(const std::shared_ptr<RunnerConfig> &runner_config);
Status InitAdvancedStrategy(const char *model_buf, size_t size, int base_thread_num);
Status InitBaseStrategy(const char *model_buf, size_t size, const std::shared_ptr<RunnerConfig> &runner_config);

View File

@ -193,7 +193,7 @@ InnerContext::~InnerContext() {
}
}
int InnerContext::IsValid() const {
int InnerContext::IsValid() {
if (this->device_list_.empty()) {
MS_LOG(ERROR) << "Device list is empty.";
return RET_NOT_SUPPORT;
@ -206,6 +206,14 @@ int InnerContext::IsValid() const {
MS_LOG(ERROR) << "Thread num smaller than 1 is not allowed.";
return RET_NOT_SUPPORT;
}
int core_num = static_cast<int>(std::max<size_t>(1, std::thread::hardware_concurrency()));
int core_num_times = 5;
int Threshold_thread_num = core_num_times * core_num;
if (thread_num_ > Threshold_thread_num) {
MS_LOG(WARNING) << "Thread num: " << thread_num_ << " is more than 5 times core num: " << Threshold_thread_num
<< ", change it to 5 times core num. Please check whether Thread num is reasonable.";
thread_num_ = Threshold_thread_num;
}
if (inter_op_parallel_num_ < 1) {
MS_LOG(ERROR) << "InterOpParallelNum smaller than 1 is not allowed.";

View File

@ -57,7 +57,7 @@ struct InnerContext : public Context {
DeviceInfo GetDeviceInfo(DeviceType type) const;
int IsValid() const;
int IsValid();
ThreadPool *thread_pool() const;