forked from mindspore-Ecosystem/mindspore
max thread num
This commit is contained in:
parent
a947e65ccd
commit
5ed08ebe51
|
@ -61,6 +61,36 @@ int ArithmeticFP16CPUKernel::CheckDataType() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool ArithmeticFP16CPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1
|
||||
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_func_ != nullptr)) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool ArithmeticFP16CPUKernel::IsBatchScalarCalc() {
|
||||
if (arithmetic_opt_func_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
size_t break_axis = 0;
|
||||
for (size_t i = 0; i < param_->ndim_; i++) {
|
||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||
break_axis = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (break_axis < param_->ndim_) {
|
||||
for (size_t i = break_axis; i < param_->ndim_; i++) {
|
||||
if (param_->in_shape1_[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
break_pos_ = break_axis;
|
||||
return true;
|
||||
}
|
||||
|
||||
void ArithmeticFP16CPUKernel::InitRunFunction(int primitive_type) {
|
||||
ARITHMETIC_FUNC_INFO_FP16 fun_table[] = {
|
||||
{PrimitiveType_MulFusion, schema::ActivationType_RELU, ElementMulReluFp16, ElementOptMulReluFp16},
|
||||
|
|
|
@ -40,6 +40,8 @@ class ArithmeticFP16CPUKernel : public ArithmeticCPUKernel {
|
|||
~ArithmeticFP16CPUKernel() = default;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
bool IsBatchScalarCalc() override;
|
||||
bool IsScalarClac() override;
|
||||
|
||||
private:
|
||||
void InitRunFunction(int primitive_type) override;
|
||||
|
|
|
@ -61,7 +61,7 @@ int ArithmeticCPUKernel::ReSize() {
|
|||
}
|
||||
}
|
||||
int ret = RET_OK;
|
||||
if (!isScalarClac() && !isBatchScalarCalc() && !isBiasCalc()) {
|
||||
if (!IsScalarClac() && !IsBatchScalarCalc() && !IsBiasCalc()) {
|
||||
ret = ConstTensorBroadCast();
|
||||
}
|
||||
return ret;
|
||||
|
@ -77,7 +77,7 @@ int ArithmeticCPUKernel::CheckDataType() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::isScalarClac() { // 2 32 240 240, 1 1 1 1
|
||||
bool ArithmeticCPUKernel::IsScalarClac() { // 2 32 240 240, 1 1 1 1
|
||||
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) {
|
||||
return true;
|
||||
} else {
|
||||
|
@ -85,7 +85,7 @@ bool ArithmeticCPUKernel::isScalarClac() { // 2 32 240 240, 1 1 1 1
|
|||
}
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::isBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
|
||||
bool ArithmeticCPUKernel::IsBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
|
||||
if (arithmetic_opt_run_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ bool ArithmeticCPUKernel::isBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::isBiasCalc() { // 2 240 240 32, 1 1 1 32
|
||||
bool ArithmeticCPUKernel::IsBiasCalc() { // 2 240 240 32, 1 1 1 32
|
||||
int last_shape0 = param_->in_shape0_[param_->ndim_ - 1];
|
||||
int last_shape1 = param_->in_shape1_[param_->ndim_ - 1];
|
||||
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
|
||||
|
@ -365,7 +365,7 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
}
|
||||
int offset = stride * task_id * data_type_len_;
|
||||
/* run opt function, one of input is scalar */
|
||||
if (isScalarClac()) { // 2 32 240 240, 1 1 1 1
|
||||
if (IsScalarClac()) { // 2 32 240 240, 1 1 1 1
|
||||
if (param_->in_elements_num0_ == 1) {
|
||||
return Execute(input0_ptr_, static_cast<uint8_t *>(input1_ptr_) + offset,
|
||||
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
|
||||
|
@ -375,11 +375,11 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
}
|
||||
}
|
||||
/* run opt function, every batch one of input is scalar */
|
||||
if (isBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1
|
||||
if (IsBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1
|
||||
return BatchScalarCalc(task_id);
|
||||
}
|
||||
/* each batch is eltwise calculation */
|
||||
if (isBiasCalc()) { // 2 240 240 32, 1 1 1 32
|
||||
if (IsBiasCalc()) { // 2 240 240 32, 1 1 1 32
|
||||
return BiasCalc(task_id);
|
||||
}
|
||||
/* need broadcast in runtime */
|
||||
|
|
|
@ -97,6 +97,8 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
virtual void TileConstTensor(const void *in_data, void *out_data, size_t ndim, const int *in_shape,
|
||||
const int *in_strides, const int *out_strides, const int *multiple);
|
||||
virtual int Execute(const void *input0, const void *input1, void *output, int size, bool is_opt);
|
||||
virtual bool IsBatchScalarCalc();
|
||||
virtual bool IsScalarClac();
|
||||
bool input0_broadcast_ = false;
|
||||
bool input1_broadcast_ = false;
|
||||
void *input0_ptr_ = nullptr;
|
||||
|
@ -111,9 +113,7 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
int BatchScalarCalc(int task_id);
|
||||
int BiasCalc(int task_id);
|
||||
void FreeConstTileBuff();
|
||||
bool isScalarClac();
|
||||
bool isBatchScalarCalc();
|
||||
bool isBiasCalc();
|
||||
bool IsBiasCalc();
|
||||
ArithmeticRun arithmetic_run_ = nullptr;
|
||||
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
|
||||
ArithmeticIntRun arithmetic_run_int_ = nullptr;
|
||||
|
|
|
@ -18,11 +18,10 @@
|
|||
#include "src/runtime/parallel_executor.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
#define MAX_THREAD_NUM 8
|
||||
namespace mindspore::lite {
|
||||
ParallelExecutor::~ParallelExecutor() { DestroyThreadPool(thread_pool_); }
|
||||
int ParallelExecutor::Prepare(const std::vector<mindspore::kernel::LiteKernel *> &kernels) {
|
||||
thread_pool_ = CreateLiteThreadPool(MAX_THREAD_NUM, NO_BIND);
|
||||
thread_pool_ = CreateLiteThreadPool(max_thread_num_, NO_BIND);
|
||||
if (thread_pool_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory error: fail to new ThreadPool";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_LITE_SRC_RUNTIME_PARALLEL_EXECUTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include "src/runtime/allocator.h"
|
||||
#include "src/lite_kernel.h"
|
||||
|
@ -43,6 +44,7 @@ class ParallelExecutor : public Executor {
|
|||
std::vector<kernel::LiteKernel *> readyKernels;
|
||||
std::vector<int> results;
|
||||
struct ThreadPool *thread_pool_ = nullptr;
|
||||
int max_thread_num_ = std::thread::hardware_concurrency();
|
||||
};
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -21,10 +21,14 @@
|
|||
#include <semaphore.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#ifdef __WIN32__
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
#ifdef __ANDROID__
|
||||
#define BIND_CORE
|
||||
#include <unistd.h>
|
||||
#include <sched.h>
|
||||
#endif
|
||||
#ifdef MS_COMPILE_IOS
|
||||
|
@ -48,7 +52,6 @@
|
|||
#define RET_TP_ERROR (-8)
|
||||
#define RET_TP_SYSTEM_ERROR (-1)
|
||||
|
||||
#define MAX_THREAD_NUM (200)
|
||||
#define DEFAULT_SPIN_COUNT (30000)
|
||||
|
||||
typedef struct {
|
||||
|
@ -831,8 +834,15 @@ int CreateNewThread(struct ThreadPool *thread_pool, int thread_id) {
|
|||
}
|
||||
|
||||
ThreadPool *CreateThreadPool(int thread_num, int mode) {
|
||||
#ifdef __WIN32__
|
||||
SYSTEM_INFO sys_info;
|
||||
GetSystemInfo(&sys_info);
|
||||
long max_thread_num = sys_info.dwNumberOfProcessors;
|
||||
#else
|
||||
long max_thread_num = sysconf(_SC_NPROCESSORS_ONLN);
|
||||
#endif
|
||||
LOG_INFO("create thread pool, thread_num: %d, mode: %d", thread_num, mode);
|
||||
if (thread_num <= 0 || thread_num > MAX_THREAD_NUM) {
|
||||
if (thread_num <= 0 || thread_num > max_thread_num) {
|
||||
LOG_ERROR("invalid thread num: %d", thread_num);
|
||||
return NULL;
|
||||
}
|
||||
|
@ -851,7 +861,7 @@ ThreadPool *CreateThreadPool(int thread_num, int mode) {
|
|||
LOG_ERROR("Malloc ThreadPool failed");
|
||||
return NULL;
|
||||
}
|
||||
thread_pool->thread_num = thread_num > MAX_THREAD_NUM ? MAX_THREAD_NUM : thread_num;
|
||||
thread_pool->thread_num = thread_num > max_thread_num ? max_thread_num : thread_num;
|
||||
thread_pool->is_alive = ATOMIC_VAR_INIT(true);
|
||||
thread_pool->mode = mode;
|
||||
thread_pool->thread_list = NULL;
|
||||
|
|
Loading…
Reference in New Issue