!14088 [MS][LITE][Develop]optimize fp32 prelu

From: @lx0095
Reviewed-by: @zhang_xue_tong
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-03-29 17:28:01 +08:00 committed by Gitee
commit db95ebc79b
4 changed files with 134 additions and 176 deletions

View File

@ -15,104 +15,130 @@
*/
#include "nnacl/fp32/prelu_fp32.h"
void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int plane) {
int plane_tile = plane / TILE_NUM * TILE_NUM;
int channel_num = prelu_param_->channel_num_;
int plane_index = 0;
for (; plane_index < plane_tile; plane_index += TILE_NUM) {
float *in_plane_ptr = input + plane_index * channel_num;
float *out_plane_ptr = output + plane_index * channel_num;
int channel_index = 0;
#if defined(ENABLE_AVX)
MS_FLOAT32X8 zero_value_8 = MS_MOV256_F32(0.0f);
MS_FLOAT32X8 one_value_8 = MS_MOV256_F32(1.0f);
float *negetive_slope_value_8 = prelu_param_->slope_;
int div_channel_c8 = prelu_param_->channel_num_ / C8NUM * C8NUM;
for (; channel_index < div_channel_c8; channel_index += C8NUM) {
MS_FLOAT32X8 slope_value_8 = MS_LD256_F32(negetive_slope_value_8 + channel_index);
LOAD256X8_F32(src, in_plane_ptr + channel_index, channel_num)
PRELU_CALCULATE_256X8(dst, src)
STORE256X8_F32(out_plane_ptr + channel_index, channel_num, dst)
#ifdef ENABLE_ARM64
inline void PRelu4x16(const float *in, float *out, float *cur_slope, size_t step) {
asm volatile(
"mov x10, %[in]\n"
"mov x11, %[out]\n"
"mov x12, %[cur_slope]\n"
"ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12]\n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n"
"fmul v16.4s, v0.4s, v4.4s\n"
"fmul v17.4s, v1.4s, v5.4s\n"
"fmul v18.4s, v2.4s, v6.4s\n"
"fmul v19.4s, v3.4s, v7.4s\n"
"fcmgt v20.4s, v0.4s, #0\n"
"fcmgt v21.4s, v1.4s, #0\n"
"fcmgt v22.4s, v2.4s, #0\n"
"fcmgt v23.4s, v3.4s, #0\n"
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], %[step]\n"
"bif v0.16b, v16.16b, v20.16b\n"
"bif v1.16b, v17.16b, v21.16b\n"
"bif v2.16b, v18.16b, v22.16b\n"
"bif v3.16b, v19.16b, v23.16b\n"
"fmul v8.4s, v24.4s, v4.4s\n"
"fmul v9.4s, v25.4s, v5.4s\n"
"fmul v10.4s, v26.4s, v6.4s\n"
"fmul v11.4s, v27.4s, v7.4s\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n"
"fcmgt v12.4s, v24.4s, #0\n"
"fcmgt v13.4s, v25.4s, #0\n"
"fcmgt v14.4s, v26.4s, #0\n"
"fcmgt v15.4s, v27.4s, #0\n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n"
"bif v24.16b, v8.16b, v12.16b\n"
"bif v25.16b, v9.16b, v13.16b\n"
"bif v26.16b, v10.16b, v14.16b\n"
"bif v27.16b, v11.16b, v15.16b\n"
"fmul v16.4s, v0.4s, v4.4s\n"
"fmul v17.4s, v1.4s, v5.4s\n"
"fmul v18.4s, v2.4s, v6.4s\n"
"fmul v19.4s, v3.4s, v7.4s\n"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], %[step]\n"
"fcmgt v20.4s, v0.4s, #0\n"
"fcmgt v21.4s, v1.4s, #0\n"
"fcmgt v22.4s, v2.4s, #0\n"
"fcmgt v23.4s, v3.4s, #0\n"
"ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10]\n"
"bif v0.16b, v16.16b, v20.16b\n"
"bif v1.16b, v17.16b, v21.16b\n"
"bif v2.16b, v18.16b, v22.16b\n"
"bif v3.16b, v19.16b, v23.16b\n"
"fmul v8.4s, v24.4s, v4.4s\n"
"fmul v9.4s, v25.4s, v5.4s\n"
"fmul v10.4s, v26.4s, v6.4s\n"
"fmul v11.4s, v27.4s, v7.4s\n"
"fcmgt v12.4s, v24.4s, #0\n"
"fcmgt v13.4s, v25.4s, #0\n"
"fcmgt v14.4s, v26.4s, #0\n"
"fcmgt v15.4s, v27.4s, #0\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n"
"bif v24.16b, v8.16b, v12.16b\n"
"bif v25.16b, v9.16b, v13.16b\n"
"bif v26.16b, v10.16b, v14.16b\n"
"bif v27.16b, v11.16b, v15.16b\n"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11]\n"
:
: [ in ] "r"(in), [ out ] "r"(out), [ cur_slope ] "r"(cur_slope), [ step ] "r"(step)
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27");
}
#endif
void PRelu(const float *input, float *output, float *slope, int start, int end, int channel) {
int i = start;
#ifdef ENABLE_ARM64
for (; i < end - 3; i += 4) {
const float *cur_in = input + i * channel;
float *cur_out = output + i * channel;
int j = 0;
for (; j < channel - 15; j += 16) {
const float *in = cur_in + j;
float *out = cur_out + j;
float *cur_slope = slope + j;
size_t step = channel * sizeof(float);
PRelu4x16(in, out, cur_slope, step);
}
for (; j < channel; j++) {
cur_out[j] = (cur_in[j] > 0) ? cur_in[j] : (cur_in[j] * slope[j]);
cur_out[j + channel] = (cur_in[j + channel] > 0) ? cur_in[j + channel] : cur_in[j + channel] * slope[j];
cur_out[j + 2 * channel] =
(cur_in[j + 2 * channel] > 0) ? cur_in[j + 2 * channel] : (cur_in[j + 2 * channel] * slope[j]);
cur_out[j + 3 * channel] =
(cur_in[j + 3 * channel] > 0) ? cur_in[j + 3 * channel] : (cur_in[j + 3 * channel] * slope[j]);
}
}
#endif
for (; i < end; i++) {
const float *cur_in = input + i * channel;
float *cur_out = output + i * channel;
int j = 0;
#if defined(ENABLE_ARM)
for (; j < channel - 3; j += 4) {
MS_FLOAT32X4 in = MS_LDQ_F32(cur_in + j);
MS_FLOAT32X4 s = MS_LDQ_F32(slope + j);
MS_FLOAT32X4 mul = MS_MULQ_F32(in, s);
MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 res = MS_BLENDQ_F32(mul, in, MS_CMPGTQ_F32(in, zero));
MS_STQ_F32(cur_out + j, res);
}
#endif
// note: First AVX processing, then SSE processing on X86 platform
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_FLOAT32X4 zero_value = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 one_value = MS_MOVQ_F32(1.0f);
float *negetive_slope_value = prelu_param_->slope_;
int div_channel = prelu_param_->channel_num_ / C4NUM * C4NUM;
for (; channel_index < div_channel; channel_index += C4NUM) {
MS_FLOAT32X4 slope_value = MS_LDQ_F32(negetive_slope_value + channel_index);
LOAD128X8_F32(src, in_plane_ptr + channel_index, channel_num)
PRELU_CALCULATE_128X8(dst, src)
STORE128X8_F32(out_plane_ptr + channel_index, channel_num, dst)
}
#endif
for (; channel_index < channel_num; channel_index++) {
float *in_c = in_plane_ptr + channel_index;
float *out_c = out_plane_ptr + channel_index;
for (int tile_i = 0; tile_i < TILE_NUM; tile_i++) {
float *in_tile = in_c + tile_i * channel_num;
float *out_tile = out_c + tile_i * channel_num;
const float in_data = in_tile[0];
out_tile[0] = (in_data < 0 ? in_data : 0) * prelu_param_->slope_[channel_index] + (in_data > 0 ? in_data : 0);
for (; j < channel; j++) {
if (cur_in[j] > 0) {
cur_out[j] = cur_in[j];
} else {
cur_out[j] = cur_in[j] * slope[j];
}
}
}
}
for (; plane_index < plane; plane_index++) {
float *in_plane_ptr = input + plane_index * channel_num;
float *out_plane_ptr = output + plane_index * channel_num;
for (int channel_index = 0; channel_index < channel_num; channel_index++) {
const float in_data = in_plane_ptr[channel_index];
out_plane_ptr[channel_index] =
(in_data < 0 ? in_data : 0) * prelu_param_->slope_[channel_index] + (in_data > 0 ? in_data : 0);
void PReluShareChannel(const float *input, float *output, float slope, int start, int end) {
for (int i = start; i < end; i++) {
if (input[i] > 0) {
output[i] = input[i];
} else {
output[i] = input[i] * slope;
}
}
}
void PReluShareChannel(float *input, float *output, const PReluParameter *prelu_param_, int task_id) {
for (int j = task_id; j < prelu_param_->tile_block_; j += prelu_param_->op_parameter_.thread_num_) {
int cal_index;
#if defined(ENABLE_ARM64) || defined(ENABLE_AVX)
cal_index = j * 64;
#else
cal_index = j * 32;
#endif
float *input_ptr = input + cal_index;
float *output_ptr = input + cal_index;
#if defined(ENABLE_AVX)
MS_FLOAT32X8 zero_value_8 = MS_MOV256_F32(0);
MS_FLOAT32X8 one_value_8 = MS_MOV256_F32(1.0f);
MS_FLOAT32X8 slope_value_8 = MS_MOV256_F32(prelu_param_->slope_[0]);
LOAD256X8_F32(src, input_ptr, 8)
PRELU_CALCULATE_256X8(dst, src)
STORE256X8_F32(output_ptr, 8, dst)
#elif defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
MS_FLOAT32X4 zero_value = MS_MOVQ_F32(0);
MS_FLOAT32X4 one_value = MS_MOVQ_F32(1.0f);
MS_FLOAT32X4 slope_value = MS_MOVQ_F32(prelu_param_->slope_[0]);
LOAD128X8_F32(src, input_ptr, 4)
#ifdef ENABLE_ARM64
LOAD128X8_F32(src1, input_ptr + 32, 4)
#endif
PRELU_CALCULATE_128X8(dst, src)
#ifdef ENABLE_ARM64
PRELU_CALCULATE_128X8(dst1, src1)
#endif
STORE128X8_F32(output_ptr, 4, dst)
#ifdef ENABLE_ARM64
STORE128X8_F32(output_ptr + 32, 4, dst1)
#endif
#else
const int cal_per_time = 32;
for (int i = 0; i < cal_per_time; ++i) {
float data = input_ptr[i];
output_ptr[i] = (data < 0 ? data : 0) * prelu_param_->slope_[0] + (data > 0 ? data : 0);
}
#endif
}
}

View File

@ -22,39 +22,11 @@
#ifdef __cplusplus
extern "C" {
#endif
void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int task_id);
void PRelu(const float *input, float *output, float *slope, int start, int end, int channel);
void PReluShareChannel(float *input, float *output, const PReluParameter *prelu_param_, int task_id);
void PReluShareChannel(const float *input, float *output, float slope, int start, int end);
#ifdef __cplusplus
}
#endif
#define PRELU_CALCULATE_256X8(dst, src) \
MS_FLOAT32X8 dst##1 = \
MS_MUL256_F32(src##1, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##1, zero_value_8, 30))); \
MS_FLOAT32X8 dst##2 = \
MS_MUL256_F32(src##2, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##2, zero_value_8, 30))); \
MS_FLOAT32X8 dst##3 = \
MS_MUL256_F32(src##3, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##3, zero_value_8, 30))); \
MS_FLOAT32X8 dst##4 = \
MS_MUL256_F32(src##4, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##4, zero_value_8, 30))); \
MS_FLOAT32X8 dst##5 = \
MS_MUL256_F32(src##5, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##5, zero_value_8, 30))); \
MS_FLOAT32X8 dst##6 = \
MS_MUL256_F32(src##6, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##6, zero_value_8, 30))); \
MS_FLOAT32X8 dst##7 = \
MS_MUL256_F32(src##7, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##7, zero_value_8, 30))); \
MS_FLOAT32X8 dst##8 = \
MS_MUL256_F32(src##8, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##8, zero_value_8, 30)));
#define PRELU_CALCULATE_128X8(dst, src) \
MS_FLOAT32X4 dst##1 = MS_MULQ_F32(src##1, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##1, zero_value))); \
MS_FLOAT32X4 dst##2 = MS_MULQ_F32(src##2, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##2, zero_value))); \
MS_FLOAT32X4 dst##3 = MS_MULQ_F32(src##3, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##3, zero_value))); \
MS_FLOAT32X4 dst##4 = MS_MULQ_F32(src##4, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##4, zero_value))); \
MS_FLOAT32X4 dst##5 = MS_MULQ_F32(src##5, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##5, zero_value))); \
MS_FLOAT32X4 dst##6 = MS_MULQ_F32(src##6, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##6, zero_value))); \
MS_FLOAT32X4 dst##7 = MS_MULQ_F32(src##7, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##7, zero_value))); \
MS_FLOAT32X4 dst##8 = MS_MULQ_F32(src##8, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##8, zero_value)));
#endif // MINDSPORE_LITE_NNACL_FP32_PRELU_H_

View File

@ -27,8 +27,7 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_PReLUFusion;
namespace mindspore::kernel {
namespace {
int PReluRun(void *cdata, int task_id) {
static int PReluRun(void *cdata, int task_id) {
auto PRelu = reinterpret_cast<PReluCPUKernel *>(cdata);
auto ret = PRelu->DoExcute(task_id);
if (ret != RET_OK) {
@ -37,7 +36,6 @@ int PReluRun(void *cdata, int task_id) {
}
return RET_OK;
}
} // namespace
int PReluCPUKernel::Init() {
if (in_tensors_[1]->ElementsNum() == 1) {
@ -52,26 +50,22 @@ int PReluCPUKernel::Init() {
}
int PReluCPUKernel::DoExcute(int task_id) {
int thread_num = prelu_param_->op_parameter_.thread_num_;
if (prelu_param_->channelShared) {
PReluShareChannel(input_data_, output_data_, prelu_param_, task_id);
int step = UP_DIV(prelu_param_->input_num_, thread_num);
int start = task_id * step;
int end = MSMIN(start + step, prelu_param_->input_num_);
PReluShareChannel(input_data_, output_data_, prelu_param_->slope_[0], start, end);
} else {
int res_plane = prelu_param_->input_num_ - task_id * prelu_param_->tile_block_;
int plane = MSMIN(prelu_param_->tile_block_, res_plane);
if (plane <= 0) {
return RET_OK;
}
float *in = input_data_ + task_id * prelu_param_->tile_block_ * prelu_param_->channel_num_;
float *out = output_data_ + task_id * prelu_param_->tile_block_ * prelu_param_->channel_num_;
PRelu(in, out, prelu_param_, plane);
int step = UP_DIV(prelu_param_->tile_block_, thread_num);
int start = task_id * step;
int end = MSMIN(start + step, prelu_param_->tile_block_);
PRelu(input_data_, output_data_, prelu_param_->slope_, start, end, prelu_param_->channel_num_);
}
return RET_OK;
}
int PReluCPUKernel::ReSize() {
if (prelu_param_->channelShared) {
return RET_OK;
}
auto input_tensor = in_tensors_.at(0);
auto in_shape = input_tensor->shape();
auto n_dim = in_shape.size();
@ -81,46 +75,19 @@ int PReluCPUKernel::ReSize() {
input_plane *= in_shape.at(i);
}
prelu_param_->input_num_ = input_plane;
prelu_param_->tile_block_ = UP_DIV(UP_DIV(input_plane, TILE_NUM), op_parameter_->thread_num_) * TILE_NUM;
prelu_param_->input_num_ = input_plane * channel_num;
prelu_param_->tile_block_ = input_plane;
prelu_param_->channel_num_ = channel_num;
return RET_OK;
}
int PReluCPUKernel::ProcessShareChannelInput() {
auto input_tensor = in_tensors_.at(0);
prelu_param_->input_num_ = input_tensor->ElementsNum();
int tile = 32;
#if defined(ENABLE_ARM64) || defined(ENABLE_AVX)
tile = 64;
#endif
prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, tile);
input_data_ =
reinterpret_cast<float *>(context_->allocator->Malloc(prelu_param_->tile_block_ * tile * sizeof(float)));
if (input_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input_data_ failed.";
return RET_ERROR;
}
memcpy(input_data_, ori_input_, prelu_param_->input_num_ * sizeof(float));
return RET_OK;
}
int PReluCPUKernel::Run() {
MS_ASSERT(in_tensors_.size() >= 2);
auto input_tensor = in_tensors_[0];
ori_input_ = reinterpret_cast<float *>(input_tensor->data_c());
input_data_ = reinterpret_cast<float *>(input_tensor->data_c());
output_data_ = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c());
MS_ASSERT(ori_input_);
MS_ASSERT(input_data_);
MS_ASSERT(output_data_);
if (prelu_param_->channelShared) {
auto ret = ProcessShareChannelInput();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ProcessShareChannel failed.";
return ret;
}
} else {
input_data_ = ori_input_;
}
// negative slope tensor
auto negative_slope_tensor = in_tensors_.at(1);
@ -129,14 +96,9 @@ int PReluCPUKernel::Run() {
auto ret = ParallelLaunch(this->context_->thread_pool_, PReluRun, this, prelu_param_->op_parameter_.thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "PRelu Run error: error_code[" << ret << "]";
context_->allocator->Free(input_data_);
return RET_ERROR;
}
if (prelu_param_->channelShared) {
memcpy(output_data_, input_data_, prelu_param_->input_num_ * sizeof(float));
context_->allocator->Free(input_data_);
}
return RET_OK;
}

View File

@ -35,11 +35,9 @@ class PReluCPUKernel : public LiteKernel {
int ReSize() override;
int Run() override;
int DoExcute(int task_id);
int ProcessShareChannelInput();
private:
PReluParameter *prelu_param_;
float *ori_input_ = nullptr;
float *input_data_ = nullptr;
float *output_data_ = nullptr;
};