!8559 [MSLITE][Develop] optimize_softmax

From: @sunsuodong
Reviewed-by: @zhanghaibo5
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-18 11:22:05 +08:00 committed by Gitee
commit 37a722fe77
7 changed files with 203 additions and 4 deletions

View File

@ -16,6 +16,7 @@
#include "nnacl/fp32/exp_fp32.h"
#include <math.h>
#include <string.h>
#include "nnacl/errorcode.h"
int Exp(const float *input_data, float *output_data, ExpParameter *parameter, int task_id) {
@ -35,3 +36,40 @@ int Exp(const float *input_data, float *output_data, ExpParameter *parameter, in
}
return NNACL_OK;
}
void ExpFp32(const float *src, float *dst, int num) {
int i = 0;
const float param[] = {log(2.0f), 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f};
#ifdef ENABLE_ARM64
float32x4_t maxv = vdupq_n_f32(88.0f);
float32x4_t minv = vdupq_n_f32(-88.0f);
float32x4_t param0 = vdupq_n_f32(log(2.0f));
float32x4_t param1 = vdupq_n_f32(1.0f / 120);
float32x4_t param2 = vdupq_n_f32(1.0f / 24);
float32x4_t param3 = vdupq_n_f32(1.0f / 6);
float32x4_t param4 = vdupq_n_f32(0.5f);
float32x4_t param5 = vdupq_n_f32(1.0f);
for (; i < num - C4NUM; i += C4NUM) {
float32x4_t input4 = vmaxq_f32(minv, vminq_f32(maxv, vld1q_f32(src + i)));
int32x4_t integer4 = vcvtq_s32_f32(vdivq_f32(input4, param0));
float32x4_t decimal4 = vsubq_f32(input4, vmulq_f32(vcvtq_f32_s32(integer4), param0));
int32x4_t int_exp4 = vshlq_s32(vaddq_s32(integer4, vdupq_n_s32(127)), vdupq_n_s32(23));
vst1q_f32(dst + i, vld1q_f32((float32_t *)(&int_exp4)));
float32x4_t decimal_exp4 = vaddq_f32(param2, vmulq_f32(decimal4, param1));
decimal_exp4 = vmulq_f32(decimal4, vaddq_f32(param3, vmulq_f32(decimal4, decimal_exp4)));
decimal_exp4 = vaddq_f32(param5, vmulq_f32(decimal4, vaddq_f32(param4, decimal_exp4)));
decimal_exp4 = vaddq_f32(param5, vmulq_f32(decimal4, decimal_exp4));
vst1q_f32(dst + i, vmulq_f32(vld1q_f32(dst + i), decimal_exp4));
}
#endif
for (; i < num; ++i) {
float input = MSMAX(-88.0f, MSMIN(88.0f, src[i]));
int integer = input / param[0];
float decimal = input - integer * param[0];
int int_exp = (integer + 127) << 23;
memcpy(dst + i, &int_exp, sizeof(float));
float decimal_exp =
1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
dst[i] *= decimal_exp;
}
}

View File

@ -34,6 +34,7 @@ typedef struct ExpParameter {
extern "C" {
#endif
int Exp(const float *input_data, float *output_data, ExpParameter *parameter, int task_id);
void ExpFp32(const float *src, float *dst, int num);
#ifdef __cplusplus
}
#endif

View File

@ -13,9 +13,79 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/softmax_fp32.h"
#include <math.h>
#include "nnacl/fp32/exp_fp32.h"
void SoftmaxNorm(const float *src, float *dst, int batch, int channel) {
int cur_batch_offset = 0;
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
int j = 0;
#ifdef ENABLE_ARM64
float32x4_t max4 = vld1q_f32(src + cur_batch_offset);
j += C4NUM;
for (; j < channel - C4NUM; j += C4NUM) {
float32x4_t input4 = vld1q_f32(src + cur_batch_offset + j);
max4 = vmaxq_f32(max4, input4);
}
float max = channel >= C4NUM ? vmaxvq_f32(max4) : src[cur_batch_offset];
#else
float max = src[cur_batch_offset];
#endif
for (; j < channel; j++) {
float input = src[cur_batch_offset + j];
if (input > max) {
max = input;
}
}
int k = 0;
#ifdef ENABLE_NEON
for (; k < channel - C4NUM; k += C4NUM) {
float32x4_t input4 = vld1q_f32(src + cur_batch_offset + k);
float32x4_t output4 = vsubq_f32(input4, vdupq_n_f32(max));
vst1q_f32(dst + cur_batch_offset + k, output4);
}
#endif
for (; k < channel; k++) {
int offset = cur_batch_offset + k;
dst[offset] = src[offset] - max;
}
}
}
void SumAndDiv(const float *src, float *dst, int batch, int channel) {
int cur_batch_offset = 0;
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
float sum = 0;
int j = 0;
#ifdef ENABLE_NEON
float32x4_t sum4 = vdupq_n_f32(0);
for (; j < channel - C4NUM; j += C4NUM) {
sum4 = vaddq_f32(sum4, vld1q_f32(src + cur_batch_offset + j));
}
sum = sum4[0] + sum4[1] + sum4[2] + sum4[3];
#endif
for (; j < channel; j++) {
sum += src[cur_batch_offset + j];
}
int k = 0;
#ifdef ENABLE_NEON
float div = 1.0f / sum;
for (; k < channel - C4NUM; k += C4NUM) {
vst1q_f32(dst + cur_batch_offset + k, vmulq_n_f32(vld1q_f32(src + cur_batch_offset + k), div));
}
#endif
for (; k < channel; k++) {
dst[cur_batch_offset + k] = src[cur_batch_offset + k] / sum;
}
}
}
void SoftmaxLastAxis(const float *src, float *dst, int batch, int channel) {
SoftmaxNorm(src, dst, batch, channel);
ExpFp32(dst, dst, batch * channel);
SumAndDiv(dst, dst, batch, channel);
}
// output = exp(input) / reduce_sum(exp(input), axis)
void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter) {

View File

@ -23,6 +23,7 @@
extern "C" {
#endif
void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter);
void SoftmaxLastAxis(const float *src, float *dst, int batch, int channel);
#ifdef __cplusplus
}
#endif

