!13581 [ms][lite][cpu] master softmax exp erf fp16 32 算子优化

From: @lzkcode
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2021-03-22 09:08:04 +08:00 committed by Gitee
commit fbdd876396
13 changed files with 265 additions and 44 deletions

View File

@ -118,3 +118,10 @@ int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size)
}
return NNACL_OK;
}
int ElementErfFp16(float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = erff(input[i]);
}
return NNACL_OK;
}

View File

@ -50,6 +50,8 @@ int ElementCeilFp16(float16_t *input, float16_t *output, int number);
int ElementNegativeFp16(float16_t *input, float16_t *output, int element_size);
int ElementReciprocalFp16(float16_t *input, float16_t *output, int element_size);
int ElementErfFp16(float16_t *input, float16_t *output, int element_size);
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/exp_fp16.h"
#include <math.h>
#include <string.h>
#include "nnacl/errorcode.h"
void ExpFp16(const float16_t *src, float16_t *dst, int num) {
int i = 0;
#ifdef ENABLE_ARM64
int count = (num / C8NUM) * C8NUM;
for (; i < count; i += C8NUM) {
simd_exp_fp16(vld1q_f16(src + i), dst + i);
}
#endif
for (; i < num; ++i) {
single_exp_fp16(src[i], dst + i);
}
}

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_EXP_H_
#define MINDSPORE_LITE_NNACL_FP16_EXP_H_
#include "nnacl/op_base.h"
#ifdef __cplusplus
extern "C" {
#endif
void ExpFp16(const float16_t *src, float16_t *dst, int num);
#if defined(ENABLE_ARM64)
static inline float32x4_t exp_fp32(float32x4_t input) {
static float32x4_t param[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f},
{1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120},
{1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24},
{1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6},
{0.5f, 0.5f, 0.5f, 0.5f},
{1.0f, 1.0f, 1.0f, 1.0f}};
int32x4_t integer = vcvtq_s32_f32(input / param[0]);
float32x4_t decimal = input - vcvtq_f32_s32(integer) * param[0];
int32x4_t int_exp = vshlq_s32((integer + vmovq_n_s32(127)), vmovq_n_s32(23));
float32x4_t decimal_exp =
param[5] +
decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
decimal_exp = decimal_exp * vld1q_f32((float *)(&int_exp));
return decimal_exp;
}
static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) {
static float16x8_t maxv = {88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f};
static float16x8_t minv = {-88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f};
input = vmaxq_f16(minv, vminq_f16(input, maxv));
float32x4_t input_low = vcvt_f32_f16(vget_low_f16(input));
float32x4_t input_high = vcvt_high_f32_f16(input);
vst1q_f16(dst, vcombine_f16(vcvt_f16_f32(exp_fp32(input_low)), vcvt_f16_f32(exp_fp32(input_high))));
}
#endif
static inline void single_exp_fp16(float16_t src, float16_t *dst) {
static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f};
src = MSMAX(-88.0f, MSMIN(88.0f, src));
int integer = (float)src / param[0];
float decimal = (float)src - integer * param[0];
int int_exp = (integer + 127) << 23;
float decimal_exp =
1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
*dst = (float16_t)(*((float *)&int_exp) * decimal_exp);
}
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP16_EXP_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,7 +16,79 @@
#include "nnacl/fp16/softmax_fp16.h"
#include <math.h>
#include <float.h>
#include "nnacl/fp16/exp_fp16.h"
void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channel) {
int cur_batch_offset = 0;
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
int j = 0;
#ifdef ENABLE_ARM64
float16x8_t max_8 = vdupq_n_f16(-FLT16_MAX);
int count = (channel / C8NUM) * C8NUM;
for (; j < count; j += C8NUM) {
float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + j);
max_8 = vmaxq_f16(max_8, input_8);
}
float16_t max = vmaxvq_f16(max_8);
#else
float16_t max = -FLT_MAX;
#endif
for (; j < channel; j++) {
float16_t input = src[cur_batch_offset + j];
if (input > max) {
max = input;
}
}
int k = 0;
#ifdef ENABLE_NEON
int count2 = (channel / C8NUM) * C8NUM;
for (; k < count2; k += C8NUM) {
float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k);
float16x8_t output_8 = vsubq_f16(input_8, vdupq_n_f16(max));
vst1q_f16(dst + cur_batch_offset + k, output_8);
}
#endif
for (; k < channel; k++) {
int offset = cur_batch_offset + k;
dst[offset] = src[offset] - max;
}
}
}
void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) {
int cur_batch_offset = 0;
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
float16_t sum = 0;
int j = 0;
#ifdef ENABLE_NEON
float16x8_t sum8 = vdupq_n_f16(0);
int count = (channel / C8NUM) * C8NUM;
for (; j < count; j += C8NUM) {
sum8 = vaddq_f16(sum8, vld1q_f16(src + cur_batch_offset + j));
}
sum = sum8[0] + sum8[1] + sum8[2] + sum8[3] + sum8[4] + sum8[5] + sum8[6] + sum8[7];
#endif
for (; j < channel; j++) {
sum += src[cur_batch_offset + j];
}
int k = 0;
#ifdef ENABLE_NEON
const float16_t div = 1.0f / sum;
for (; k < count; k += C8NUM) {
vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div));
}
#endif
for (; k < channel; k++) {
dst[cur_batch_offset + k] = src[cur_batch_offset + k] / sum;
}
}
}
void SoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, int batch, int channel) {
SoftmaxNormFp16(src, dst, batch, channel);
ExpFp16(dst, dst, batch * channel);
SumAndDivFp16(dst, dst, batch, channel);
}
// output = exp(input) / reduce_sum(exp(input), axis)
void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@
extern "C" {
#endif
void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter);
void SoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, int batch, int channel);
#ifdef __cplusplus
}
#endif

