diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index cade91f2278..943f96f3306 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -61,6 +61,36 @@ int ArithmeticFP16CPUKernel::CheckDataType() { return RET_OK; } +bool ArithmeticFP16CPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1 + if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_func_ != nullptr)) { + return true; + } else { + return false; + } +} + +bool ArithmeticFP16CPUKernel::IsBatchScalarCalc() { + if (arithmetic_opt_func_ == nullptr) { + return false; + } + size_t break_axis = 0; + for (size_t i = 0; i < param_->ndim_; i++) { + if (param_->in_shape0_[i] != param_->in_shape1_[i]) { + break_axis = i; + break; + } + } + if (break_axis < param_->ndim_) { + for (size_t i = break_axis; i < param_->ndim_; i++) { + if (param_->in_shape1_[i] != 1) { + return false; + } + } + } + break_pos_ = break_axis; + return true; +} + void ArithmeticFP16CPUKernel::InitRunFunction(int primitive_type) { ARITHMETIC_FUNC_INFO_FP16 fun_table[] = { {PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16}, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h index 5b25ffc7dba..85295f246ed 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h @@ -40,6 +40,8 @@ class ArithmeticFP16CPUKernel : public ArithmeticCPUKernel { ~ArithmeticFP16CPUKernel() = default; int ReSize() override; int Run() override; + bool IsBatchScalarCalc() override; + bool IsScalarClac() override; private: void InitRunFunction(int primitive_type) override; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index bc7f35aa343..e4c79952c98 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -61,7 +61,7 @@ int ArithmeticCPUKernel::ReSize() { } } int ret = RET_OK; - if (!isScalarClac() && !isBatchScalarCalc() && !isBiasCalc()) { + if (!IsScalarClac() && !IsBatchScalarCalc() && !IsBiasCalc()) { ret = ConstTensorBroadCast(); } return ret; @@ -77,7 +77,7 @@ int ArithmeticCPUKernel::CheckDataType() { return RET_OK; } -bool ArithmeticCPUKernel::isScalarClac() { // 2 32 240 240, 1 1 1 1 +bool ArithmeticCPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1 if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) { return true; } else { @@ -85,7 +85,7 @@ bool ArithmeticCPUKernel::isScalarClac() { // 2 32 240 240, 1 1 1 1 } } -bool ArithmeticCPUKernel::isBatchScalarCalc() { // 2 32 240 240, 2 32 1 1 +bool ArithmeticCPUKernel::IsBatchScalarCalc() { // 2 32 240 240, 2 32 1 1 if (arithmetic_opt_run_ == nullptr) { return false; } @@ -107,7 +107,7 @@ bool ArithmeticCPUKernel::isBatchScalarCalc() { // 2 32 240 240, 2 32 1 1 return true; } -bool ArithmeticCPUKernel::isBiasCalc() { // 2 240 240 32, 1 1 1 32 +bool ArithmeticCPUKernel::IsBiasCalc() { // 2 240 240 32, 1 1 1 32 int last_shape0 = param_->in_shape0_[param_->ndim_ - 1]; int last_shape1 = param_->in_shape1_[param_->ndim_ - 1]; if (param_->in_elements_num0_ > param_->in_elements_num1_) { @@ -365,7 +365,7 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { } int offset = stride * task_id * data_type_len_; /* run opt function, one of input is scalar */ - if (isScalarClac()) { // 2 32 240 240, 1 1 1 1 + if (IsScalarClac()) { // 2 32 240 240, 1 1 1 1 if (param_->in_elements_num0_ == 1) { return Execute(input0_ptr_, static_cast(input1_ptr_) + offset, static_cast(output_ptr_) + offset, count, true); @@ -375,11 +375,11 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { } } /* run opt function, every batch one of input is scalar */ - if (isBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1 + if (IsBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1 return BatchScalarCalc(task_id); } /* each batch is eltwise calculation */ - if (isBiasCalc()) { // 2 240 240 32, 1 1 1 32 + if (IsBiasCalc()) { // 2 240 240 32, 1 1 1 32 return BiasCalc(task_id); } /* need broadcast in runtime */ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index 5801a58206f..3366d9fd437 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -97,6 +97,8 @@ class ArithmeticCPUKernel : public LiteKernel { virtual void TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape, const int *in_strides, const int *out_strides, const int *multiple); virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt); + virtual bool IsBatchScalarCalc(); + virtual bool IsScalarClac(); bool input0_broadcast_ = false; bool input1_broadcast_ = false; void *input0_ptr_ = nullptr; @@ -111,9 +113,7 @@ class ArithmeticCPUKernel : public LiteKernel { int BatchScalarCalc(int task_id); int BiasCalc(int task_id); void FreeConstTileBuff(); - bool isScalarClac(); - bool isBatchScalarCalc(); - bool isBiasCalc(); + bool IsBiasCalc(); ArithmeticRun arithmetic_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr; ArithmeticIntRun arithmetic_run_int_ = nullptr; diff --git a/mindspore/lite/src/runtime/parallel_executor.cc b/mindspore/lite/src/runtime/parallel_executor.cc index 18cc0860b19..3910c0ff872 100644 --- a/mindspore/lite/src/runtime/parallel_executor.cc +++ b/mindspore/lite/src/runtime/parallel_executor.cc @@ -18,11 +18,10 @@ #include "src/runtime/parallel_executor.h" #include "src/runtime/runtime_api.h" -#define MAX_THREAD_NUM 8 namespace mindspore::lite { ParallelExecutor::~ParallelExecutor() { DestroyThreadPool(thread_pool_); } int ParallelExecutor::Prepare(const std::vector &kernels) { - thread_pool_ = CreateLiteThreadPool(MAX_THREAD_NUM, NO_BIND); + thread_pool_ = CreateLiteThreadPool(max_thread_num_, NO_BIND); if (thread_pool_ == nullptr) { MS_LOG(ERROR) << "Memory error: fail to new ThreadPool"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/parallel_executor.h b/mindspore/lite/src/runtime/parallel_executor.h index cc428dd6997..8ca15ef1b0f 100644 --- a/mindspore/lite/src/runtime/parallel_executor.h +++ b/mindspore/lite/src/runtime/parallel_executor.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_ #include +#include #include #include "src/runtime/allocator.h" #include "src/lite_kernel.h" @@ -43,6 +44,7 @@ class ParallelExecutor : public Executor { std::vector readyKernels; std::vector results; struct ThreadPool *thread_pool_ = nullptr; + int max_thread_num_ = std::thread::hardware_concurrency(); }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/runtime/thread_pool.c b/mindspore/lite/src/runtime/thread_pool.c index ca4d1a05579..b4f0b34d4ee 100644 --- a/mindspore/lite/src/runtime/thread_pool.c +++ b/mindspore/lite/src/runtime/thread_pool.c @@ -21,10 +21,14 @@ #include #include #include +#include + +#ifdef __WIN32__ +#include +#endif #ifdef __ANDROID__ #define BIND_CORE -#include #include #endif #ifdef MS_COMPILE_IOS @@ -48,7 +52,6 @@ #define RET_TP_ERROR (-8) #define RET_TP_SYSTEM_ERROR (-1) -#define MAX_THREAD_NUM (200) #define DEFAULT_SPIN_COUNT (30000) typedef struct { @@ -831,8 +834,15 @@ int CreateNewThread(struct ThreadPool *thread_pool, int thread_id) { } ThreadPool *CreateThreadPool(int thread_num, int mode) { +#ifdef __WIN32__ + SYSTEM_INFO sys_info; + GetSystemInfo(&sys_info); + long max_thread_num = sys_info.dwNumberOfProcessors; +#else + long max_thread_num = sysconf(_SC_NPROCESSORS_ONLN); +#endif LOG_INFO("create thread pool, thread_num: %d, mode: %d", thread_num, mode); - if (thread_num <= 0 || thread_num > MAX_THREAD_NUM) { + if (thread_num <= 0 || thread_num > max_thread_num) { LOG_ERROR("invalid thread num: %d", thread_num); return NULL; } @@ -851,7 +861,7 @@ ThreadPool *CreateThreadPool(int thread_num, int mode) { LOG_ERROR("Malloc ThreadPool failed"); return NULL; } - thread_pool->thread_num = thread_num > MAX_THREAD_NUM ? MAX_THREAD_NUM : thread_num; + thread_pool->thread_num = thread_num > max_thread_num ? max_thread_num : thread_num; thread_pool->is_alive = ATOMIC_VAR_INIT(true); thread_pool->mode = mode; thread_pool->thread_list = NULL;