!12998 [MSLITE][Develop] Optimize arithmetic performance and memory

From: @sunsuodong
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-03-11 15:29:23 +08:00 committed by Gitee
commit aa9bee0ce3
3 changed files with 94 additions and 43 deletions

View File

@ -119,10 +119,6 @@ void ArithmeticFP16CPUKernel::TileConstTensor(const void *in_data, void *out_dat
int ArithmeticFP16CPUKernel::Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) {
int ret = RET_OK;
if (in_tensors_[0]->data_type() != kNumberTypeFloat16) {
MS_LOG(ERROR) << "data type is not fp16";
return RET_ERROR;
}
if (is_opt) {
CHECK_NULL_RETURN(arithmetic_opt_func_, RET_ERROR);
ret = arithmetic_opt_func_(reinterpret_cast<const float16_t *>(input0), reinterpret_cast<const float16_t *>(input1),

View File

@ -60,7 +60,11 @@ int ArithmeticCPUKernel::ReSize() {
outside_ *= param_->out_shape_[i];
}
}
return ConstTensorBroadCast();
int ret = RET_OK;
if (!isScalarClac() && !isBatchScalarCalc() && !isBiasCalc()) {
ret = ConstTensorBroadCast();
}
return ret;
}
int ArithmeticCPUKernel::CheckDataType() {
@ -73,6 +77,47 @@ int ArithmeticCPUKernel::CheckDataType() {
return RET_OK;
}
bool ArithmeticCPUKernel::isScalarClac() { // 2 32 240 240, 1 1 1 1
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) {
return true;
} else {
return false;
}
}
bool ArithmeticCPUKernel::isBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
if (arithmetic_opt_run_ == nullptr) {
return false;
}
size_t break_axis = 0;
for (size_t i = 0; i < param_->ndim_; i++) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_axis = i;
break;
}
}
if (break_axis < param_->ndim_) {
for (size_t i = break_axis; i < param_->ndim_; i++) {
if (param_->in_shape1_[i] != 1) {
return false;
}
}
}
break_pos_ = break_axis;
return true;
}
bool ArithmeticCPUKernel::isBiasCalc() { // 2 240 240 32, 1 1 1 32
int last_shape0 = param_->in_shape0_[param_->ndim_ - 1];
int last_shape1 = param_->in_shape1_[param_->ndim_ - 1];
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
return param_->in_elements_num1_ == last_shape1 && last_shape0 == last_shape1;
} else if (param_->in_elements_num0_ < param_->in_elements_num1_) {
return param_->in_elements_num0_ == last_shape0 && last_shape0 == last_shape1;
}
return false;
}
int ArithmeticCPUKernel::ConstTensorBroadCast() {
/* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */
if (!param_->broadcasting_) {
@ -86,11 +131,6 @@ int ArithmeticCPUKernel::ConstTensorBroadCast() {
param_->in_elements_num1_ != param_->out_elements_num_) {
return RET_OK;
}
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && arithmetic_opt_run_ != nullptr) {
/* run opt function
* one of input is scalar */
return RET_OK;
}
FreeConstTileBuff();
if (in_tensors_[0]->data_c() != nullptr && param_->in_elements_num0_ != param_->out_elements_num_) {
@ -252,32 +292,6 @@ int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output,
return RET_OK;
}
bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1
if (input0_broadcast_ || input1_broadcast_) {
return false;
}
if (param_->in_elements_num0_ == param_->in_elements_num1_ || param_->in_elements_num0_ == 1 ||
param_->in_elements_num1_ == 1) {
return false;
}
size_t break_axis = 0;
for (size_t i = 0; i < param_->ndim_; i++) {
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
break_axis = i;
break;
}
}
if (break_axis < param_->ndim_) {
for (size_t i = break_axis; i < param_->ndim_; i++) {
if (param_->in_shape1_[i] != 1) {
return false;
}
}
}
break_pos_ = break_axis;
return true;
}
int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
if (break_pos_ < 1) {
return RET_ERROR;
@ -308,6 +322,40 @@ int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
return ret;
}
int ArithmeticCPUKernel::BiasCalc(int task_id) {
int last_shape = param_->out_shape_[param_->ndim_ - 1];
int batch = param_->out_elements_num_ / last_shape;
int batch_per_thread = UP_DIV(batch, context_->thread_num_);
int start_batch = batch_per_thread * task_id;
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
int batch_size = end_batch - start_batch;
int stride = last_shape * data_type_len_;
int offset = stride * start_batch;
int ret = RET_OK;
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
for (int i = 0; i < batch_size; i++) {
ret = Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_),
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false);
if (ret != RET_OK) {
return ret;
}
offset += stride;
}
} else {
for (int i = 0; i < batch_size; i++) {
ret = Execute(static_cast<uint8_t *>(input0_ptr_), static_cast<uint8_t *>(input1_ptr_) + offset,
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false);
if (ret != RET_OK) {
return ret;
}
offset += stride;
}
}
return ret;
}
int ArithmeticCPUKernel::DoArithmetic(int task_id) {
auto element_num = out_tensors_[0]->ElementsNum();
int stride = UP_DIV(element_num, context_->thread_num_);
@ -315,13 +363,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
if (count <= 0) {
return RET_OK;
}
/* run opt function, every batch one of input is scalar */
if (CanBatchScalar()) {
return BatchScalarCalc(task_id);
}
int offset = stride * task_id * data_type_len_;
/* run opt function, one of input is scalar */
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && arithmetic_opt_run_ != nullptr) {
if (isScalarClac()) { // 2 32 240 240, 1 1 1 1
if (param_->in_elements_num0_ == 1) {
return Execute(input0_ptr_, static_cast<uint8_t *>(input1_ptr_) + offset,
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
@ -330,6 +374,14 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
}
}
/* run opt function, every batch one of input is scalar */
if (isBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1
return BatchScalarCalc(task_id);
}
/* each batch is eltwise calculation */
if (isBiasCalc()) { // 2 240 240 32, 1 1 1 32
return BiasCalc(task_id);
}
/* need broadcast in runtime */
if (param_->broadcasting_) {
stride = UP_DIV(outside_, context_->thread_num_);
@ -339,7 +391,7 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
}
return BroadcastRun(input0_ptr_, input1_ptr_, output_ptr_, 0, out_count, stride * task_id);
}
/* no broadcast in runtime */
/* all elements eltwise calculation */
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_) + offset,
static_cast<uint8_t *>(output_ptr_) + offset, count, false);
}

View File

@ -108,9 +108,12 @@ class ArithmeticCPUKernel : public LiteKernel {
int data_type_len_ = sizeof(float);
private:
bool CanBatchScalar();
int BatchScalarCalc(int task_id);
int BiasCalc(int task_id);
void FreeConstTileBuff();
bool isScalarClac();
bool isBatchScalarCalc();
bool isBiasCalc();
ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
ArithmeticIntRun arithmetic_run_int_ = nullptr;