View File

@ -53,11 +53,10 @@ static inline void simd_exp(MS_FLOAT32X4 input, float *dst) {
MS_INT32X4 integer = MS_CVTQPS_EPI32(input / param[0]);
MS_FLOAT32X4 decimal = input - MS_CVTQEPI32_PS(integer) * param[0];
MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(127)), 23);
memcpy(dst, &int_exp, sizeof(int32_t) * 4);
MS_FLOAT32X4 decimal_exp =
param[5] +
decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
MS_STQ_F32(dst, decimal_exp * MS_LDQ_F32(dst));
MS_STQ_F32(dst, decimal_exp * MS_LDQ_F32((float *)(&int_exp)));
}
#endif
@ -76,11 +75,10 @@ static inline void simd_exp_avx(MS_FLOAT32X8 input, float *dst) {
MS_INT32X8 integer = MS_CVT256PS_EPI32(input / param[0]);
MS_FLOAT32X8 decimal = input - MS_CVT256EPI32_PS(integer) * param[0];
MS_INT32X8 int_exp = MS_SLLI256_EPI32(MS_ADD256_EPI32(integer, MS_MOV256_EPI32(127)), 23);
memcpy(dst, &int_exp, sizeof(int32_t) * 8);
MS_FLOAT32X8 decimal_exp =
param[5] +
decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
MS_ST256_F32(dst, decimal_exp * MS_LD256_F32(dst));
MS_ST256_F32(dst, decimal_exp * MS_LD256_F32((float *)(&int_exp)));
}
#endif
@ -90,10 +88,10 @@ static inline void single_exp(float src, float *dst) {
int integer = src / param[0];
float decimal = src - integer * param[0];
int int_exp = (integer + 127) << 23;
memcpy(dst, &int_exp, sizeof(float));
const float decimal_exp =
float decimal_exp =
1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
*dst *= decimal_exp;
float *ptr = (float *)&int_exp;
*dst = *ptr * decimal_exp;
}
#ifdef __cplusplus
}

View File

