forked from mindspore-Ecosystem/mindspore
!12186 Added Fp16 relu, sigmoid and log grad compute units
From: @louisncu Reviewed-by: @zhang_xue_tong Signed-off-by:
This commit is contained in:
commit
452b393edb
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/fp16_grad/activation_grad.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int Fp16ReluGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) {
|
||||
int i = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t zero_4 = vdupq_n_f16(0);
|
||||
for (; i < length - 4; i += 4) {
|
||||
float16x8_t src0_4 = vld1q_f16(src0 + i);
|
||||
float16x8_t src1_4 = vld1q_f16(src1 + i);
|
||||
uint16x8_t mask_4 = vcgtq_f16(src1_4, zero_4);
|
||||
float16x8_t dst_4 = vbslq_f16(mask_4, src0_4, zero_4);
|
||||
vst1q_f16(dst + i, dst_4);
|
||||
}
|
||||
#endif
|
||||
for (; i < length; i++) {
|
||||
dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int Fp16SigmoidGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) {
|
||||
int i = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t one_4 = vdupq_n_f16(1);
|
||||
for (; i < length - 4; i += 4) {
|
||||
float16x8_t src0_4 = vld1q_f16(src0 + i);
|
||||
float16x8_t src1_4 = vld1q_f16(src1 + i);
|
||||
float16x8_t dst_4 = vmulq_f16(src0_4, vmulq_f16(src1_4, vsubq_f16(one_4, src1_4)));
|
||||
vst1q_f16(dst + i, dst_4);
|
||||
}
|
||||
#endif
|
||||
for (; i < length; i++) {
|
||||
dst[i] = src0[i] * (src1[i] * (1.0f - src1[i]));
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) {
|
||||
int i = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t log_10 = vdupq_n_f16(log(10));
|
||||
for (; i < length - 4; i += 4) {
|
||||
float16x8_t src0_4 = vld1q_f16(src0 + i);
|
||||
float16x8_t src1_4 = vld1q_f16(src1 + i);
|
||||
float16x8_t dst_4 = vmulq_f16(src0_4, vrecpeq_f16(vmulq_f16(src1_4, log_10)));
|
||||
vst1q_f16(dst + i, dst_4);
|
||||
}
|
||||
#endif
|
||||
for (; i < length; i++) {
|
||||
dst[i] = src0[i] * 1.0f / (src1[i] * log(10));
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP16_GRAD_ACTIVATION_GRAD_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP16_GRAD_ACTIVATION_GRAD_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "mindspore/lite/nnacl/int8/fixed_point.h"
|
||||
|
||||
typedef struct ActivationGradParameterFp16 {
|
||||
OpParameter op_parameter;
|
||||
int type_;
|
||||
float alpha_;
|
||||
} ActivationGradParameterFp16;
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int Fp16ReluGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst);
|
||||
int Fp16SigmoidGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst);
|
||||
int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP16_GRAD_ACTIVATION_GRAD_H_
|
|
@ -17,6 +17,11 @@ list(APPEND SDOT_FILES ${SDOT_SRC})
|
|||
list(APPEND FP16_FILES ${FP16_C_SRC})
|
||||
list(APPEND FP16_FILES ${FP16_NEON_SRC})
|
||||
|
||||
if(SUPPORT_TRAIN)
|
||||
file(GLOB FP16_TRAIN_SRC ${NNACL_DIR}/fp16_grad/*.c)
|
||||
list(APPEND FP16_FILES ${FP16_TRAIN_SRC})
|
||||
endif()
|
||||
|
||||
string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||
|
|
|
@ -99,7 +99,8 @@ enum ActivationGradType : byte {
|
|||
HSIGMOID = 13,
|
||||
THRESHOLDRELU = 14,
|
||||
LINEAR = 15,
|
||||
UNKNOWN = 16
|
||||
UNKNOWN = 16,
|
||||
LOG = 17
|
||||
}
|
||||
enum ReduceType : byte {
|
||||
REDUCE_MAX = 0,
|
||||
|
|
|
@ -9,6 +9,7 @@ file(GLOB KERNEL_SRC
|
|||
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc)
|
||||
|
||||
if(SUPPORT_TRAIN)
|
||||
file(GLOB TRAIN_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
|
||||
file(GLOB TRAIN_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp32_grad/*.cc)
|
||||
set(KERNEL_SRC ${KERNEL_SRC} ${TRAIN_KERNEL_SRC})
|
||||
endif()
|
||||
|
@ -19,6 +20,9 @@ add_dependencies(cpu_kernel_mid fbs_src)
|
|||
if(PLATFORM_ARM64)
|
||||
if(ENABLE_FP16)
|
||||
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc)
|
||||
if(SUPPORT_TRAIN)
|
||||
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
|
||||
endif()
|
||||
add_library(cpu_fp16_kernel_mid OBJECT ${FP16_KERNEL_SRC})
|
||||
add_dependencies(cpu_fp16_kernel_mid fbs_src)
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp16_grad/activation_fp16_grad.h"
|
||||
#include "nnacl/fp16_grad/activation_grad.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::ActivationType_RELU;
|
||||
using mindspore::schema::ActivationType_SIGMOID;
|
||||
using mindspore::schema::PrimitiveType_ActivationGrad;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int ActivationGradCPUKernelFp16::Init() {
|
||||
if (in_tensors_.size() != 2) {
|
||||
MS_LOG(ERROR) << "ActivationGrad should have 2 input tensors";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ActivationGradCPUKernelFp16::ReSize() { return RET_OK; }
|
||||
|
||||
int ActivationGradCPUKernelFp16::DoActivation(int task_id) {
|
||||
auto yt_addr = reinterpret_cast<float16_t *>(in_tensors_.at(0)->MutableData());
|
||||
auto input_addr = reinterpret_cast<float16_t *>(in_tensors_.at(1)->MutableData());
|
||||
auto output_addr = reinterpret_cast<float16_t *>(out_tensors_.at(0)->MutableData());
|
||||
int length = in_tensors_.at(0)->ElementsNum();
|
||||
|
||||
int stride = UP_DIV(length, thread_count_);
|
||||
int count = MSMIN(stride, length - stride * task_id);
|
||||
int start = stride * task_id;
|
||||
|
||||
auto error_code = RET_OK;
|
||||
|
||||
if (param_act_grad_->type_ == schema::ActivationGradType_RELU) {
|
||||
error_code = Fp16ReluGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationGradType_SIGMOID) {
|
||||
// Sigmoid gets the input tensors in reverse order!
|
||||
error_code = Fp16SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start);
|
||||
} else if (param_act_grad_->type_ == schema::ActivationGradType_LOG) {
|
||||
error_code = Fp16LogGrad(yt_addr + start, input_addr + start, count, output_addr + start);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Activation type error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (error_code != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ActivationGradRunFp16(void *cdata, int task_id) {
|
||||
MS_ASSERT(cdata != nullptr);
|
||||
auto activationGrad_kernel = reinterpret_cast<ActivationGradCPUKernelFp16 *>(cdata);
|
||||
auto error_code = activationGrad_kernel->DoActivation(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "ActivationGradRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ActivationGradCPUKernelFp16::Run() {
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRunFp16, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Activation Grad function error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ActivationGrad, LiteKernelCreator<ActivationGradCPUKernelFp16>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_ACTIVATION_FP16_GRAD_H
|
||||
#define MINDSPORE_ACTIVATION_FP16_GRAD_H
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/fp16_grad/activation_grad.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ActivationGradCPUKernelFp16 : public LiteKernel {
|
||||
public:
|
||||
explicit ActivationGradCPUKernelFp16(OpParameter *param, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
|
||||
param_act_grad_ = reinterpret_cast<ActivationGradParameterFp16 *>(param);
|
||||
}
|
||||
~ActivationGradCPUKernelFp16() override = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoActivation(int task_id);
|
||||
|
||||
private:
|
||||
ActivationGradParameterFp16 *param_act_grad_;
|
||||
int thread_count_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_ACTIVATION_FP16_GRAD_H
|
|
@ -323,6 +323,12 @@ if(ENABLE_FP16)
|
|||
)
|
||||
endif()
|
||||
|
||||
if(SUPPORT_TRAIN)
|
||||
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp6_grad/*.cc)
|
||||
list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD})
|
||||
endif()
|
||||
|
||||
add_executable(lite-test ${TEST_SRC})
|
||||
add_dependencies(lite-test fbs_src)
|
||||
target_link_libraries(lite-test dl mindspore::gtest)
|
||||
|
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "nnacl/fp16_grad/activation_grad.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestActGradFp16 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestActGradFp16() {}
|
||||
float error_bound = 1e-3;
|
||||
};
|
||||
|
||||
TEST_F(TestActGradFp16, ReluGradFp16) {
|
||||
size_t output_data_size = 50;
|
||||
size_t input_size;
|
||||
std::string input_path = "./test_data/activationGrad/relu_y_50.bin";
|
||||
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
|
||||
ASSERT_NE(input_data, nullptr);
|
||||
EXPECT_EQ(input_size, output_data_size * sizeof(float));
|
||||
|
||||
std::string yt_path = "./test_data/activationGrad/relu_yt_50.bin";
|
||||
auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size));
|
||||
ASSERT_NE(yt_data, nullptr);
|
||||
EXPECT_EQ(input_size, output_data_size * sizeof(float));
|
||||
|
||||
std::string output_path = "./test_data/activationGrad/relu_out_50.bin";
|
||||
auto ref_data = reinterpret_cast<const float *>(mindspore::lite::ReadFile(output_path.c_str(), &input_size));
|
||||
ASSERT_NE(ref_data, nullptr);
|
||||
EXPECT_EQ(input_size, output_data_size * sizeof(float));
|
||||
|
||||
auto yt_buf = new float16_t[output_data_size];
|
||||
auto input_buf = new float16_t[output_data_size];
|
||||
auto output_buf = new float16_t[output_data_size];
|
||||
|
||||
std::cout << "======yt_buf======" << std::endl;
|
||||
for (int i = 0; i < output_data_size; i++) {
|
||||
yt_buf[i] = (float16_t)yt_data[i];
|
||||
input_buf[i] = (float16_t)input_data[i];
|
||||
}
|
||||
|
||||
Fp16ReluGrad(yt_buf, input_buf, 50, output_buf);
|
||||
|
||||
int res = 0;
|
||||
float error = 0;
|
||||
std::cout << "======Compare with reference data======" << std::endl;
|
||||
for (int i = 0; i < output_data_size; i++) {
|
||||
float diff = std::fabs(static_cast<float>(output_buf[i]) - ref_data[i]);
|
||||
if (diff > 0.00001) {
|
||||
error += diff;
|
||||
}
|
||||
}
|
||||
error /= static_cast<float>(output_data_size);
|
||||
if (error > error_bound) {
|
||||
printf("error=%f while error_bound=%f\n", error, error_bound);
|
||||
res = 1;
|
||||
}
|
||||
|
||||
EXPECT_EQ(res, 0);
|
||||
|
||||
delete[] output_buf;
|
||||
delete[] yt_buf;
|
||||
delete[] input_buf;
|
||||
delete[] ref_data;
|
||||
delete[] yt_data;
|
||||
delete[] input_data;
|
||||
|
||||
MS_LOG(INFO) << "ReluGradFp16 passed";
|
||||
}
|
||||
|
||||
TEST_F(TestActGradFp16, SigmoidGradFp16) {
|
||||
size_t output_data_size = 50;
|
||||
size_t input_size;
|
||||
std::string input_path = "./test_data/activationGrad/sigmoid_y_50.bin";
|
||||
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
|
||||
ASSERT_NE(input_data, nullptr);
|
||||
|
||||
std::string yt_path = "./test_data/activationGrad/sigmoid_yt_50.bin";
|
||||
auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size));
|
||||
ASSERT_NE(yt_data, nullptr);
|
||||
|
||||
std::string output_path = "./test_data/activationGrad/sigmoid_out_50.bin";
|
||||
auto ref_data = reinterpret_cast<const float *>(mindspore::lite::ReadFile(output_path.c_str(), &input_size));
|
||||
ASSERT_NE(ref_data, nullptr);
|
||||
EXPECT_EQ(input_size, output_data_size * sizeof(float));
|
||||
|
||||
auto yt_buf = new float16_t[output_data_size];
|
||||
auto input_buf = new float16_t[output_data_size];
|
||||
auto output_buf = new float16_t[output_data_size];
|
||||
|
||||
std::cout << "======yt_buf======" << std::endl;
|
||||
for (int i = 0; i < output_data_size; i++) {
|
||||
yt_buf[i] = (float16_t)yt_data[i];
|
||||
input_buf[i] = (float16_t)input_data[i];
|
||||
}
|
||||
|
||||
Fp16SigmoidGrad(yt_buf, input_buf, 50, output_buf);
|
||||
|
||||
int res = 0;
|
||||
float error = 0;
|
||||
std::cout << "======Compare with reference data======" << std::endl;
|
||||
for (int i = 0; i < output_data_size; i++) {
|
||||
float diff = std::fabs(static_cast<float>(output_buf[i]) - ref_data[i]);
|
||||
if (diff > 0.00001) {
|
||||
error += diff;
|
||||
}
|
||||
}
|
||||
error /= static_cast<float>(output_data_size);
|
||||
if (error > error_bound) {
|
||||
printf("error=%f while error_bound=%f\n", error, error_bound);
|
||||
res = 1;
|
||||
}
|
||||
|
||||
EXPECT_EQ(res, 0);
|
||||
|
||||
delete[] output_buf;
|
||||
delete[] yt_buf;
|
||||
delete[] input_buf;
|
||||
delete[] ref_data;
|
||||
delete[] yt_data;
|
||||
delete[] input_data;
|
||||
|
||||
MS_LOG(INFO) << "SigmoidGradFp16 passed";
|
||||
}
|
||||
|
||||
TEST_F(TestActGradFp16, LogGradFp16) {
|
||||
size_t output_data_size = 50;
|
||||
size_t input_size;
|
||||
std::string input_path = "./test_data/activationGrad/log_x_50.bin";
|
||||
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
|
||||
ASSERT_NE(input_data, nullptr);
|
||||
|
||||
std::string yt_path = "./test_data/activationGrad/log_yt_50.bin";
|
||||
auto yt_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(yt_path.c_str(), &input_size));
|
||||
ASSERT_NE(yt_data, nullptr);
|
||||
|
||||
std::string output_path = "./test_data/activationGrad/log_out_50.bin";
|
||||
auto ref_data = reinterpret_cast<const float *>(mindspore::lite::ReadFile(output_path.c_str(), &input_size));
|
||||
ASSERT_NE(ref_data, nullptr);
|
||||
EXPECT_EQ(input_size, output_data_size * sizeof(float));
|
||||
|
||||
auto yt_buf = new float16_t[output_data_size];
|
||||
auto input_buf = new float16_t[output_data_size];
|
||||
auto output_buf = new float16_t[output_data_size];
|
||||
|
||||
for (int i = 0; i < output_data_size; i++) {
|
||||
yt_buf[i] = (float16_t)yt_data[i];
|
||||
input_buf[i] = (float16_t)input_data[i];
|
||||
}
|
||||
|
||||
Fp16LogGrad(yt_buf, input_buf, 50, output_buf);
|
||||
|
||||
int res = 0;
|
||||
float error = 0;
|
||||
std::cout << "======Compare with reference data======" << std::endl;
|
||||
for (int i = 0; i < output_data_size; i++) {
|
||||
float diff = std::fabs(static_cast<float>(output_buf[i]) - ref_data[i]);
|
||||
if (diff > 0.00001) {
|
||||
error += diff;
|
||||
}
|
||||
}
|
||||
error /= static_cast<float>(output_data_size);
|
||||
if (error > error_bound) {
|
||||
printf("error%f while error_bound=%f\n", error, error_bound);
|
||||
res = 1;
|
||||
}
|
||||
|
||||
EXPECT_EQ(res, 0);
|
||||
|
||||
delete[] output_buf;
|
||||
delete[] yt_buf;
|
||||
delete[] input_buf;
|
||||
delete[] ref_data;
|
||||
delete[] yt_data;
|
||||
delete[] input_data;
|
||||
|
||||
MS_LOG(INFO) << "LogGradFp16 passed";
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
|
@ -0,0 +1 @@
|
|||
и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>и[о>
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue