max thread num

This commit is contained in:
lzk 2021-03-18 01:42:01 -07:00
parent a947e65ccd
commit 5ed08ebe51
7 changed files with 59 additions and 16 deletions

View File

@ -61,6 +61,36 @@ int ArithmeticFP16CPUKernel::CheckDataType() {
return RET_OK; return RET_OK;
} }
bool ArithmeticFP16CPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_func_ != nullptr)) {
return true;
} else {
return false;
}
}
bool ArithmeticFP16CPUKernel::IsBatchScalarCalc() {
if (arithmetic_opt_func_ == nullptr) {
return false;
}
size_t break_axis = 0;
for (size_t i = 0; i < param_->ndim_; i++) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_axis = i;
break;
}
}
if (break_axis < param_->ndim_) {
for (size_t i = break_axis; i < param_->ndim_; i++) {
if (param_->in_shape1_[i] != 1) {
return false;
}
}
}
break_pos_ = break_axis;
return true;
}
void ArithmeticFP16CPUKernel::InitRunFunction(int primitive_type) { void ArithmeticFP16CPUKernel::InitRunFunction(int primitive_type) {
ARITHMETIC_FUNC_INFO_FP16 fun_table[] = { ARITHMETIC_FUNC_INFO_FP16 fun_table[] = {
{PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16}, {PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16},

View File

@ -40,6 +40,8 @@ class ArithmeticFP16CPUKernel : public ArithmeticCPUKernel {
~ArithmeticFP16CPUKernel() = default; ~ArithmeticFP16CPUKernel() = default;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
bool IsBatchScalarCalc() override;
bool IsScalarClac() override;
private: private:
void InitRunFunction(int primitive_type) override; void InitRunFunction(int primitive_type) override;

View File

@ -61,7 +61,7 @@ int ArithmeticCPUKernel::ReSize() {
} }
} }
int ret = RET_OK; int ret = RET_OK;
if (!isScalarClac() && !isBatchScalarCalc() && !isBiasCalc()) { if (!IsScalarClac() && !IsBatchScalarCalc() && !IsBiasCalc()) {
ret = ConstTensorBroadCast(); ret = ConstTensorBroadCast();
} }
return ret; return ret;
@ -77,7 +77,7 @@ int ArithmeticCPUKernel::CheckDataType() {
return RET_OK; return RET_OK;
} }
bool ArithmeticCPUKernel::isScalarClac() { // 2 32 240 240, 1 1 1 1 bool ArithmeticCPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) { if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) {
return true; return true;
} else { } else {
@ -85,7 +85,7 @@ bool ArithmeticCPUKernel::isScalarClac() { // 2 32 240 240, 1 1 1 1
} }
} }
bool ArithmeticCPUKernel::isBatchScalarCalc() { // 2 32 240 240, 2 32 1 1 bool ArithmeticCPUKernel::IsBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
if (arithmetic_opt_run_ == nullptr) { if (arithmetic_opt_run_ == nullptr) {
return false; return false;
} }
@ -107,7 +107,7 @@ bool ArithmeticCPUKernel::isBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
return true; return true;
} }
bool ArithmeticCPUKernel::isBiasCalc() { // 2 240 240 32, 1 1 1 32 bool ArithmeticCPUKernel::IsBiasCalc() { // 2 240 240 32, 1 1 1 32
int last_shape0 = param_->in_shape0_[param_->ndim_ - 1]; int last_shape0 = param_->in_shape0_[param_->ndim_ - 1];
int last_shape1 = param_->in_shape1_[param_->ndim_ - 1]; int last_shape1 = param_->in_shape1_[param_->ndim_ - 1];
if (param_->in_elements_num0_ > param_->in_elements_num1_) { if (param_->in_elements_num0_ > param_->in_elements_num1_) {
@ -365,7 +365,7 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
} }
int offset = stride * task_id * data_type_len_; int offset = stride * task_id * data_type_len_;
/* run opt function, one of input is scalar */ /* run opt function, one of input is scalar */
if (isScalarClac()) { // 2 32 240 240, 1 1 1 1 if (IsScalarClac()) { // 2 32 240 240, 1 1 1 1
if (param_->in_elements_num0_ == 1) { if (param_->in_elements_num0_ == 1) {
return Execute(input0_ptr_, static_cast<uint8_t *>(input1_ptr_) + offset, return Execute(input0_ptr_, static_cast<uint8_t *>(input1_ptr_) + offset,
static_cast<uint8_t *>(output_ptr_) + offset, count, true); static_cast<uint8_t *>(output_ptr_) + offset, count, true);
@ -375,11 +375,11 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
} }
} }
/* run opt function, every batch one of input is scalar */ /* run opt function, every batch one of input is scalar */
if (isBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1 if (IsBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1
return BatchScalarCalc(task_id); return BatchScalarCalc(task_id);
} }
/* each batch is eltwise calculation */ /* each batch is eltwise calculation */
if (isBiasCalc()) { // 2 240 240 32, 1 1 1 32 if (IsBiasCalc()) { // 2 240 240 32, 1 1 1 32
return BiasCalc(task_id); return BiasCalc(task_id);
} }
/* need broadcast in runtime */ /* need broadcast in runtime */

View File

