!12186 Added Fp16 relu, sigmoid and log grad compute units

From: @louisncu
Reviewed-by: @zhang_xue_tong
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-07 17:19:28 +08:00 committed by Gitee
commit 452b393edb
12 changed files with 471 additions and 1 deletions

View File

@ -0,0 +1,72 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/fp16_grad/activation_grad.h"
#include "nnacl/errorcode.h"
int Fp16ReluGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) {
int i = 0;
#ifdef ENABLE_NEON
float16x8_t zero_4 = vdupq_n_f16(0);
for (; i < length - 4; i += 4) {
float16x8_t src0_4 = vld1q_f16(src0 + i);
float16x8_t src1_4 = vld1q_f16(src1 + i);
uint16x8_t mask_4 = vcgtq_f16(src1_4, zero_4);
float16x8_t dst_4 = vbslq_f16(mask_4, src0_4, zero_4);
vst1q_f16(dst + i, dst_4);
}
#endif
for (; i < length; i++) {
dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f;
}
return NNACL_OK;
}
int Fp16SigmoidGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) {
int i = 0;
#ifdef ENABLE_NEON
float16x8_t one_4 = vdupq_n_f16(1);
for (; i < length - 4; i += 4) {
float16x8_t src0_4 = vld1q_f16(src0 + i);
float16x8_t src1_4 = vld1q_f16(src1 + i);
float16x8_t dst_4 = vmulq_f16(src0_4, vmulq_f16(src1_4, vsubq_f16(one_4, src1_4)));
vst1q_f16(dst + i, dst_4);
}
#endif
for (; i < length; i++) {
dst[i] = src0[i] * (src1[i] * (1.0f - src1[i]));
}
return NNACL_OK;
}
int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) {
int i = 0;
#ifdef ENABLE_NEON
float16x8_t log_10 = vdupq_n_f16(log(10));
for (; i < length - 4; i += 4) {
float16x8_t src0_4 = vld1q_f16(src0 + i);
float16x8_t src1_4 = vld1q_f16(src1 + i);
float16x8_t dst_4 = vmulq_f16(src0_4, vrecpeq_f16(vmulq_f16(src1_4, log_10)));
vst1q_f16(dst + i, dst_4);
}
#endif
for (; i < length; i++) {
dst[i] = src0[i] * 1.0f / (src1[i] * log(10));
}
return NNACL_OK;
}

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_GRAD_ACTIVATION_GRAD_H_
#define MINDSPORE_LITE_NNACL_FP16_GRAD_ACTIVATION_GRAD_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include <math.h>
#include "nnacl/op_base.h"
#include "mindspore/lite/nnacl/int8/fixed_point.h"
typedef struct ActivationGradParameterFp16 {
OpParameter op_parameter;
int type_;
float alpha_;
} ActivationGradParameterFp16;
#ifdef __cplusplus
extern "C" {
#endif
int Fp16ReluGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst);
int Fp16SigmoidGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst);
int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP16_GRAD_ACTIVATION_GRAD_H_

View File

