forked from mindspore-Ecosystem/mindspore
!11995 [MSLITE][Develop] optimize arithmetic
From: @sunsuodong Reviewed-by: Signed-off-by:
This commit is contained in:
commit
c2d9e1f396
|
@ -96,202 +96,59 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
|
|||
}
|
||||
|
||||
void ArithmeticCPUKernel::InitRunFunction() {
|
||||
switch (op_parameter_->type_) {
|
||||
case PrimitiveType_Mul:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementMulRelu;
|
||||
arithmetic_run_int_ = ElementMulReluInt;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementMulRelu6;
|
||||
arithmetic_run_int_ = ElementMulRelu6Int;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementMul;
|
||||
arithmetic_run_int_ = ElementMulInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Add:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementAddRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementAddRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementAdd;
|
||||
arithmetic_run_int_ = ElementAddInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Sub:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementSubRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementSubRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementSub;
|
||||
arithmetic_run_int_ = ElementSubInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Div:
|
||||
case PrimitiveType_RealDiv:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmetic_run_ = ElementDivRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmetic_run_ = ElementDivRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementDiv;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_LogicalAnd:
|
||||
arithmetic_run_ = ElementLogicalAnd;
|
||||
arithmetic_run_int_ = ElementLogicalAndInt;
|
||||
arithmetic_run_bool_ = ElementLogicalAndBool;
|
||||
break;
|
||||
case PrimitiveType_LogicalOr:
|
||||
arithmetic_run_ = ElementLogicalOr;
|
||||
break;
|
||||
case PrimitiveType_Maximum:
|
||||
arithmetic_run_ = ElementMaximum;
|
||||
arithmetic_run_int_ = ElementMaximumInt;
|
||||
break;
|
||||
case PrimitiveType_Minimum:
|
||||
arithmetic_run_ = ElementMinimum;
|
||||
arithmetic_run_int_ = ElementMinimumInt;
|
||||
break;
|
||||
case PrimitiveType_FloorDiv:
|
||||
arithmetic_run_ = ElementFloorDiv;
|
||||
arithmetic_run_int_ = ElementFloorDivInt;
|
||||
break;
|
||||
case PrimitiveType_FloorMod:
|
||||
arithmetic_run_ = ElementFloorMod;
|
||||
arithmetic_run_int_ = ElementFloorModInt;
|
||||
break;
|
||||
case PrimitiveType_Mod:
|
||||
arithmetic_run_ = ElementMod;
|
||||
arithmetic_run_int_ = ElementModInt;
|
||||
break;
|
||||
case PrimitiveType_SquaredDifference:
|
||||
arithmetic_run_ = ElementSquaredDifference;
|
||||
break;
|
||||
case PrimitiveType_Equal:
|
||||
case PrimitiveType_Less:
|
||||
case PrimitiveType_Greater:
|
||||
case PrimitiveType_NotEqual:
|
||||
case PrimitiveType_LessEqual:
|
||||
case PrimitiveType_GreaterEqual:
|
||||
arithmetic_run_ = nullptr;
|
||||
arithmetic_run_int_ = nullptr;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
|
||||
arithmetic_run_ = nullptr;
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
ARITHMETIC_FUNC_INFO_FP32 fun_table[] = {
|
||||
{PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr, ElementOptMulRelu,
|
||||
ElementOptMulReluInt},
|
||||
{PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6, ElementMulRelu6Int, nullptr, ElementOptMulRelu6,
|
||||
ElementOptMulRelu6Int},
|
||||
{PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMul, ElementMulInt, nullptr, ElementOptMul,
|
||||
ElementOptMulInt},
|
||||
{PrimitiveType_Add, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu, nullptr},
|
||||
{PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6, nullptr, nullptr, ElementOptAddRelu6, nullptr},
|
||||
{PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAdd, ElementAddInt, nullptr, ElementOptAdd,
|
||||
ElementOptAddInt},
|
||||
{PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu, nullptr},
|
||||
{PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6, nullptr, nullptr, ElementOptSubRelu6, nullptr},
|
||||
{PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSub, ElementSubInt, nullptr, ElementOptSub,
|
||||
ElementOptSubInt},
|
||||
{PrimitiveType_Div, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr},
|
||||
{PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6, nullptr},
|
||||
{PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
|
||||
ElementOptDivInt},
|
||||
{PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr},
|
||||
{PrimitiveType_RealDiv, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
|
||||
nullptr},
|
||||
{PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
|
||||
ElementOptDivInt},
|
||||
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAnd, ElementLogicalAndInt,
|
||||
ElementLogicalAndBool, nullptr, nullptr},
|
||||
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOr, nullptr, nullptr, nullptr,
|
||||
nullptr},
|
||||
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr, nullptr,
|
||||
nullptr},
|
||||
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr, nullptr,
|
||||
nullptr},
|
||||
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorMod, ElementFloorModInt, nullptr,
|
||||
nullptr, nullptr},
|
||||
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDiv, ElementFloorDivInt, nullptr,
|
||||
nullptr, nullptr},
|
||||
{PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, ElementMod, ElementModInt, nullptr, ElementOptMod,
|
||||
ElementOptModInt},
|
||||
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifference, nullptr, nullptr,
|
||||
nullptr, nullptr}};
|
||||
|
||||
void ArithmeticCPUKernel::InitOptRunFunction() {
|
||||
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
|
||||
switch (arithmeticParameter_->op_parameter_.type_) {
|
||||
case PrimitiveType_Mul:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptMulRelu;
|
||||
arithmetic_opt_run_int_ = ElementOptMulReluInt;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptMulRelu6;
|
||||
arithmetic_opt_run_int_ = ElementOptMulRelu6Int;
|
||||
break;
|
||||
default:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptMul;
|
||||
arithmetic_opt_run_int_ = ElementOptMulInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Add:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptAddRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptAddRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptAdd;
|
||||
arithmetic_opt_run_int_ = ElementOptAddInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Sub:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptSubRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptSubRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptSub;
|
||||
arithmetic_opt_run_int_ = ElementOptSubInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Div:
|
||||
case PrimitiveType_RealDiv:
|
||||
switch (arithmeticParameter_->activation_type_) {
|
||||
case schema::ActivationType_RELU:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptDivRelu;
|
||||
break;
|
||||
case schema::ActivationType_RELU6:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptDivRelu6;
|
||||
break;
|
||||
default:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptDiv;
|
||||
arithmetic_opt_run_int_ = ElementOptDivInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case PrimitiveType_Mod:
|
||||
arithmeticParameter_->broadcasting_ = false;
|
||||
arithmetic_opt_run_ = ElementOptMod;
|
||||
arithmetic_opt_run_int_ = ElementOptModInt;
|
||||
break;
|
||||
default:
|
||||
arithmetic_opt_run_ = nullptr;
|
||||
arithmetic_opt_run_int_ = nullptr;
|
||||
break;
|
||||
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32);
|
||||
for (size_t i = 0; i < length; i++) {
|
||||
if (fun_table[i].primitive_type_ == op_parameter_->type_ &&
|
||||
fun_table[i].activation_type_ == arithmeticParameter_->activation_type_) {
|
||||
arithmetic_run_ = fun_table[i].func_;
|
||||
arithmetic_run_int_ = fun_table[i].int_func_;
|
||||
arithmetic_run_bool_ = fun_table[i].bool_func_;
|
||||
arithmetic_opt_run_ = fun_table[i].opt_func_;
|
||||
arithmetic_opt_run_int_ = fun_table[i].opt_int_func_;
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
arithmetic_opt_run_ = nullptr;
|
||||
arithmetic_opt_run_int_ = nullptr;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ArithmeticCPUKernel::InitParam() {
|
||||
|
@ -321,7 +178,6 @@ void ArithmeticCPUKernel::InitParam() {
|
|||
|
||||
int ArithmeticCPUKernel::ReSize() {
|
||||
InitParam();
|
||||
InitOptRunFunction();
|
||||
return InitBroadCastCase();
|
||||
}
|
||||
|
||||
|
@ -359,6 +215,66 @@ int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1
|
||||
if (input0_broadcast_ == true || input1_broadcast_ == true) {
|
||||
return false;
|
||||
}
|
||||
if (arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->in_elements_num1_ ||
|
||||
arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
|
||||
return false;
|
||||
}
|
||||
size_t break_axis = 0;
|
||||
for (size_t i = 0; i < arithmeticParameter_->ndim_; i++) {
|
||||
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) {
|
||||
break_axis = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (break_axis < arithmeticParameter_->ndim_) {
|
||||
for (size_t i = break_axis; i < arithmeticParameter_->ndim_; i++) {
|
||||
if (arithmeticParameter_->in_shape1_[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
break_pos_ = break_axis;
|
||||
return true;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
|
||||
int batch = arithmeticParameter_->out_elements_num_ / arithmeticParameter_->out_strides_[break_pos_ - 1];
|
||||
int batch_per_thread = UP_DIV(batch, thread_count_);
|
||||
|
||||
int start_batch = batch_per_thread * task_id;
|
||||
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
|
||||
int batch_size = end_batch - start_batch;
|
||||
|
||||
int stride0 = arithmeticParameter_->in_strides0_[break_pos_ - 1];
|
||||
int stride1 = arithmeticParameter_->in_strides1_[break_pos_ - 1];
|
||||
int out_stride = arithmeticParameter_->out_strides_[break_pos_ - 1];
|
||||
|
||||
int offset0 = stride0 * start_batch;
|
||||
int offset1 = stride1 * start_batch;
|
||||
int out_offset = out_stride * start_batch;
|
||||
|
||||
int ret = RET_OK;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
ret = arithmetic_opt_run_(
|
||||
reinterpret_cast<float *>(input0_ptr_) + offset0, reinterpret_cast<float *>(input1_ptr_) + offset1,
|
||||
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_);
|
||||
} else {
|
||||
ret = arithmetic_opt_run_int_(
|
||||
reinterpret_cast<int *>(input0_ptr_) + offset0, reinterpret_cast<int *>(input1_ptr_) + offset1,
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_);
|
||||
}
|
||||
offset0 += stride0;
|
||||
offset1 += stride1;
|
||||
out_offset += out_stride;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
||||
auto element_num = out_tensors_[0]->ElementsNum();
|
||||
|
||||
|
@ -370,27 +286,12 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int error_code;
|
||||
if (arithmeticParameter_->broadcasting_) {
|
||||
/* need broadcast in runtime */
|
||||
stride = UP_DIV(outside_, thread_count_);
|
||||
int out_count = MSMIN(stride, outside_ - stride * task_id);
|
||||
if (out_count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
int out_thread_stride = stride * task_id;
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
|
||||
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
} else {
|
||||
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
}
|
||||
return error_code;
|
||||
if (CanBatchScalar()) {
|
||||
return BatchScalarCalc(task_id);
|
||||
}
|
||||
|
||||
if (arithmetic_opt_run_ != nullptr) {
|
||||
int error_code = RET_OK;
|
||||
if ((arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) &&
|
||||
(arithmetic_opt_run_ != nullptr && arithmetic_opt_run_int_ != nullptr)) {
|
||||
/* run opt function
|
||||
* one of input is scalar */
|
||||
if (arithmeticParameter_->in_elements_num0_ == 1) {
|
||||
|
@ -413,11 +314,24 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
reinterpret_cast<int *>(input0_ptr_) + stride * task_id, reinterpret_cast<int *>(input1_ptr_),
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Arithmetic opt run: at least one of inputs is scalar";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return error_code;
|
||||
}
|
||||
if (arithmeticParameter_->broadcasting_) {
|
||||
/* need broadcast in runtime */
|
||||
stride = UP_DIV(outside_, thread_count_);
|
||||
int out_count = MSMIN(stride, outside_ - stride * task_id);
|
||||
if (out_count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
int out_thread_stride = stride * task_id;
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
|
||||
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
} else {
|
||||
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
|
||||
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
|
||||
}
|
||||
return error_code;
|
||||
}
|
||||
|
||||
|
|
|
@ -52,6 +52,16 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
const ArithmeticParameter *param);
|
||||
typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size);
|
||||
|
||||
typedef struct {
|
||||
int primitive_type_;
|
||||
int activation_type_;
|
||||
ArithmeticRun func_;
|
||||
ArithmeticIntRun int_func_;
|
||||
ArithmeticBoolRun bool_func_;
|
||||
ArithmeticOptRun opt_func_;
|
||||
ArithmeticOptIntRun opt_int_func_;
|
||||
} ARITHMETIC_FUNC_INFO_FP32;
|
||||
|
||||
public:
|
||||
ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
|
@ -75,6 +85,8 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
void FreeTmpPtr();
|
||||
int InitBroadCastCase();
|
||||
void InitParamInRunTime();
|
||||
bool CanBatchScalar();
|
||||
int BatchScalarCalc(int task_id);
|
||||
|
||||
protected:
|
||||
bool input0_broadcast_ = false;
|
||||
|
|
Loading…
Reference in New Issue