forked from mindspore-Ecosystem/mindspore
[MS][LITE] add check input thread_num exceed core_num
This commit is contained in:
parent
495a852c1e
commit
ab7b8d6bdd
|
@ -16,6 +16,7 @@
|
|||
#include "src/extendrt/cxx_api/model_pool/model_pool.h"
|
||||
#include <unistd.h>
|
||||
#include <future>
|
||||
#include <algorithm>
|
||||
#include "mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/lite_types.h"
|
||||
|
@ -342,24 +343,43 @@ Status ModelPool::CheckAffinityCoreList(const std::shared_ptr<RunnerConfig> &run
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<RunnerConfig> &runner_config) {
|
||||
Status ModelPool::CheckThreadNum(const std::shared_ptr<RunnerConfig> &runner_config) {
|
||||
auto context = runner_config->GetContext();
|
||||
MS_CHECK_TRUE_MSG(context != nullptr, nullptr, "user set config context nullptr.");
|
||||
MS_CHECK_TRUE_MSG(context != nullptr, kLiteError, "user set config context nullptr.");
|
||||
auto user_thread_num = context->GetThreadNum();
|
||||
if (user_thread_num < 0) {
|
||||
MS_LOG(ERROR) << "Invalid thread num " << context->GetThreadNum();
|
||||
return nullptr;
|
||||
return kLiteError;
|
||||
}
|
||||
int core_num = static_cast<int>(std::max<size_t>(1, std::thread::hardware_concurrency()));
|
||||
int core_num_times = 5;
|
||||
int Threshold_thread_num = core_num_times * core_num;
|
||||
if (user_thread_num > Threshold_thread_num) {
|
||||
MS_LOG(WARNING) << "Thread num: " << user_thread_num << " is more than 5 times core num: " << Threshold_thread_num
|
||||
<< ", change it to 5 times core num. Please check whether Thread num is reasonable.";
|
||||
user_thread_num = Threshold_thread_num;
|
||||
}
|
||||
if (user_thread_num == 0) {
|
||||
// Defaults are automatically adjusted based on computer performance
|
||||
auto default_thread_num = GetDefaultThreadNum();
|
||||
if (default_thread_num == 0) {
|
||||
MS_LOG(ERROR) << "computer thread num failed.";
|
||||
return nullptr;
|
||||
return kLiteError;
|
||||
}
|
||||
context->SetThreadNum(default_thread_num);
|
||||
}
|
||||
auto status = CheckAffinityCoreList(runner_config);
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
std::shared_ptr<Context> ModelPool::GetUserDefineContext(const std::shared_ptr<RunnerConfig> &runner_config) {
|
||||
auto context = runner_config->GetContext();
|
||||
MS_CHECK_TRUE_MSG(context != nullptr, nullptr, "user set config context nullptr.");
|
||||
auto status = CheckThreadNum(runner_config);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "user set thread num failed.";
|
||||
return nullptr;
|
||||
}
|
||||
status = CheckAffinityCoreList(runner_config);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "user set core list failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -121,6 +121,8 @@ class ModelPool {
|
|||
|
||||
Status CheckAffinityCoreList(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
|
||||
Status CheckThreadNum(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
|
||||
Status InitAdvancedStrategy(const char *model_buf, size_t size, int base_thread_num);
|
||||
|
||||
Status InitBaseStrategy(const char *model_buf, size_t size, const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
|
|
|
@ -193,7 +193,7 @@ InnerContext::~InnerContext() {
|
|||
}
|
||||
}
|
||||
|
||||
int InnerContext::IsValid() const {
|
||||
int InnerContext::IsValid() {
|
||||
if (this->device_list_.empty()) {
|
||||
MS_LOG(ERROR) << "Device list is empty.";
|
||||
return RET_NOT_SUPPORT;
|
||||
|
@ -206,6 +206,14 @@ int InnerContext::IsValid() const {
|
|||
MS_LOG(ERROR) << "Thread num smaller than 1 is not allowed.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
int core_num = static_cast<int>(std::max<size_t>(1, std::thread::hardware_concurrency()));
|
||||
int core_num_times = 5;
|
||||
int Threshold_thread_num = core_num_times * core_num;
|
||||
if (thread_num_ > Threshold_thread_num) {
|
||||
MS_LOG(WARNING) << "Thread num: " << thread_num_ << " is more than 5 times core num: " << Threshold_thread_num
|
||||
<< ", change it to 5 times core num. Please check whether Thread num is reasonable.";
|
||||
thread_num_ = Threshold_thread_num;
|
||||
}
|
||||
|
||||
if (inter_op_parallel_num_ < 1) {
|
||||
MS_LOG(ERROR) << "InterOpParallelNum smaller than 1 is not allowed.";
|
||||
|
|
|
@ -57,7 +57,7 @@ struct InnerContext : public Context {
|
|||
|
||||
DeviceInfo GetDeviceInfo(DeviceType type) const;
|
||||
|
||||
int IsValid() const;
|
||||
int IsValid();
|
||||
|
||||
ThreadPool *thread_pool() const;
|
||||
|
||||
|
|
Loading…
Reference in New Issue