[MS][LITE] arm cpu fp32 op: optimize scale op

This commit is contained in:
yangruoqi713 2020-08-25 17:01:30 +08:00
parent 2de216961f
commit 4888300180
3 changed files with 74 additions and 57 deletions

View File

@ -15,35 +15,65 @@
*/
#include "nnacl/scale.h"
#include "nnacl/errorcode.h"
int DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param) {
if (in_data == NULL || out_data == NULL || scale == NULL || offset == NULL || scale_param == NULL) {
return NNACL_ERR;
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
void ScaleInner(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end,
int axis_size, int inner_size) {
for (int out = outer_start; out < outer_end; out++) {
int out_offset = out * axis_size * inner_size;
for (int i = 0; i < axis_size; i++) {
int axis_offset = out_offset + i * inner_size;
int in_index = 0;
#ifdef ENABLE_ARM
for (; in_index < inner_size - 4; in_index += 4) {
int in_offset = axis_offset + in_index;
float32x4_t data = vld1q_f32(in_data + in_offset);
float32x4_t scale_4 = vdupq_n_f32(scale[i]);
float32x4_t offset_4 = vdupq_n_f32(offset[i]);
float32x4_t reslut = vfmaq_f32(offset_4, data, scale_4);
vst1q_f32(out_data + in_offset, reslut);
}
if (scale_param->has_offset_) {
for (int out = task_id; out < scale_param->outer_size_; out += scale_param->op_parameter_.thread_num_) {
int out_offset = out * scale_param->axis_size_ * scale_param->inner_size_;
for (int i = 0; i < scale_param->axis_size_; i++) {
int axis_offset = out_offset + i * scale_param->inner_size_;
for (int in = 0; in < scale_param->inner_size_; in++) {
int in_offset = axis_offset + in;
#endif
for (; in_index < inner_size; in_index++) {
int in_offset = axis_offset + in_index;
out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i];
}
}
}
}
void ScaleAxis(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end,
int axis_size) {
for (int out = outer_start; out < outer_end; out++) {
int out_offset = out * axis_size;
int index = 0;
#ifdef ENABLE_ARM
for (; index < axis_size - 4; index += 4) {
int in_offset = out_offset + index;
float32x4_t data = vld1q_f32(in_data + in_offset);
float32x4_t scale_4 = vld1q_f32(scale + index);
float32x4_t offset_4 = vld1q_f32(offset + index);
float32x4_t reslut = vfmaq_f32(offset_4, data, scale_4);
vst1q_f32(out_data + in_offset, reslut);
}
#endif
for (; index < axis_size; index++) {
int in_offset = out_offset + index;
out_data[in_offset] = in_data[in_offset] * scale[index] + offset[index];
}
}
}
void DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param) {
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
int outer_start = task_id * outer_step;
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
if (scale_param->inner_size_ == 1) {
ScaleAxis(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_);
} else {
for (int out = task_id; out < scale_param->outer_size_; out += scale_param->op_parameter_.thread_num_) {
int out_offset = out * scale_param->axis_size_ * scale_param->inner_size_;
for (int i = 0; i < scale_param->axis_size_; i++) {
int axis_offset = out_offset + i * scale_param->inner_size_;
for (int in = 0; in < scale_param->inner_size_; in++) {
int in_offset = axis_offset + in;
out_data[in_offset] = in_data[in_offset] * scale[i];
ScaleInner(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_,
scale_param->inner_size_);
}
}
}
}
return NNACL_OK;
}

View File

@ -26,13 +26,12 @@ typedef struct ScaleParameter {
int inner_size_;
int axis_;
bool const_scale_;
bool has_offset_;
} ScaleParameter;
#ifdef __cplusplus
extern "C" {
#endif
int DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param);
void DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param);
#ifdef __cplusplus
}
#endif

View File

@ -35,13 +35,11 @@ ScaleCPUKernel::~ScaleCPUKernel() {
scale_ = nullptr;
}
}
if (scale_param_->has_offset_) {
if (offset_ != nullptr) {
free(offset_);
offset_ = nullptr;
}
}
}
int ScaleCPUKernel::InitScaleOffset() {
auto scale_tensor = in_tensors_.at(1);
@ -59,18 +57,15 @@ int ScaleCPUKernel::InitScaleOffset() {
scale_ = nullptr;
}
if (in_tensors_.size() == 3) {
auto offset_tensor = in_tensors_.at(2);
offset_ = reinterpret_cast<float *>(malloc(offset_tensor->ElementsNum() * sizeof(float)));
offset_ = reinterpret_cast<float *>(malloc(scale_param_->axis_size_ * sizeof(float)));
if (offset_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
memset(offset_, 0, scale_param_->axis_size_ * sizeof(float));
if (in_tensors_.size() == 3) {
auto offset_tensor = in_tensors_.at(2);
memcpy(offset_, offset_tensor->Data(), offset_tensor->ElementsNum() * sizeof(float));
scale_param_->has_offset_ = true;
} else {
offset_ = nullptr;
scale_param_->has_offset_ = false;
}
return RET_OK;
}
@ -101,6 +96,7 @@ int ScaleCPUKernel::InitParameter() {
for (size_t i = scale_param_->axis_ + scale_shape.size(); i < in_shape.size(); i++) {
scale_param_->inner_size_ *= in_shape[i];
}
scale_param_->op_parameter_.thread_num_ = MSMIN(scale_param_->op_parameter_.thread_num_, scale_param_->outer_size_);
return RET_OK;
}
@ -114,6 +110,11 @@ int ScaleCPUKernel::Init() {
return RET_OK;
}
ReSize();
return RET_OK;
}
int ScaleCPUKernel::ReSize() {
auto ret = InitParameter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale fp32 InitParameter failed.";
@ -128,21 +129,8 @@ int ScaleCPUKernel::Init() {
return RET_OK;
}
int ScaleCPUKernel::ReSize() {
auto ret = InitParameter();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale fp32 InitParameter failed.";
return RET_ERROR;
}
return RET_OK;
}
int ScaleCPUKernel::Scale(int task_id) {
auto ret = DoScale(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
DoScale(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_);
return RET_OK;
}