View File

@ -70,12 +70,41 @@ int SoftmaxCPUKernel::ReSize() {
return RET_OK;
}
int SoftmaxCPUKernel::Run() {
memset(sum_data_, 0, in_plane_size_ * out_plane_size_ * sizeof(float));
int SoftmaxCPUKernel::DoSoftmaxLastAxis(int task_id) {
int unit = UP_DIV(out_plane_size_, context_->thread_num_);
int begin = task_id * unit;
int end = MSMIN(begin + unit, out_plane_size_);
int channel = softmax_param_->input_shape_[softmax_param_->axis_];
int offset = begin * channel;
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
Softmax(input_ptr, output_ptr, sum_data_, softmax_param_);
SoftmaxLastAxis(input_ptr + offset, output_ptr + offset, end - begin, channel);
return RET_OK;
}
int SoftmaxLastAxisRun(void *cdata, int task_id) {
auto kernel = reinterpret_cast<SoftmaxCPUKernel *>(cdata);
auto ret = kernel->DoSoftmaxLastAxis(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoSoftmaxLastAxis error task_id: " << task_id << ", ret: " << ret;
}
return ret;
}
int SoftmaxCPUKernel::Run() {
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
int ret = RET_OK;
if (in_plane_size_ == 1) {
ret = ParallelLaunch(this->context_->thread_pool_, SoftmaxLastAxisRun, this, context_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "SoftmaxCPUKernel ParallelLaunch failed, ret: " << ret;
}
} else {
memset(sum_data_, 0, in_plane_size_ * out_plane_size_ * sizeof(float));
Softmax(input_ptr, output_ptr, sum_data_, softmax_param_);
}
return ret;
}
} // namespace mindspore::kernel

View File

@ -37,6 +37,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel {
int Init() override;
int ReSize() override;
int Run() override;
int DoSoftmaxLastAxis(int task_id);
private:
float *sum_data_ = nullptr;

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "common/common_test.h"
#include "nnacl/softmax_parameter.h"
#include "mindspore/lite/src/kernel_registry.h"
namespace mindspore {
class TestSoftmaxFp32 : public mindspore::CommonTest {
public:
TestSoftmaxFp32() {}
};
TEST_F(TestSoftmaxFp32, 001) {
lite::Tensor in_tensor(kNumberTypeFloat32, {2, 1, 1, 5});
lite::Tensor out_tensor(kNumberTypeFloat32, {2, 1, 1, 5});
float input_data[] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
float output_data[10] = {0};
in_tensor.set_data(input_data);
out_tensor.set_data(output_data);
std::vector<lite::Tensor *> inputs = {&in_tensor};
std::vector<lite::Tensor *> outputs = {&out_tensor};
SoftmaxParameter parameter = {{}, -1, 10, 4, {2, 1, 1, 5}};
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SoftMax};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
auto ctx = std::make_shared<lite::InnerContext>();
ASSERT_EQ(lite::RET_OK, ctx->Init());
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), ctx.get(), desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto ret = kernel->Run();
EXPECT_EQ(0, ret);
float expect[] = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f};
for (size_t i = 0; i < sizeof(expect) / sizeof(expect[0]); ++i) {
EXPECT_EQ(output_data[i], expect[i]);
}
in_tensor.set_data(nullptr);
out_tensor.set_data(nullptr);
}
} // namespace mindspore