@ -17,6 +17,11 @@ list(APPEND SDOT_FILES ${SDOT_SRC})
list(APPEND FP16_FILES ${FP16_C_SRC})
list(APPEND FP16_FILES ${FP16_NEON_SRC})
if(SUPPORT_TRAIN)
file(GLOB FP16_TRAIN_SRC ${NNACL_DIR}/fp16_grad/*.c)
list(APPEND FP16_FILES ${FP16_TRAIN_SRC})
endif()
string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")

View File

@ -99,7 +99,8 @@ enum ActivationGradType : byte {
HSIGMOID = 13,
THRESHOLDRELU = 14,
LINEAR = 15,
UNKNOWN = 16
UNKNOWN = 16,
LOG = 17
}
enum ReduceType : byte {
REDUCE_MAX = 0,

View File

@ -9,6 +9,7 @@ file(GLOB KERNEL_SRC
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc)
if(SUPPORT_TRAIN)
file(GLOB TRAIN_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
file(GLOB TRAIN_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp32_grad/*.cc)
set(KERNEL_SRC ${KERNEL_SRC} ${TRAIN_KERNEL_SRC})
endif()
@ -19,6 +20,9 @@ add_dependencies(cpu_kernel_mid fbs_src)
if(PLATFORM_ARM64)
if(ENABLE_FP16)
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc)
if(SUPPORT_TRAIN)
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
endif()
add_library(cpu_fp16_kernel_mid OBJECT ${FP16_KERNEL_SRC})
add_dependencies(cpu_fp16_kernel_mid fbs_src)
endif()

View File

@ -0,0 +1,93 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.h"
#include "nnacl/fp16_grad/activation_grad.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::ActivationType_RELU;
using mindspore::schema::ActivationType_SIGMOID;
using mindspore::schema::PrimitiveType_ActivationGrad;
namespace mindspore::kernel {
int ActivationGradCPUKernelFp16::Init() {
if (in_tensors_.size() != 2) {
MS_LOG(ERROR) << "ActivationGrad should have 2 input tensors";
return RET_ERROR;
}
return RET_OK;
}
int ActivationGradCPUKernelFp16::ReSize() { return RET_OK; }
int ActivationGradCPUKernelFp16::DoActivation(int task_id) {
auto yt_addr = reinterpret_cast<float16_t *>(in_tensors_.at(0)->MutableData());
auto input_addr = reinterpret_cast<float16_t *>(in_tensors_.at(1)->MutableData());
auto output_addr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->MutableData());
int length = in_tensors_.at(0)->ElementsNum();
int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
int start = stride * task_id;
auto error_code = RET_OK;
if (param_act_grad_->type_ == schema::ActivationGradType_RELU) {
error_code = Fp16ReluGrad(yt_addr + start, input_addr + start, count, output_addr + start);
} else if (param_act_grad_->type_ == schema::ActivationGradType_SIGMOID) {
// Sigmoid gets the input tensors in reverse order!
error_code = Fp16SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start);
} else if (param_act_grad_->type_ == schema::ActivationGradType_LOG) {
error_code = Fp16LogGrad(yt_addr + start, input_addr + start, count, output_addr + start);
} else {
MS_LOG(ERROR) << "Activation type error";
return RET_ERROR;
}
if (error_code != RET_OK) {
return RET_ERROR;
}
return RET_OK;
}
int ActivationGradRunFp16(void *cdata, int task_id) {
MS_ASSERT(cdata != nullptr);
auto activationGrad_kernel = reinterpret_cast<ActivationGradCPUKernelFp16 *>(cdata);
auto error_code = activationGrad_kernel->DoActivation(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "ActivationGradRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int ActivationGradCPUKernelFp16::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRunFp16, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Activation Grad function error error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ActivationGrad, LiteKernelCreator<ActivationGradCPUKernelFp16>)
} // namespace mindspore::kernel

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ACTIVATION_FP16_GRAD_H
#define MINDSPORE_ACTIVATION_FP16_GRAD_H
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/fp16_grad/activation_grad.h"
namespace mindspore::kernel {
class ActivationGradCPUKernelFp16 : public LiteKernel {
public:
explicit ActivationGradCPUKernelFp16(OpParameter *param, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
param_act_grad_ = reinterpret_cast<ActivationGradParameterFp16 *>(param);
}
~ActivationGradCPUKernelFp16() override = default;
int Init() override;
int ReSize() override;
int Run() override;
int DoActivation(int task_id);
private:
ActivationGradParameterFp16 *param_act_grad_;
int thread_count_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_ACTIVATION_FP16_GRAD_H

View File

@ -323,6 +323,12 @@ if(ENABLE_FP16)
)
endif()
if(SUPPORT_TRAIN)
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD
${TEST_DIR}/ut/src/runtime/kernel/arm/fp6_grad/*.cc)
list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD})
endif()
add_executable(lite-test ${TEST_SRC})
add_dependencies(lite-test fbs_src)
target_link_libraries(lite-test dl mindspore::gtest)

View File

@ -0,0 +1,199 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <vector>
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "src/common/file_utils.h"
#include "nnacl/fp16_grad/activation_grad.h"
namespace mindspore {
class TestActGradFp16 : public mindspore::CommonTest {
public:
TestActGradFp16() {}
float error_bound = 1e-3;
};
TEST_F(TestActGradFp16, ReluGradFp16) {
size_t output_data_size = 50;
size_t input_size;
std::string input_path = "./test_data/activationGrad/relu_y_50.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
ASSERT_NE(input_data, nullptr);
EXPECT_EQ(input_size, output_data_size * sizeof(float));
std::string yt_path = "./test_data/activationGrad/relu_yt_50.bin";
auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size));
ASSERT_NE(yt_data, nullptr);
EXPECT_EQ(input_size, output_data_size * sizeof(float));
std::string output_path = "./test_data/activationGrad/relu_out_50.bin";
auto ref_data = reinterpret_cast<const float *>(mindspore::lite::ReadFile(output_path.c_str(), &input_size));
ASSERT_NE(ref_data, nullptr);
EXPECT_EQ(input_size, output_data_size * sizeof(float));
auto yt_buf = new float16_t[output_data_size];
auto input_buf = new float16_t[output_data_size];
auto output_buf = new float16_t[output_data_size];
std::cout << "======yt_buf======" << std::endl;
for (int i = 0; i < output_data_size; i++) {
yt_buf[i] = (float16_t)yt_data[i];
input_buf[i] = (float16_t)input_data[i];
}
Fp16ReluGrad(yt_buf, input_buf, 50, output_buf);
int res = 0;
float error = 0;
std::cout << "======Compare with reference data======" << std::endl;
for (int i = 0; i < output_data_size; i++) {
float diff = std::fabs(static_cast<float>(output_buf[i]) - ref_data[i]);
if (diff > 0.00001) {
error += diff;
}
}
error /= static_cast<float>(output_data_size);
if (error > error_bound) {
printf("error=%f while error_bound=%f\n", error, error_bound);
res = 1;
}
EXPECT_EQ(res, 0);
delete[] output_buf;
delete[] yt_buf;
delete[] input_buf;
delete[] ref_data;
delete[] yt_data;
delete[] input_data;
MS_LOG(INFO) << "ReluGradFp16 passed";
}
TEST_F(TestActGradFp16, SigmoidGradFp16) {
size_t output_data_size = 50;
size_t input_size;
std::string input_path = "./test_data/activationGrad/sigmoid_y_50.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
ASSERT_NE(input_data, nullptr);
std::string yt_path = "./test_data/activationGrad/sigmoid_yt_50.bin";
auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size));
ASSERT_NE(yt_data, nullptr);
std::string output_path = "./test_data/activationGrad/sigmoid_out_50.bin";
auto ref_data = reinterpret_cast<const float *>(mindspore::lite::ReadFile(output_path.c_str(), &input_size));
ASSERT_NE(ref_data, nullptr);
EXPECT_EQ(input_size, output_data_size * sizeof(float));
auto yt_buf = new float16_t[output_data_size];
auto input_buf = new float16_t[output_data_size];
auto output_buf = new float16_t[output_data_size];
std::cout << "======yt_buf======" << std::endl;
for (int i = 0; i < output_data_size; i++) {
yt_buf[i] = (float16_t)yt_data[i];
input_buf[i] = (float16_t)input_data[i];
}
Fp16SigmoidGrad(yt_buf, input_buf, 50, output_buf);
int res = 0;
float error = 0;
std::cout << "======Compare with reference data======" << std::endl;
for (int i = 0; i < output_data_size; i++) {
float diff = std::fabs(static_cast<float>(output_buf[i]) - ref_data[i]);
if (diff > 0.00001) {
error += diff;
}
}
error /= static_cast<float>(output_data_size);
if (error > error_bound) {
printf("error=%f while error_bound=%f\n", error, error_bound);
res = 1;
}
EXPECT_EQ(res, 0);
delete[] output_buf;
delete[] yt_buf;
delete[] input_buf;
delete[] ref_data;
delete[] yt_data;
delete[] input_data;
MS_LOG(INFO) << "SigmoidGradFp16 passed";
}
TEST_F(TestActGradFp16, LogGradFp16) {
size_t output_data_size = 50;
size_t input_size;
std::string input_path = "./test_data/activationGrad/log_x_50.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
ASSERT_NE(input_data, nullptr);
std::string yt_path = "./test_data/activationGrad/log_yt_50.bin";
auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size));
ASSERT_NE(yt_data, nullptr);
std::string output_path = "./test_data/activationGrad/log_out_50.bin";
auto ref_data = reinterpret_cast<const float *>(mindspore::lite::ReadFile(output_path.c_str(), &input_size));
ASSERT_NE(ref_data, nullptr);
EXPECT_EQ(input_size, output_data_size * sizeof(float));
auto yt_buf = new float16_t[output_data_size];
auto input_buf = new float16_t[output_data_size];
auto output_buf = new float16_t[output_data_size];
for (int i = 0; i < output_data_size; i++) {
yt_buf[i] = (float16_t)yt_data[i];
input_buf[i] = (float16_t)input_data[i];
}
Fp16LogGrad(yt_buf, input_buf, 50, output_buf);
int res = 0;
float error = 0;
std::cout << "======Compare with reference data======" << std::endl;
for (int i = 0; i < output_data_size; i++) {
float diff = std::fabs(static_cast<float>(output_buf[i]) - ref_data[i]);
if (diff > 0.00001) {
error += diff;
}
}
error /= static_cast<float>(output_data_size);
if (error > error_bound) {
printf("error%f while error_bound=%f\n", error, error_bound);
res = 1;
}
EXPECT_EQ(res, 0);
delete[] output_buf;
delete[] yt_buf;
delete[] input_buf;
delete[] ref_data;
delete[] yt_data;
delete[] input_data;
MS_LOG(INFO) << "LogGradFp16 passed";
}
} // namespace mindspore

View File

@ -0,0 +1 @@
и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>