forked from mindspore-Ecosystem/mindspore
!12998 [MSLITE][Develop] Optimize arithmetic performance and memory
From: @sunsuodong Reviewed-by: @zhanghaibo5,@zhang_xue_tong Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
aa9bee0ce3
|
@ -119,10 +119,6 @@ void ArithmeticFP16CPUKernel::TileConstTensor(const void *in_data, void *out_dat
|
|||
|
||||
int ArithmeticFP16CPUKernel::Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) {
|
||||
int ret = RET_OK;
|
||||
if (in_tensors_[0]->data_type() != kNumberTypeFloat16) {
|
||||
MS_LOG(ERROR) << "data type is not fp16";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (is_opt) {
|
||||
CHECK_NULL_RETURN(arithmetic_opt_func_, RET_ERROR);
|
||||
ret = arithmetic_opt_func_(reinterpret_cast<const float16_t *>(input0), reinterpret_cast<const float16_t *>(input1),
|
||||
|
|
|
@ -60,7 +60,11 @@ int ArithmeticCPUKernel::ReSize() {
|
|||
outside_ *= param_->out_shape_[i];
|
||||
}
|
||||
}
|
||||
return ConstTensorBroadCast();
|
||||
int ret = RET_OK;
|
||||
if (!isScalarClac() && !isBatchScalarCalc() && !isBiasCalc()) {
|
||||
ret = ConstTensorBroadCast();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::CheckDataType() {
|
||||
|
@ -73,6 +77,47 @@ int ArithmeticCPUKernel::CheckDataType() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::isScalarClac() { // 2 32 240 240, 1 1 1 1
|
||||
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::isBatchScalarCalc() { // 2 32 240 240, 2 32 1 1
|
||||
if (arithmetic_opt_run_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
size_t break_axis = 0;
|
||||
for (size_t i = 0; i < param_->ndim_; i++) {
|
||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||
break_axis = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (break_axis < param_->ndim_) {
|
||||
for (size_t i = break_axis; i < param_->ndim_; i++) {
|
||||
if (param_->in_shape1_[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
break_pos_ = break_axis;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::isBiasCalc() { // 2 240 240 32, 1 1 1 32
|
||||
int last_shape0 = param_->in_shape0_[param_->ndim_ - 1];
|
||||
int last_shape1 = param_->in_shape1_[param_->ndim_ - 1];
|
||||
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
|
||||
return param_->in_elements_num1_ == last_shape1 && last_shape0 == last_shape1;
|
||||
} else if (param_->in_elements_num0_ < param_->in_elements_num1_) {
|
||||
return param_->in_elements_num0_ == last_shape0 && last_shape0 == last_shape1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::ConstTensorBroadCast() {
|
||||
/* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */
|
||||
if (!param_->broadcasting_) {
|
||||
|
@ -86,11 +131,6 @@ int ArithmeticCPUKernel::ConstTensorBroadCast() {
|
|||
param_->in_elements_num1_ != param_->out_elements_num_) {
|
||||
return RET_OK;
|
||||
}
|
||||
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && arithmetic_opt_run_ != nullptr) {
|
||||
/* run opt function
|
||||
* one of input is scalar */
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
FreeConstTileBuff();
|
||||
if (in_tensors_[0]->data_c() != nullptr && param_->in_elements_num0_ != param_->out_elements_num_) {
|
||||
|
@ -252,32 +292,6 @@ int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1
|
||||
if (input0_broadcast_ || input1_broadcast_) {
|
||||
return false;
|
||||
}
|
||||
if (param_->in_elements_num0_ == param_->in_elements_num1_ || param_->in_elements_num0_ == 1 ||
|
||||
param_->in_elements_num1_ == 1) {
|
||||
return false;
|
||||
}
|
||||
size_t break_axis = 0;
|
||||
for (size_t i = 0; i < param_->ndim_; i++) {
|
||||
if (param_->in_shape0_[i] != param_->in_shape1_[i]) {
|
||||
break_axis = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (break_axis < param_->ndim_) {
|
||||
for (size_t i = break_axis; i < param_->ndim_; i++) {
|
||||
if (param_->in_shape1_[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
break_pos_ = break_axis;
|
||||
return true;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
|
||||
if (break_pos_ < 1) {
|
||||
return RET_ERROR;
|
||||
|
@ -308,6 +322,40 @@ int ArithmeticCPUKernel::BatchScalarCalc(int task_id) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::BiasCalc(int task_id) {
|
||||
int last_shape = param_->out_shape_[param_->ndim_ - 1];
|
||||
int batch = param_->out_elements_num_ / last_shape;
|
||||
int batch_per_thread = UP_DIV(batch, context_->thread_num_);
|
||||
|
||||
int start_batch = batch_per_thread * task_id;
|
||||
int end_batch = MSMIN(start_batch + batch_per_thread, batch);
|
||||
int batch_size = end_batch - start_batch;
|
||||
|
||||
int stride = last_shape * data_type_len_;
|
||||
int offset = stride * start_batch;
|
||||
int ret = RET_OK;
|
||||
if (param_->in_elements_num0_ > param_->in_elements_num1_) {
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
ret = Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_),
|
||||
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
offset += stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
ret = Execute(static_cast<uint8_t *>(input0_ptr_), static_cast<uint8_t *>(input1_ptr_) + offset,
|
||||
static_cast<uint8_t *>(output_ptr_) + offset, last_shape, false);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
offset += stride;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
||||
auto element_num = out_tensors_[0]->ElementsNum();
|
||||
int stride = UP_DIV(element_num, context_->thread_num_);
|
||||
|
@ -315,13 +363,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
if (count <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
/* run opt function, every batch one of input is scalar */
|
||||
if (CanBatchScalar()) {
|
||||
return BatchScalarCalc(task_id);
|
||||
}
|
||||
int offset = stride * task_id * data_type_len_;
|
||||
/* run opt function, one of input is scalar */
|
||||
if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && arithmetic_opt_run_ != nullptr) {
|
||||
if (isScalarClac()) { // 2 32 240 240, 1 1 1 1
|
||||
if (param_->in_elements_num0_ == 1) {
|
||||
return Execute(input0_ptr_, static_cast<uint8_t *>(input1_ptr_) + offset,
|
||||
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
|
||||
|
@ -330,6 +374,14 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
static_cast<uint8_t *>(output_ptr_) + offset, count, true);
|
||||
}
|
||||
}
|
||||
/* run opt function, every batch one of input is scalar */
|
||||
if (isBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1
|
||||
return BatchScalarCalc(task_id);
|
||||
}
|
||||
/* each batch is eltwise calculation */
|
||||
if (isBiasCalc()) { // 2 240 240 32, 1 1 1 32
|
||||
return BiasCalc(task_id);
|
||||
}
|
||||
/* need broadcast in runtime */
|
||||
if (param_->broadcasting_) {
|
||||
stride = UP_DIV(outside_, context_->thread_num_);
|
||||
|
@ -339,7 +391,7 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|||
}
|
||||
return BroadcastRun(input0_ptr_, input1_ptr_, output_ptr_, 0, out_count, stride * task_id);
|
||||
}
|
||||
/* no broadcast in runtime */
|
||||
/* all elements eltwise calculation */
|
||||
return Execute(static_cast<uint8_t *>(input0_ptr_) + offset, static_cast<uint8_t *>(input1_ptr_) + offset,
|
||||
static_cast<uint8_t *>(output_ptr_) + offset, count, false);
|
||||
}
|
||||
|
|
|
@ -108,9 +108,12 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
int data_type_len_ = sizeof(float);
|
||||
|
||||
private:
|
||||
bool CanBatchScalar();
|
||||
int BatchScalarCalc(int task_id);
|
||||
int BiasCalc(int task_id);
|
||||
void FreeConstTileBuff();
|
||||
bool isScalarClac();
|
||||
bool isBatchScalarCalc();
|
||||
bool isBiasCalc();
|
||||
ArithmeticRun arithmetic_run_ = nullptr;
|
||||
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
|
||||
ArithmeticIntRun arithmetic_run_int_ = nullptr;
|
||||
|
|
Loading…
Reference in New Issue