!11995 [MSLITE][Develop] optimize arithmetic

From: @sunsuodong
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-04 10:14:52 +08:00 committed by Gitee
commit c2d9e1f396
2 changed files with 144 additions and 218 deletions

View File

@ -96,202 +96,59 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
} }
void ArithmeticCPUKernel::InitRunFunction() { void ArithmeticCPUKernel::InitRunFunction() {
switch (op_parameter_->type_) { ARITHMETIC_FUNC_INFO_FP32 fun_table[] = {
case PrimitiveType_Mul: {PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr, ElementOptMulRelu,
switch (arithmeticParameter_->activation_type_) { ElementOptMulReluInt},
case schema::ActivationType_RELU: {PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6, ElementMulRelu6Int, nullptr, ElementOptMulRelu6,
arithmetic_run_ = ElementMulRelu; ElementOptMulRelu6Int},
arithmetic_run_int_ = ElementMulReluInt; {PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMul, ElementMulInt, nullptr, ElementOptMul,
break; ElementOptMulInt},
case schema::ActivationType_RELU6: {PrimitiveType_Add, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu, nullptr},
arithmetic_run_ = ElementMulRelu6; {PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6, nullptr, nullptr, ElementOptAddRelu6, nullptr},
arithmetic_run_int_ = ElementMulRelu6Int; {PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAdd, ElementAddInt, nullptr, ElementOptAdd,
break; ElementOptAddInt},
default: {PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu, nullptr},
arithmetic_run_ = ElementMul; {PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6, nullptr, nullptr, ElementOptSubRelu6, nullptr},
arithmetic_run_int_ = ElementMulInt; {PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSub, ElementSubInt, nullptr, ElementOptSub,
break; ElementOptSubInt},
} {PrimitiveType_Div, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr},
break; {PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6, nullptr},
case PrimitiveType_Add: {PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
switch (arithmeticParameter_->activation_type_) { ElementOptDivInt},
case schema::ActivationType_RELU: {PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr},
arithmetic_run_ = ElementAddRelu; {PrimitiveType_RealDiv, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
break; nullptr},
case schema::ActivationType_RELU6: {PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
arithmetic_run_ = ElementAddRelu6; ElementOptDivInt},
break; {PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAnd, ElementLogicalAndInt,
default: ElementLogicalAndBool, nullptr, nullptr},
arithmetic_run_ = ElementAdd; {PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOr, nullptr, nullptr, nullptr,
arithmetic_run_int_ = ElementAddInt; nullptr},
break; {PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr, nullptr,
} nullptr},
break; {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr, nullptr,
case PrimitiveType_Sub: nullptr},
switch (arithmeticParameter_->activation_type_) { {PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorMod, ElementFloorModInt, nullptr,
case schema::ActivationType_RELU: nullptr, nullptr},
arithmetic_run_ = ElementSubRelu; {PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDiv, ElementFloorDivInt, nullptr,
break; nullptr, nullptr},
case schema::ActivationType_RELU6: {PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, ElementMod, ElementModInt, nullptr, ElementOptMod,
arithmetic_run_ = ElementSubRelu6; ElementOptModInt},
break; {PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifference, nullptr, nullptr,
default: nullptr, nullptr}};
arithmetic_run_ = ElementSub;
arithmetic_run_int_ = ElementSubInt;
break;
}
break;
case PrimitiveType_Div:
case PrimitiveType_RealDiv:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementDivRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementDivRelu6;
break;
default:
arithmetic_run_ = ElementDiv;
break;
}
break;
case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAnd;
arithmetic_run_int_ = ElementLogicalAndInt;
arithmetic_run_bool_ = ElementLogicalAndBool;
break;
case PrimitiveType_LogicalOr:
arithmetic_run_ = ElementLogicalOr;
break;
case PrimitiveType_Maximum:
arithmetic_run_ = ElementMaximum;
arithmetic_run_int_ = ElementMaximumInt;
break;
case PrimitiveType_Minimum:
arithmetic_run_ = ElementMinimum;
arithmetic_run_int_ = ElementMinimumInt;
break;
case PrimitiveType_FloorDiv:
arithmetic_run_ = ElementFloorDiv;
arithmetic_run_int_ = ElementFloorDivInt;
break;
case PrimitiveType_FloorMod:
arithmetic_run_ = ElementFloorMod;
arithmetic_run_int_ = ElementFloorModInt;
break;
case PrimitiveType_Mod:
arithmetic_run_ = ElementMod;
arithmetic_run_int_ = ElementModInt;
break;
case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifference;
break;
case PrimitiveType_Equal:
case PrimitiveType_Less:
case PrimitiveType_Greater:
case PrimitiveType_NotEqual:
case PrimitiveType_LessEqual:
case PrimitiveType_GreaterEqual:
arithmetic_run_ = nullptr;
arithmetic_run_int_ = nullptr;
break;
default:
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
arithmetic_run_ = nullptr;
break;
}
return;
}
void ArithmeticCPUKernel::InitOptRunFunction() { size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32);
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { for (size_t i = 0; i < length; i++) {
switch (arithmeticParameter_->op_parameter_.type_) { if (fun_table[i].primitive_type_ == op_parameter_->type_ &&
case PrimitiveType_Mul: fun_table[i].activation_type_ == arithmeticParameter_->activation_type_) {
switch (arithmeticParameter_->activation_type_) { arithmetic_run_ = fun_table[i].func_;
case schema::ActivationType_RELU: arithmetic_run_int_ = fun_table[i].int_func_;
arithmeticParameter_->broadcasting_ = false; arithmetic_run_bool_ = fun_table[i].bool_func_;
arithmetic_opt_run_ = ElementOptMulRelu; arithmetic_opt_run_ = fun_table[i].opt_func_;
arithmetic_opt_run_int_ = ElementOptMulReluInt; arithmetic_opt_run_int_ = fun_table[i].opt_int_func_;
break; return;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMulRelu6;
arithmetic_opt_run_int_ = ElementOptMulRelu6Int;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMul;
arithmetic_opt_run_int_ = ElementOptMulInt;
break;
}
break;
case PrimitiveType_Add:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptAddRelu;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptAddRelu6;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptAdd;
arithmetic_opt_run_int_ = ElementOptAddInt;
break;
}
break;
case PrimitiveType_Sub:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSubRelu;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSubRelu6;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSub;
arithmetic_opt_run_int_ = ElementOptSubInt;
break;
}
break;
case PrimitiveType_Div:
case PrimitiveType_RealDiv:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptDivRelu;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptDivRelu6;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptDiv;
arithmetic_opt_run_int_ = ElementOptDivInt;
break;
}
break;
case PrimitiveType_Mod:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMod;
arithmetic_opt_run_int_ = ElementOptModInt;
break;
default:
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
break;
} }
} else {
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
} }
return;
} }
void ArithmeticCPUKernel::InitParam() { void ArithmeticCPUKernel::InitParam() {
@ -321,7 +178,6 @@ void ArithmeticCPUKernel::InitParam() {
int ArithmeticCPUKernel::ReSize() { int ArithmeticCPUKernel::ReSize() {
InitParam(); InitParam();
InitOptRunFunction();
return InitBroadCastCase(); return InitBroadCastCase();
} }
@ -359,6 +215,66 @@ int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output,
return RET_OK; return RET_OK;
} }
bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1
if (input0_broadcast_ == true || input1_broadcast_ == true) {
return false;
}
if (arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->in_elements_num1_ ||
arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
return false;
}
size_t break_axis = 0;
for (size_t i = 0; i < arithmeticParameter_->ndim_; i++) {
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) {
break_axis = i;
break;
}
}
if (break_axis < arithmeticParameter_->ndim_) {
for (size_t i = break_axis; i < arithmeticParameter_->ndim_; i++) {
if (arithmeticParameter_->in_shape1_[i] != 1) {
return false;
}
}
}
break_pos_ = break_axis;
return true;
}
int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
int batch = arithmeticParameter_->out_elements_num_ / arithmeticParameter_->out_strides_[break_pos_ - 1];
int batch_per_thread = UP_DIV(batch, thread_count_);
int start_batch = batch_per_thread * task_id;
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
int batch_size = end_batch - start_batch;
int stride0 = arithmeticParameter_->in_strides0_[break_pos_ - 1];
int stride1 = arithmeticParameter_->in_strides1_[break_pos_ - 1];
int out_stride = arithmeticParameter_->out_strides_[break_pos_ - 1];
int offset0 = stride0 * start_batch;
int offset1 = stride1 * start_batch;
int out_offset = out_stride * start_batch;
int ret = RET_OK;
for (int i = 0; i < batch_size; i++) {
if (data_type_ == kDataTypeFloat) {
ret = arithmetic_opt_run_(
reinterpret_cast<float *>(input0_ptr_) + offset0, reinterpret_cast<float *>(input1_ptr_) + offset1,
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_);
} else {
ret = arithmetic_opt_run_int_(
reinterpret_cast<int *>(input0_ptr_) + offset0, reinterpret_cast<int *>(input1_ptr_) + offset1,
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_);
}
offset0 += stride0;
offset1 += stride1;
out_offset += out_stride;
}
return ret;
}
int ArithmeticCPUKernel::DoArithmetic(int task_id) { int ArithmeticCPUKernel::DoArithmetic(int task_id) {
auto element_num = out_tensors_[0]->ElementsNum(); auto element_num = out_tensors_[0]->ElementsNum();
@ -370,27 +286,12 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
return RET_ERROR; return RET_ERROR;
} }
if (CanBatchScalar()) {
int error_code; return BatchScalarCalc(task_id);
if (arithmeticParameter_->broadcasting_) {
/* need broadcast in runtime */
stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
if (out_count <= 0) {
return RET_OK;
}
int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
}
return error_code;
} }
int error_code = RET_OK;
if (arithmetic_opt_run_ != nullptr) { if ((arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) &&
(arithmetic_opt_run_ != nullptr && arithmetic_opt_run_int_ != nullptr)) {
/* run opt function /* run opt function
* one of input is scalar */ * one of input is scalar */
if (arithmeticParameter_->in_elements_num0_ == 1) { if (arithmeticParameter_->in_elements_num0_ == 1) {
@ -413,11 +314,24 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
reinterpret_cast<int *>(input0_ptr_) + stride * task_id, reinterpret_cast<int *>(input1_ptr_), reinterpret_cast<int *>(input0_ptr_) + stride * task_id, reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
} }
} else {
MS_LOG(ERROR) << "Arithmetic opt run: at least one of inputs is scalar";
return RET_ERROR;
} }
return error_code;
}
if (arithmeticParameter_->broadcasting_) {
/* need broadcast in runtime */
stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
if (out_count <= 0) {
return RET_OK;
}
int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
}
return error_code; return error_code;
} }

View File

@ -52,6 +52,16 @@ class ArithmeticCPUKernel : public LiteKernel {
const ArithmeticParameter *param); const ArithmeticParameter *param);
typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size); typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size);
typedef struct {
int primitive_type_;
int activation_type_;
ArithmeticRun func_;
ArithmeticIntRun int_func_;
ArithmeticBoolRun bool_func_;
ArithmeticOptRun opt_func_;
ArithmeticOptIntRun opt_int_func_;
} ARITHMETIC_FUNC_INFO_FP32;
public: public:
ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
@ -75,6 +85,8 @@ class ArithmeticCPUKernel : public LiteKernel {
void FreeTmpPtr(); void FreeTmpPtr();
int InitBroadCastCase(); int InitBroadCastCase();
void InitParamInRunTime(); void InitParamInRunTime();
bool CanBatchScalar();
int BatchScalarCalc(int task_id);
protected: protected:
bool input0_broadcast_ = false; bool input0_broadcast_ = false;