!11995 [MSLITE][Develop] optimize arithmetic

From: @sunsuodong
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-04 10:14:52 +08:00 committed by Gitee
commit c2d9e1f396
2 changed files with 144 additions and 218 deletions

View File

@ -96,202 +96,59 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
}
void ArithmeticCPUKernel::InitRunFunction() {
switch (op_parameter_->type_) {
case PrimitiveType_Mul:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementMulRelu;
arithmetic_run_int_ = ElementMulReluInt;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementMulRelu6;
arithmetic_run_int_ = ElementMulRelu6Int;
break;
default:
arithmetic_run_ = ElementMul;
arithmetic_run_int_ = ElementMulInt;
break;
}
break;
case PrimitiveType_Add:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementAddRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementAddRelu6;
break;
default:
arithmetic_run_ = ElementAdd;
arithmetic_run_int_ = ElementAddInt;
break;
}
break;
case PrimitiveType_Sub:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementSubRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementSubRelu6;
break;
default:
arithmetic_run_ = ElementSub;
arithmetic_run_int_ = ElementSubInt;
break;
}
break;
case PrimitiveType_Div:
case PrimitiveType_RealDiv:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementDivRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementDivRelu6;
break;
default:
arithmetic_run_ = ElementDiv;
break;
}
break;
case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAnd;
arithmetic_run_int_ = ElementLogicalAndInt;
arithmetic_run_bool_ = ElementLogicalAndBool;
break;
case PrimitiveType_LogicalOr:
arithmetic_run_ = ElementLogicalOr;
break;
case PrimitiveType_Maximum:
arithmetic_run_ = ElementMaximum;
arithmetic_run_int_ = ElementMaximumInt;
break;
case PrimitiveType_Minimum:
arithmetic_run_ = ElementMinimum;
arithmetic_run_int_ = ElementMinimumInt;
break;
case PrimitiveType_FloorDiv:
arithmetic_run_ = ElementFloorDiv;
arithmetic_run_int_ = ElementFloorDivInt;
break;
case PrimitiveType_FloorMod:
arithmetic_run_ = ElementFloorMod;
arithmetic_run_int_ = ElementFloorModInt;
break;
case PrimitiveType_Mod:
arithmetic_run_ = ElementMod;
arithmetic_run_int_ = ElementModInt;
break;
case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifference;
break;
case PrimitiveType_Equal:
case PrimitiveType_Less:
case PrimitiveType_Greater:
case PrimitiveType_NotEqual:
case PrimitiveType_LessEqual:
case PrimitiveType_GreaterEqual:
arithmetic_run_ = nullptr;
arithmetic_run_int_ = nullptr;
break;
default:
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
arithmetic_run_ = nullptr;
break;
}
return;
}
ARITHMETIC_FUNC_INFO_FP32 fun_table[] = {
{PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr, ElementOptMulRelu,
ElementOptMulReluInt},
{PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6, ElementMulRelu6Int, nullptr, ElementOptMulRelu6,
ElementOptMulRelu6Int},
{PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMul, ElementMulInt, nullptr, ElementOptMul,
ElementOptMulInt},
{PrimitiveType_Add, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu, nullptr},
{PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6, nullptr, nullptr, ElementOptAddRelu6, nullptr},
{PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAdd, ElementAddInt, nullptr, ElementOptAdd,
ElementOptAddInt},
{PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu, nullptr},
{PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6, nullptr, nullptr, ElementOptSubRelu6, nullptr},
{PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSub, ElementSubInt, nullptr, ElementOptSub,
ElementOptSubInt},
{PrimitiveType_Div, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr},
{PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6, nullptr},
{PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
ElementOptDivInt},
{PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr},
{PrimitiveType_RealDiv, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6,
nullptr},
{PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv,
ElementOptDivInt},
{PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAnd, ElementLogicalAndInt,
ElementLogicalAndBool, nullptr, nullptr},
{PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOr, nullptr, nullptr, nullptr,
nullptr},
{PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr, nullptr,
nullptr},
{PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr, nullptr,
nullptr},
{PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorMod, ElementFloorModInt, nullptr,
nullptr, nullptr},
{PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDiv, ElementFloorDivInt, nullptr,
nullptr, nullptr},
{PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, ElementMod, ElementModInt, nullptr, ElementOptMod,
ElementOptModInt},
{PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifference, nullptr, nullptr,
nullptr, nullptr}};
void ArithmeticCPUKernel::InitOptRunFunction() {
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
switch (arithmeticParameter_->op_parameter_.type_) {
case PrimitiveType_Mul:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMulRelu;
arithmetic_opt_run_int_ = ElementOptMulReluInt;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMulRelu6;
arithmetic_opt_run_int_ = ElementOptMulRelu6Int;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMul;
arithmetic_opt_run_int_ = ElementOptMulInt;
break;
}
break;
case PrimitiveType_Add:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptAddRelu;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptAddRelu6;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptAdd;
arithmetic_opt_run_int_ = ElementOptAddInt;
break;
}
break;
case PrimitiveType_Sub:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSubRelu;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSubRelu6;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSub;
arithmetic_opt_run_int_ = ElementOptSubInt;
break;
}
break;
case PrimitiveType_Div:
case PrimitiveType_RealDiv:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptDivRelu;
break;
case schema::ActivationType_RELU6:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptDivRelu6;
break;
default:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptDiv;
arithmetic_opt_run_int_ = ElementOptDivInt;
break;
}
break;
case PrimitiveType_Mod:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMod;
arithmetic_opt_run_int_ = ElementOptModInt;
break;
default:
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
break;
size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32);
for (size_t i = 0; i < length; i++) {
if (fun_table[i].primitive_type_ == op_parameter_->type_ &&
fun_table[i].activation_type_ == arithmeticParameter_->activation_type_) {
arithmetic_run_ = fun_table[i].func_;
arithmetic_run_int_ = fun_table[i].int_func_;
arithmetic_run_bool_ = fun_table[i].bool_func_;
arithmetic_opt_run_ = fun_table[i].opt_func_;
arithmetic_opt_run_int_ = fun_table[i].opt_int_func_;
return;
}
} else {
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
}
return;
}
void ArithmeticCPUKernel::InitParam() {
@ -321,7 +178,6 @@ void ArithmeticCPUKernel::InitParam() {
int ArithmeticCPUKernel::ReSize() {
InitParam();
InitOptRunFunction();
return InitBroadCastCase();
}
@ -359,6 +215,66 @@ int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output,
return RET_OK;
}
bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1
if (input0_broadcast_ == true || input1_broadcast_ == true) {
return false;
}
if (arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->in_elements_num1_ ||
arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
return false;
}
size_t break_axis = 0;
for (size_t i = 0; i < arithmeticParameter_->ndim_; i++) {
if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) {
break_axis = i;
break;
}
}
if (break_axis < arithmeticParameter_->ndim_) {
for (size_t i = break_axis; i < arithmeticParameter_->ndim_; i++) {
if (arithmeticParameter_->in_shape1_[i] != 1) {
return false;
}
}
}
break_pos_ = break_axis;
return true;
}
int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
int batch = arithmeticParameter_->out_elements_num_ / arithmeticParameter_->out_strides_[break_pos_ - 1];
int batch_per_thread = UP_DIV(batch, thread_count_);
int start_batch = batch_per_thread * task_id;
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
int batch_size = end_batch - start_batch;
int stride0 = arithmeticParameter_->in_strides0_[break_pos_ - 1];
int stride1 = arithmeticParameter_->in_strides1_[break_pos_ - 1];
int out_stride = arithmeticParameter_->out_strides_[break_pos_ - 1];
int offset0 = stride0 * start_batch;
int offset1 = stride1 * start_batch;
int out_offset = out_stride * start_batch;
int ret = RET_OK;
for (int i = 0; i < batch_size; i++) {
if (data_type_ == kDataTypeFloat) {
ret = arithmetic_opt_run_(
reinterpret_cast<float *>(input0_ptr_) + offset0, reinterpret_cast<float *>(input1_ptr_) + offset1,
reinterpret_cast<float *>(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_);
} else {
ret = arithmetic_opt_run_int_(
reinterpret_cast<int *>(input0_ptr_) + offset0, reinterpret_cast<int *>(input1_ptr_) + offset1,
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_);
}
offset0 += stride0;
offset1 += stride1;
out_offset += out_stride;
}
return ret;
}
int ArithmeticCPUKernel::DoArithmetic(int task_id) {
auto element_num = out_tensors_[0]->ElementsNum();
@ -370,27 +286,12 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
return RET_ERROR;
}
int error_code;
if (arithmeticParameter_->broadcasting_) {
/* need broadcast in runtime */
stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
if (out_count <= 0) {
return RET_OK;
}
int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
}
return error_code;
if (CanBatchScalar()) {
return BatchScalarCalc(task_id);
}
if (arithmetic_opt_run_ != nullptr) {
int error_code = RET_OK;
if ((arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) &&
(arithmetic_opt_run_ != nullptr && arithmetic_opt_run_int_ != nullptr)) {
/* run opt function
* one of input is scalar */
if (arithmeticParameter_->in_elements_num0_ == 1) {
@ -413,11 +314,24 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
reinterpret_cast<int *>(input0_ptr_) + stride * task_id, reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<int *>(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_);
}
} else {
MS_LOG(ERROR) << "Arithmetic opt run: at least one of inputs is scalar";
return RET_ERROR;
}
return error_code;
}
if (arithmeticParameter_->broadcasting_) {
/* need broadcast in runtime */
stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id);
if (out_count <= 0) {
return RET_OK;
}
int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(reinterpret_cast<float *>(input0_ptr_), reinterpret_cast<float *>(input1_ptr_),
reinterpret_cast<float *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<int *>(input0_ptr_), reinterpret_cast<int *>(input1_ptr_),
reinterpret_cast<int *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
}
return error_code;
}

View File

@ -52,6 +52,16 @@ class ArithmeticCPUKernel : public LiteKernel {
const ArithmeticParameter *param);
typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size);
typedef struct {
int primitive_type_;
int activation_type_;
ArithmeticRun func_;
ArithmeticIntRun int_func_;
ArithmeticBoolRun bool_func_;
ArithmeticOptRun opt_func_;
ArithmeticOptIntRun opt_int_func_;
} ARITHMETIC_FUNC_INFO_FP32;
public:
ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
@ -75,6 +85,8 @@ class ArithmeticCPUKernel : public LiteKernel {
void FreeTmpPtr();
int InitBroadCastCase();
void InitParamInRunTime();
bool CanBatchScalar();
int BatchScalarCalc(int task_id);
protected:
bool input0_broadcast_ = false;