forked from mindspore-Ecosystem/mindspore
softmax activation fp16
This commit is contained in:
parent
e3899c552c
commit
b6b18e477a
mindspore/lite
|
@ -0,0 +1,98 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp16/activation_fp16.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
int i;
|
||||
for (i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu_src = vld1q_f16(src + index);
|
||||
float16x8_t zero_src = vdupq_n_f16(0);
|
||||
relu_src = vmaxq_f16(relu_src, zero_src);
|
||||
vst1q_f16(dst + index, relu_src);
|
||||
#else
|
||||
int j;
|
||||
for (j = 0; j < C8NUM; j++) {
|
||||
dst[index + j] = src[index + j] < 0 ? 0 : src[index + j];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
dst[j] = src[j] < 0 ? 0 : src[j];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
int i;
|
||||
for (i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu6_data = vld1q_f16(data + index);
|
||||
float16x8_t zero_data = vdupq_n_f16(0);
|
||||
float16x8_t six_data = vdupq_n_f16(6);
|
||||
relu6_data = vmaxq_f16(relu6_data, zero_data);
|
||||
relu6_data = vminq_f16(relu6_data, six_data);
|
||||
vst1q_f16(dst + index, relu6_data);
|
||||
#else
|
||||
int j;
|
||||
for (j = 0; j < C8NUM; ++j) {
|
||||
dst[index + j] = data[index + j] < 0 ? 0 : data[index + j];
|
||||
dst[index + j] = dst[index + j] > 6 ? 6 : dst[index + j];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
dst[j] = data[j] < 0 ? 0 : data[j];
|
||||
dst[j] = dst[j] > 6 ? 6 : dst[j];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) {
|
||||
for (int i = 0; i < ele_num; ++i) {
|
||||
dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
for (int i = 0; i < ele_num; ++i) {
|
||||
dst[i] = (float16_t)1.0f / (float16_t)(1.0f + exp(-src[i]));
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
for (int i = 0; i < ele_num; ++i) {
|
||||
dst[i] = (float16_t)1.0f - (float16_t)2.0f / (float16_t)(exp(2 * src[i]) + 1);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
for (int i = 0; i < ele_num; ++i) {
|
||||
float16_t in = src[i];
|
||||
float16_t relu6 = MSMIN(MSMAX(in + 3, 0), 6);
|
||||
dst[i] = in * relu6 / (float16_t)6.0f;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
|
||||
typedef struct ActivationParameter {
|
||||
OpParameter op_parameter_;
|
||||
int type_;
|
||||
float alpha_;
|
||||
} ActivationParameter;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int ReluFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num);
|
||||
int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha);
|
||||
int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
int TanhFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_
|
|
@ -1,61 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/fp16/common_func.h"
|
||||
|
||||
void ReluFp16(float16_t *data, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
for (int i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu_data = vld1q_f16(data + index);
|
||||
float16x8_t zero_data = vdupq_n_f16(0);
|
||||
relu_data = vmaxq_f16(relu_data, zero_data);
|
||||
vst1q_f16(dst + index, relu_data);
|
||||
#else
|
||||
data[index] = data[index] < 0 ? 0 : data[index];
|
||||
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
|
||||
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
|
||||
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
|
||||
#endif
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
data[j] = data[j] < 0 ? 0 : data[j];
|
||||
}
|
||||
}
|
||||
|
||||
void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
for (int i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu6_data = vld1q_f16(data + index);
|
||||
float16x8_t zero_data = vdupq_n_f16(0);
|
||||
float16x8_t six_data = vdupq_n_f16(6);
|
||||
relu6_data = vmaxq_f16(relu6_data, zero_data);
|
||||
relu6_data = vminq_f16(relu6_data, six_data);
|
||||
vst1q_f16(dst + index, relu6_data);
|
||||
#else
|
||||
for (int j = 0; j < C8NUM; ++j) {
|
||||
data[index + j] = data[index + j] < 0 ? 0 : data[index + j];
|
||||
data[index + j] = data[index + j] > 6 ? 6 : data[index + j];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
data[j] = data[j] < 0 ? 0 : data[j];
|
||||
data[j] = data[j] > 6 ? 6 : data[j];
|
||||
}
|
||||
}
|
|
@ -41,8 +41,6 @@ void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *w
|
|||
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
|
||||
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
|
||||
#endif
|
||||
void ReluFp16(float16_t *data, float16_t *dst, int ele_num);
|
||||
void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp16/softmax_fp16.h"
|
||||
#include <math.h>
|
||||
#include <float.h>
|
||||
|
||||
// output = exp(input) / reduce_sum(exp(input), axis)
|
||||
void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter) {
|
||||
int32_t axis = parameter->axis_;
|
||||
int n_dim = parameter->n_dim_;
|
||||
int ele_size = parameter->element_size_;
|
||||
int *input_shape = parameter->input_shape_;
|
||||
|
||||
float16_t max_data = input_ptr[0];
|
||||
for (int i = 0; i < ele_size; i++) {
|
||||
max_data = max_data > input_ptr[i] ? max_data : input_ptr[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < ele_size; i++) {
|
||||
output_ptr[i] = exp(input_ptr[i] - max_data);
|
||||
}
|
||||
int inner_size = 1, outter_size = 1;
|
||||
for (int i = 0; i < axis; i++) {
|
||||
outter_size *= input_shape[i];
|
||||
}
|
||||
for (int i = axis + 1; i < n_dim; i++) {
|
||||
inner_size *= input_shape[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < outter_size; i++) {
|
||||
int outter_offset = i * input_shape[axis] * inner_size;
|
||||
int sum_outter_offset = i * inner_size;
|
||||
for (int k = 0; k < inner_size; k++) {
|
||||
int inner_offset = outter_offset + k;
|
||||
for (int j = 0; j < input_shape[axis]; j++) {
|
||||
int axis_offset = inner_offset + j * inner_size;
|
||||
sum_data[k + sum_outter_offset] += output_ptr[axis_offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < outter_size; i++) {
|
||||
int outter_offset = i * input_shape[axis] * inner_size;
|
||||
int sum_outter_offset = i * inner_size;
|
||||
for (int j = 0; j < input_shape[axis]; j++) {
|
||||
int axis_offset = outter_offset + j * inner_size;
|
||||
for (int k = 0; k < inner_size; k++) {
|
||||
int inner_offset = axis_offset + k;
|
||||
output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/softmax_parameter.h"
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_
|
|
@ -0,0 +1,156 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp16/activation_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::ActivationType_HSWISH;
|
||||
using mindspore::schema::ActivationType_LEAKY_RELU;
|
||||
using mindspore::schema::ActivationType_RELU;
|
||||
using mindspore::schema::ActivationType_RELU6;
|
||||
using mindspore::schema::PrimitiveType_Activation;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int ActivationFp16CPUKernel::Init() { return RET_OK; }
|
||||
|
||||
int ActivationFp16CPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int ActivationFp16CPUKernel::MallocTmpBuffer() {
|
||||
fp16_input_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_);
|
||||
if (fp16_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc data failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
fp16_output_ = MallocOutputFp16(out_tensors_.at(0), context_);
|
||||
if (fp16_output_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc data failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ActivationFp16CPUKernel::FreeTmpBuffer() {
|
||||
if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
|
||||
if (fp16_input_ != nullptr) {
|
||||
context_->allocator->Free(fp16_input_);
|
||||
fp16_input_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
|
||||
if (fp16_output_ != nullptr) {
|
||||
context_->allocator->Free(fp16_output_);
|
||||
fp16_output_ = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int ActivationFp16CPUKernel::DoActivation(int task_id) {
|
||||
auto length = in_tensors_.at(0)->ElementsNum();
|
||||
|
||||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
|
||||
int error_code;
|
||||
if (type_ == schema::ActivationType_RELU) {
|
||||
error_code = ReluFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count);
|
||||
} else if (type_ == schema::ActivationType_RELU6) {
|
||||
error_code = Relu6Fp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count);
|
||||
} else if (type_ == schema::ActivationType_LEAKY_RELU) {
|
||||
error_code = LReluFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count, alpha_);
|
||||
} else if (type_ == schema::ActivationType_SIGMOID) {
|
||||
error_code = SigmoidFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count);
|
||||
} else if (type_ == schema::ActivationType_TANH) {
|
||||
error_code = TanhFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count);
|
||||
} else if (type_ == schema::ActivationType_HSWISH) {
|
||||
error_code = HSwishFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Activation fp16 not support type: " << type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return error_code;
|
||||
}
|
||||
|
||||
int ActivationRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto activation_kernel = reinterpret_cast<ActivationFp16CPUKernel *>(cdata);
|
||||
auto error_code = activation_kernel->DoActivation(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "ActivationRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ActivationFp16CPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = MallocTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
FreeTmpBuffer();
|
||||
return ret;
|
||||
}
|
||||
|
||||
int error_code = LiteBackendParallelLaunch(ActivationRun, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Activation function error error_code[" << error_code << "]";
|
||||
FreeTmpBuffer();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto out_tensor = out_tensors_.at(0);
|
||||
if (out_tensor->data_type() == kNumberTypeFloat32) {
|
||||
Float16ToFloat32(fp16_output_, reinterpret_cast<float *>(out_tensor->Data()), out_tensor->ElementsNum());
|
||||
}
|
||||
FreeTmpBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuActivationFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Activation);
|
||||
auto *kernel = new (std::nothrow) ActivationFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
delete kernel;
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Activation, CpuActivationFp16KernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ACTIVATION_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ACTIVATION_FP16_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/fp16/activation_fp16.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ActivationFp16CPUKernel : public LiteKernel {
|
||||
public:
|
||||
ActivationFp16CPUKernel(OpParameter *param, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
|
||||
type_ = (reinterpret_cast<ActivationParameter *>(param))->type_;
|
||||
alpha_ = (float16_t)((reinterpret_cast<ActivationParameter *>(param))->alpha_);
|
||||
}
|
||||
~ActivationFp16CPUKernel() override = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoActivation(int task_id);
|
||||
int MallocTmpBuffer();
|
||||
void FreeTmpBuffer();
|
||||
|
||||
private:
|
||||
int thread_count_;
|
||||
int type_;
|
||||
float16_t alpha_;
|
||||
float16_t *fp16_input_ = nullptr;
|
||||
float16_t *fp16_output_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ACTIVATION_FP16_H_
|
|
@ -0,0 +1,156 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string.h>
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/arm/fp16/softmax_fp16.h"
|
||||
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
|
||||
#include "nnacl/fp16/softmax_fp16.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_SoftMax;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int SoftmaxFp16CPUKernel::Init() {
|
||||
auto ret = SoftmaxBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int SoftmaxFp16CPUKernel::ReSize() {
|
||||
return SoftmaxBaseCPUKernel::ReSize();
|
||||
}
|
||||
|
||||
int SoftmaxFp16CPUKernel::MallocTmpBuffer() {
|
||||
auto n_dim = softmax_param_->n_dim_;
|
||||
auto axis = softmax_param_->axis_;
|
||||
if (axis == -1) {
|
||||
softmax_param_->axis_ += n_dim;
|
||||
axis = softmax_param_->axis_;
|
||||
}
|
||||
auto in_shape = in_tensors_.front()->shape();
|
||||
int out_plane_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
out_plane_size *= in_shape[i];
|
||||
}
|
||||
int in_plane_size = 1;
|
||||
for (int i = axis + 1; i < n_dim; i++) {
|
||||
in_plane_size *= in_shape[i];
|
||||
}
|
||||
|
||||
sum_data_ =
|
||||
reinterpret_cast<float16_t *>(context_->allocator->Malloc(out_plane_size * in_plane_size * sizeof(float16_t)));
|
||||
if (sum_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc data for softmax fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(sum_data_, 0, out_plane_size * in_plane_size * sizeof(float16_t));
|
||||
|
||||
input_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(kInputIndex), context_);
|
||||
if (input_fp16_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc data failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_fp16_ = MallocOutputFp16(out_tensors_.at(kOutputIndex), context_);
|
||||
if (output_fp16_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc data failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void SoftmaxFp16CPUKernel::FreeTmpBuffer() {
|
||||
if (sum_data_ != nullptr) {
|
||||
context_->allocator->Free(sum_data_);
|
||||
sum_data_ = nullptr;
|
||||
}
|
||||
if (in_tensors_.at(kInputIndex)->data_type() == kNumberTypeFloat32) {
|
||||
if (input_fp16_ != nullptr) {
|
||||
context_->allocator->Free(input_fp16_);
|
||||
input_fp16_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (out_tensors_.at(kOutputIndex)->data_type() == kNumberTypeFloat32) {
|
||||
if (output_fp16_ != nullptr) {
|
||||
context_->allocator->Free(output_fp16_);
|
||||
output_fp16_ = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int SoftmaxFp16CPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = MallocTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
FreeTmpBuffer();
|
||||
MS_LOG(ERROR) << "MallocTmpBuffer failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
SoftmaxFp16(input_fp16_, output_fp16_, sum_data_, softmax_param_);
|
||||
auto out_tensor = out_tensors_.at(kOutputIndex);
|
||||
if (out_tensor->data_type() == kNumberTypeFloat32) {
|
||||
Float16ToFloat32(output_fp16_, reinterpret_cast<float *>(out_tensor->Data()), out_tensor->ElementsNum());
|
||||
}
|
||||
FreeTmpBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuSoftmaxFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Input opParameter is nullptr!";
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_SoftMax);
|
||||
auto *kernel = new (std::nothrow) SoftmaxFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new SoftmaxFp16CPUKernel fail!";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
delete kernel;
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SoftMax, CpuSoftmaxFp16KernelCreator)
|
||||
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SOFTMAX_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SOFTMAX_FP16_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/softmax_base.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class SoftmaxFp16CPUKernel : public SoftmaxBaseCPUKernel {
|
||||
public:
|
||||
SoftmaxFp16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), sum_data_(nullptr) {}
|
||||
~SoftmaxFp16CPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int MallocTmpBuffer();
|
||||
void FreeTmpBuffer();
|
||||
|
||||
private:
|
||||
float16_t *sum_data_ = nullptr;
|
||||
float16_t *input_fp16_ = nullptr;
|
||||
float16_t *output_fp16_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SOFTMAX_FP16_H_
|
Loading…
Reference in New Issue