@ -22,14 +22,21 @@ void SoftmaxNorm(const float *src, float *dst, int batch, int channel) {
int cur_batch_offset = 0;
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
int j = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
float32x4_t max4 = vdupq_n_f32(-FLT_MAX);
int count = (channel / C4NUM) * C4NUM;
for (; j < count; j += C4NUM) {
float32x4_t input4 = vld1q_f32(src + cur_batch_offset + j);
max4 = vmaxq_f32(max4, input4);
}
#ifdef ENABLE_ARM64
float max = vmaxvq_f32(max4);
#else
float max = max4[0];
for (int m = 1; m < 4; ++m) {
max = MSMAX(max, max4[m]);
}
#endif
#else
float max = -FLT_MAX;
#endif
@ -66,7 +73,11 @@ void SumAndDiv(const float *src, float *dst, int batch, int channel) {
for (; j < count; j += C4NUM) {
sum4 = vaddq_f32(sum4, vld1q_f32(src + cur_batch_offset + j));
}
#ifdef ENABLE_ARM64
sum = vaddvq_f32(sum4);
#else
sum = sum4[0] + sum4[1] + sum4[2] + sum4[3];
#endif
#endif
for (; j < channel; j++) {
sum += src[cur_batch_offset + j];

View File

@ -61,6 +61,7 @@
#define kInputSize2 3
#define MAX_AXIS_SIZE 6
#define MAX_LEN 256
#define FLT16_MAX 65504
typedef enum LiteDataType {
kDataTypeFloat,

View File

@ -43,7 +43,8 @@ ArithmeticSelfFp16Func ArithmeticSelfFp16CPUKernel::GetArithmeticSelfFp16Fun(int
{mindspore::schema::PrimitiveType_Ceil, ElementCeilFp16},
{mindspore::schema::PrimitiveType_Round, ElementRoundFp16},
{mindspore::schema::PrimitiveType_Neg, ElementNegativeFp16},
{mindspore::schema::PrimitiveType_Reciprocal, ElementReciprocalFp16}};
{mindspore::schema::PrimitiveType_Reciprocal, ElementReciprocalFp16},
{mindspore::schema::PrimitiveType_Erf, ElementErfFp16}};
for (size_t i = 0; i < sizeof(type_func_table) / sizeof(TYPE_FUNC_INFO); i++) {
if (type_func_table[i].primitive_type_ == primitive_type) {
return type_func_table[i].func_;
@ -98,4 +99,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Ceil, LiteKernelCreator<Arith
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Round, LiteKernelCreator<ArithmeticSelfFp16CPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Neg, LiteKernelCreator<ArithmeticSelfFp16CPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Reciprocal, LiteKernelCreator<ArithmeticSelfFp16CPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Erf, LiteKernelCreator<ArithmeticSelfFp16CPUKernel>)
} // namespace mindspore::kernel

View File

@ -43,59 +43,74 @@ int SoftmaxFp16CPUKernel::Init() {
return ReSize();
}
int SoftmaxFp16CPUKernel::ReSize() { return SoftmaxBaseCPUKernel::ReSize(); }
int SoftmaxFp16CPUKernel::MallocTmpBuffer() {
int SoftmaxFp16CPUKernel::ReSize() {
auto ret = SoftmaxBaseCPUKernel::ReSize();
if (ret != RET_OK) {
return ret;
}
auto n_dim = softmax_param_->n_dim_;
auto axis = softmax_param_->axis_;
if (axis == -1) {
softmax_param_->axis_ += n_dim;
axis = softmax_param_->axis_;
}
auto in_shape = in_tensors_.front()->shape();
int out_plane_size = 1;
out_plane_size_ = 1;
for (int i = 0; i < axis; ++i) {
out_plane_size *= in_shape[i];
out_plane_size_ *= in_shape[i];
}
int in_plane_size = 1;
in_plane_size_ = 1;
for (int i = axis + 1; i < n_dim; i++) {
in_plane_size *= in_shape[i];
in_plane_size_ *= in_shape[i];
}
sum_data_ =
reinterpret_cast<float16_t *>(context_->allocator->Malloc(out_plane_size * in_plane_size * sizeof(float16_t)));
if (sum_data_ != nullptr) {
free(sum_data_);
}
sum_data_ = reinterpret_cast<float16_t *>(malloc(out_plane_size_ * in_plane_size_ * sizeof(float16_t)));
if (sum_data_ == nullptr) {
MS_LOG(ERROR) << "malloc data for softmax fail!";
return RET_ERROR;
}
memset(sum_data_, 0, out_plane_size * in_plane_size * sizeof(float16_t));
return RET_OK;
}
void SoftmaxFp16CPUKernel::FreeTmpBuffer() {
if (sum_data_ != nullptr) {
context_->allocator->Free(sum_data_);
sum_data_ = nullptr;
int SoftmaxFp16CPUKernel::DoSoftmaxLastAxis(int task_id) {
int unit = UP_DIV(out_plane_size_, context_->thread_num_);
int begin = task_id * unit;
int end = MSMIN(begin + unit, out_plane_size_);
int channel = softmax_param_->input_shape_[softmax_param_->axis_];
int offset = begin * channel;
auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_.at(kInputIndex)->MutableData());
auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_.at(kOutputIndex)->MutableData());
SoftmaxLastAxisFp16(input_ptr + offset, output_ptr + offset, end - begin, channel);
return RET_OK;
}
int SoftmaxLastAxisFp16Run(void *cdata, int task_id) {
auto kernel = reinterpret_cast<SoftmaxFp16CPUKernel *>(cdata);
auto ret = kernel->DoSoftmaxLastAxis(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoSoftmaxLastAxisFp16 error task_id: " << task_id << ", ret: " << ret;
}
return ret;
}
int SoftmaxFp16CPUKernel::Run() {
auto ret = MallocTmpBuffer();
if (ret != RET_OK) {
FreeTmpBuffer();
MS_LOG(ERROR) << "MallocTmpBuffer failed";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(0);
MS_ASSERT(input_tensor);
auto output_tensor = out_tensors_.at(0);
MS_ASSERT(output_tensor);
input_fp16_ = reinterpret_cast<float16_t *>(input_tensor->data_c());
MS_ASSERT(input_fp16_);
output_fp16_ = reinterpret_cast<float16_t *>(output_tensor->data_c());
SoftmaxFp16(input_fp16_, output_fp16_, sum_data_, softmax_param_);
FreeTmpBuffer();
MS_ASSERT(output_fp16_);
if (in_plane_size_ == 1) {
auto ret = ParallelLaunch(this->context_->thread_pool_, SoftmaxLastAxisFp16Run, this, context_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SoftmaxFp16CPUKernel ParallelLaunch failed, ret: " << ret;
}
return ret;
} else {
MS_ASSERT(sum_data_);
memset(sum_data_, 0, out_plane_size_ * in_plane_size_ * sizeof(float16_t));
SoftmaxFp16(input_fp16_, output_fp16_, sum_data_, softmax_param_);
}
return RET_OK;
}

View File

@ -28,18 +28,25 @@ class SoftmaxFp16CPUKernel : public SoftmaxBaseCPUKernel {
SoftmaxFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx), sum_data_(nullptr) {}
~SoftmaxFp16CPUKernel() = default;
~SoftmaxFp16CPUKernel() override {
if (sum_data_ != nullptr) {
free(sum_data_);
}
}
int Init() override;
int ReSize() override;
int Run() override;
int MallocTmpBuffer();
void FreeTmpBuffer();
int DoSoftmaxLastAxis(int task_id);
private:
float16_t *sum_data_ = nullptr;
float16_t *input_fp16_ = nullptr;
float16_t *output_fp16_ = nullptr;
int in_plane_size_ = 0;
int out_plane_size_ = 0;
};
} // namespace mindspore::kernel

View File

@ -62,3 +62,5 @@ ml_video_edit_oneclick_adaptis.pb 3
# Q_hand_0812.pb is not suitable for float16. Out of float16 range.
Q_hand_0812.pb
tacotron_encoder_stf.pb 5;1:1,62:1,62:1,62:1,62
Q_inception-249970-672-11-16.pb 1
Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid.pb 1