@ -97,6 +97,8 @@ class ArithmeticCPUKernel : public LiteKernel {
virtual void TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape, virtual void TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape,
const int *in_strides, const int *out_strides, const int *multiple); const int *in_strides, const int *out_strides, const int *multiple);
virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt); virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt);
virtual bool IsBatchScalarCalc();
virtual bool IsScalarClac();
bool input0_broadcast_ = false; bool input0_broadcast_ = false;
bool input1_broadcast_ = false; bool input1_broadcast_ = false;
void *input0_ptr_ = nullptr; void *input0_ptr_ = nullptr;
@ -111,9 +113,7 @@ class ArithmeticCPUKernel : public LiteKernel {
int BatchScalarCalc(int task_id); int BatchScalarCalc(int task_id);
int BiasCalc(int task_id); int BiasCalc(int task_id);
void FreeConstTileBuff(); void FreeConstTileBuff();
bool isScalarClac(); bool IsBiasCalc();
bool isBatchScalarCalc();
bool isBiasCalc();
ArithmeticRun arithmetic_run_ = nullptr; ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr;
ArithmeticIntRun arithmetic_run_int_ = nullptr; ArithmeticIntRun arithmetic_run_int_ = nullptr;

View File

@ -18,11 +18,10 @@
#include "src/runtime/parallel_executor.h" #include "src/runtime/parallel_executor.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#define MAX_THREAD_NUM 8
namespace mindspore::lite { namespace mindspore::lite {
ParallelExecutor::~ParallelExecutor() { DestroyThreadPool(thread_pool_); } ParallelExecutor::~ParallelExecutor() { DestroyThreadPool(thread_pool_); }
int ParallelExecutor::Prepare(const std::vector<mindspore::kernel::LiteKernel *> &kernels) { int ParallelExecutor::Prepare(const std::vector<mindspore::kernel::LiteKernel *> &kernels) {
thread_pool_ = CreateLiteThreadPool(MAX_THREAD_NUM, NO_BIND); thread_pool_ = CreateLiteThreadPool(max_thread_num_, NO_BIND);
if (thread_pool_ == nullptr) { if (thread_pool_ == nullptr) {
MS_LOG(ERROR) << "Memory error: fail to new ThreadPool"; MS_LOG(ERROR) << "Memory error: fail to new ThreadPool";
return RET_ERROR; return RET_ERROR;

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_ #define MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_
#include <vector> #include <vector>
#include <thread>
#include <unordered_map> #include <unordered_map>
#include "src/runtime/allocator.h" #include "src/runtime/allocator.h"
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
@ -43,6 +44,7 @@ class ParallelExecutor : public Executor {
std::vector<kernel::LiteKernel *> readyKernels; std::vector<kernel::LiteKernel *> readyKernels;
std::vector<int> results; std::vector<int> results;
struct ThreadPool *thread_pool_ = nullptr; struct ThreadPool *thread_pool_ = nullptr;
int max_thread_num_ = std::thread::hardware_concurrency();
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -21,10 +21,14 @@
#include <semaphore.h> #include <semaphore.h>
#include <string.h> #include <string.h>
#include <stdlib.h> #include <stdlib.h>
#include <unistd.h>
#ifdef __WIN32__
#include <windows.h>
#endif
#ifdef __ANDROID__ #ifdef __ANDROID__
#define BIND_CORE #define BIND_CORE
#include <unistd.h>
#include <sched.h> #include <sched.h>
#endif #endif
#ifdef MS_COMPILE_IOS #ifdef MS_COMPILE_IOS
@ -48,7 +52,6 @@
#define RET_TP_ERROR (-8) #define RET_TP_ERROR (-8)
#define RET_TP_SYSTEM_ERROR (-1) #define RET_TP_SYSTEM_ERROR (-1)
#define MAX_THREAD_NUM (200)
#define DEFAULT_SPIN_COUNT (30000) #define DEFAULT_SPIN_COUNT (30000)
typedef struct { typedef struct {
@ -831,8 +834,15 @@ int CreateNewThread(struct ThreadPool *thread_pool, int thread_id) {
} }
ThreadPool *CreateThreadPool(int thread_num, int mode) { ThreadPool *CreateThreadPool(int thread_num, int mode) {
#ifdef __WIN32__
SYSTEM_INFO sys_info;
GetSystemInfo(&sys_info);
long max_thread_num = sys_info.dwNumberOfProcessors;
#else
long max_thread_num = sysconf(_SC_NPROCESSORS_ONLN);
#endif
LOG_INFO("create thread pool, thread_num: %d, mode: %d", thread_num, mode); LOG_INFO("create thread pool, thread_num: %d, mode: %d", thread_num, mode);
if (thread_num <= 0 || thread_num > MAX_THREAD_NUM) { if (thread_num <= 0 || thread_num > max_thread_num) {
LOG_ERROR("invalid thread num: %d", thread_num); LOG_ERROR("invalid thread num: %d", thread_num);
return NULL; return NULL;
} }
@ -851,7 +861,7 @@ ThreadPool *CreateThreadPool(int thread_num, int mode) {
LOG_ERROR("Malloc ThreadPool failed"); LOG_ERROR("Malloc ThreadPool failed");
return NULL; return NULL;
} }
thread_pool->thread_num = thread_num > MAX_THREAD_NUM ? MAX_THREAD_NUM : thread_num; thread_pool->thread_num = thread_num > max_thread_num ? max_thread_num : thread_num;
thread_pool->is_alive = ATOMIC_VAR_INIT(true); thread_pool->is_alive = ATOMIC_VAR_INIT(true);
thread_pool->mode = mode; thread_pool->mode = mode;
thread_pool->thread_list = NULL; thread_pool->thread_list = NULL;