forked from mindspore-Ecosystem/mindspore
!8559 [MSLITE][Develop] optimize_softmax
From: @sunsuodong Reviewed-by: @zhanghaibo5 Signed-off-by:
This commit is contained in:
commit
37a722fe77
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "nnacl/fp32/exp_fp32.h"
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int Exp(const float *input_data, float *output_data, ExpParameter *parameter, int task_id) {
|
||||
|
@ -35,3 +36,40 @@ int Exp(const float *input_data, float *output_data, ExpParameter *parameter, in
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void ExpFp32(const float *src, float *dst, int num) {
|
||||
int i = 0;
|
||||
const float param[] = {log(2.0f), 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f};
|
||||
#ifdef ENABLE_ARM64
|
||||
float32x4_t maxv = vdupq_n_f32(88.0f);
|
||||
float32x4_t minv = vdupq_n_f32(-88.0f);
|
||||
float32x4_t param0 = vdupq_n_f32(log(2.0f));
|
||||
float32x4_t param1 = vdupq_n_f32(1.0f / 120);
|
||||
float32x4_t param2 = vdupq_n_f32(1.0f / 24);
|
||||
float32x4_t param3 = vdupq_n_f32(1.0f / 6);
|
||||
float32x4_t param4 = vdupq_n_f32(0.5f);
|
||||
float32x4_t param5 = vdupq_n_f32(1.0f);
|
||||
for (; i < num - C4NUM; i += C4NUM) {
|
||||
float32x4_t input4 = vmaxq_f32(minv, vminq_f32(maxv, vld1q_f32(src + i)));
|
||||
int32x4_t integer4 = vcvtq_s32_f32(vdivq_f32(input4, param0));
|
||||
float32x4_t decimal4 = vsubq_f32(input4, vmulq_f32(vcvtq_f32_s32(integer4), param0));
|
||||
int32x4_t int_exp4 = vshlq_s32(vaddq_s32(integer4, vdupq_n_s32(127)), vdupq_n_s32(23));
|
||||
vst1q_f32(dst + i, vld1q_f32((float32_t *)(&int_exp4)));
|
||||
float32x4_t decimal_exp4 = vaddq_f32(param2, vmulq_f32(decimal4, param1));
|
||||
decimal_exp4 = vmulq_f32(decimal4, vaddq_f32(param3, vmulq_f32(decimal4, decimal_exp4)));
|
||||
decimal_exp4 = vaddq_f32(param5, vmulq_f32(decimal4, vaddq_f32(param4, decimal_exp4)));
|
||||
decimal_exp4 = vaddq_f32(param5, vmulq_f32(decimal4, decimal_exp4));
|
||||
vst1q_f32(dst + i, vmulq_f32(vld1q_f32(dst + i), decimal_exp4));
|
||||
}
|
||||
#endif
|
||||
for (; i < num; ++i) {
|
||||
float input = MSMAX(-88.0f, MSMIN(88.0f, src[i]));
|
||||
int integer = input / param[0];
|
||||
float decimal = input - integer * param[0];
|
||||
int int_exp = (integer + 127) << 23;
|
||||
memcpy(dst + i, &int_exp, sizeof(float));
|
||||
float decimal_exp =
|
||||
1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
|
||||
dst[i] *= decimal_exp;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ typedef struct ExpParameter {
|
|||
extern "C" {
|
||||
#endif
|
||||
int Exp(const float *input_data, float *output_data, ExpParameter *parameter, int task_id);
|
||||
void ExpFp32(const float *src, float *dst, int num);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -13,9 +13,79 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/softmax_fp32.h"
|
||||
#include <math.h>
|
||||
#include "nnacl/fp32/exp_fp32.h"
|
||||
|
||||
void SoftmaxNorm(const float *src, float *dst, int batch, int channel) {
|
||||
int cur_batch_offset = 0;
|
||||
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
|
||||
int j = 0;
|
||||
#ifdef ENABLE_ARM64
|
||||
float32x4_t max4 = vld1q_f32(src + cur_batch_offset);
|
||||
j += C4NUM;
|
||||
for (; j < channel - C4NUM; j += C4NUM) {
|
||||
float32x4_t input4 = vld1q_f32(src + cur_batch_offset + j);
|
||||
max4 = vmaxq_f32(max4, input4);
|
||||
}
|
||||
float max = channel >= C4NUM ? vmaxvq_f32(max4) : src[cur_batch_offset];
|
||||
#else
|
||||
float max = src[cur_batch_offset];
|
||||
#endif
|
||||
for (; j < channel; j++) {
|
||||
float input = src[cur_batch_offset + j];
|
||||
if (input > max) {
|
||||
max = input;
|
||||
}
|
||||
}
|
||||
int k = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
for (; k < channel - C4NUM; k += C4NUM) {
|
||||
float32x4_t input4 = vld1q_f32(src + cur_batch_offset + k);
|
||||
float32x4_t output4 = vsubq_f32(input4, vdupq_n_f32(max));
|
||||
vst1q_f32(dst + cur_batch_offset + k, output4);
|
||||
}
|
||||
#endif
|
||||
for (; k < channel; k++) {
|
||||
int offset = cur_batch_offset + k;
|
||||
dst[offset] = src[offset] - max;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SumAndDiv(const float *src, float *dst, int batch, int channel) {
|
||||
int cur_batch_offset = 0;
|
||||
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
|
||||
float sum = 0;
|
||||
int j = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t sum4 = vdupq_n_f32(0);
|
||||
for (; j < channel - C4NUM; j += C4NUM) {
|
||||
sum4 = vaddq_f32(sum4, vld1q_f32(src + cur_batch_offset + j));
|
||||
}
|
||||
sum = sum4[0] + sum4[1] + sum4[2] + sum4[3];
|
||||
#endif
|
||||
for (; j < channel; j++) {
|
||||
sum += src[cur_batch_offset + j];
|
||||
}
|
||||
int k = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float div = 1.0f / sum;
|
||||
for (; k < channel - C4NUM; k += C4NUM) {
|
||||
vst1q_f32(dst + cur_batch_offset + k, vmulq_n_f32(vld1q_f32(src + cur_batch_offset + k), div));
|
||||
}
|
||||
#endif
|
||||
for (; k < channel; k++) {
|
||||
dst[cur_batch_offset + k] = src[cur_batch_offset + k] / sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SoftmaxLastAxis(const float *src, float *dst, int batch, int channel) {
|
||||
SoftmaxNorm(src, dst, batch, channel);
|
||||
ExpFp32(dst, dst, batch * channel);
|
||||
SumAndDiv(dst, dst, batch, channel);
|
||||
}
|
||||
|
||||
// output = exp(input) / reduce_sum(exp(input), axis)
|
||||
void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter) {
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, SoftmaxParameter *parameter);
|
||||
void SoftmaxLastAxis(const float *src, float *dst, int batch, int channel);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -70,12 +70,41 @@ int SoftmaxCPUKernel::ReSize() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int SoftmaxCPUKernel::Run() {
|
||||
memset(sum_data_, 0, in_plane_size_ * out_plane_size_ * sizeof(float));
|
||||
int SoftmaxCPUKernel::DoSoftmaxLastAxis(int task_id) {
|
||||
int unit = UP_DIV(out_plane_size_, context_->thread_num_);
|
||||
int begin = task_id * unit;
|
||||
int end = MSMIN(begin + unit, out_plane_size_);
|
||||
int channel = softmax_param_->input_shape_[softmax_param_->axis_];
|
||||
int offset = begin * channel;
|
||||
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData());
|
||||
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
|
||||
Softmax(input_ptr, output_ptr, sum_data_, softmax_param_);
|
||||
SoftmaxLastAxis(input_ptr + offset, output_ptr + offset, end - begin, channel);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SoftmaxLastAxisRun(void *cdata, int task_id) {
|
||||
auto kernel = reinterpret_cast<SoftmaxCPUKernel *>(cdata);
|
||||
auto ret = kernel->DoSoftmaxLastAxis(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoSoftmaxLastAxis error task_id: " << task_id << ", ret: " << ret;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
int SoftmaxCPUKernel::Run() {
|
||||
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->MutableData());
|
||||
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
|
||||
int ret = RET_OK;
|
||||
if (in_plane_size_ == 1) {
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, SoftmaxLastAxisRun, this, context_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SoftmaxCPUKernel ParallelLaunch failed, ret: " << ret;
|
||||
}
|
||||
} else {
|
||||
memset(sum_data_, 0, in_plane_size_ * out_plane_size_ * sizeof(float));
|
||||
Softmax(input_ptr, output_ptr, sum_data_, softmax_param_);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -37,6 +37,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel {
|
|||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoSoftmaxLastAxis(int task_id);
|
||||
|
||||
private:
|
||||
float *sum_data_ = nullptr;
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "nnacl/softmax_parameter.h"
|
||||
#include "mindspore/lite/src/kernel_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestSoftmaxFp32 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestSoftmaxFp32() {}
|
||||
};
|
||||
|
||||
TEST_F(TestSoftmaxFp32, 001) {
|
||||
lite::Tensor in_tensor(kNumberTypeFloat32, {2, 1, 1, 5});
|
||||
lite::Tensor out_tensor(kNumberTypeFloat32, {2, 1, 1, 5});
|
||||
float input_data[] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
|
||||
float output_data[10] = {0};
|
||||
in_tensor.set_data(input_data);
|
||||
out_tensor.set_data(output_data);
|
||||
std::vector<lite::Tensor *> inputs = {&in_tensor};
|
||||
std::vector<lite::Tensor *> outputs = {&out_tensor};
|
||||
|
||||
SoftmaxParameter parameter = {{}, -1, 10, 4, {2, 1, 1, 5}};
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_SoftMax};
|
||||
|
||||
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator, nullptr);
|
||||
|
||||
auto ctx = std::make_shared<lite::InnerContext>();
|
||||
ASSERT_EQ(lite::RET_OK, ctx->Init());
|
||||
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(¶meter), ctx.get(), desc, nullptr);
|
||||
ASSERT_NE(kernel, nullptr);
|
||||
|
||||
auto ret = kernel->Run();
|
||||
EXPECT_EQ(0, ret);
|
||||
|
||||
float expect[] = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f};
|
||||
for (size_t i = 0; i < sizeof(expect) / sizeof(expect[0]); ++i) {
|
||||
EXPECT_EQ(output_data[i], expect[i]);
|
||||
}
|
||||
in_tensor.set_data(nullptr);
|
||||
out_tensor.set_data(nullptr);
|
||||
}
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue