forked from mindspore-Ecosystem/mindspore
add dtype supports for relu/reluv2/relugradv2
This commit is contained in:
parent
09ee838320
commit
2e2c01d6f0
|
@ -1,37 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh"
|
|
||||||
#include "runtime/device/gpu/cuda_common.h"
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__global__ void CalReLUGradKernel(int size, T *dy, T *y, T *dx) {
|
|
||||||
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
|
||||||
dx[pos] = y[pos] > static_cast<T>(0) ? dy[pos] : static_cast<T>(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) {
|
|
||||||
CalReLUGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, y, dx);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream);
|
|
||||||
template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream);
|
|
||||||
template void CalReLUGrad(int size, int8_t *dy, int8_t *y, int8_t *dx, cudaStream_t cuda_stream);
|
|
||||||
template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream);
|
|
||||||
template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream);
|
|
|
@ -1,23 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_
|
|
||||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_
|
|
||||||
|
|
||||||
#include "runtime/device/gpu/cuda_common.h"
|
|
||||||
template <typename T>
|
|
||||||
void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream);
|
|
||||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_
|
|
|
@ -31,11 +31,14 @@ void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream)
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template void CalReLU(int size, double *input_addr, double *output_addr, cudaStream_t cuda_stream);
|
||||||
template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream);
|
template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream);
|
||||||
template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream);
|
template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream);
|
||||||
template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStream_t cuda_stream);
|
template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStream_t cuda_stream);
|
||||||
|
template void CalReLU(int size, int16_t *input_addr, int16_t *output_addr, cudaStream_t cuda_stream);
|
||||||
template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream);
|
template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream);
|
||||||
template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream);
|
template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream);
|
||||||
|
template void CalReLU(int size, uint8_t *input_addr, uint8_t *output_addr, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) {
|
__global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) {
|
||||||
|
@ -69,14 +72,26 @@ void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cuda
|
||||||
ReluGradV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, dy, mask, dx);
|
ReluGradV2Kernel<<<kBlocksPerGrid(num), kThreadsPerBlock, 0, cuda_stream>>>(num, dy, mask, dx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template void ReluV2(const size_t num, const double *x, double *y, uint32_t *mask, cudaStream_t cuda_stream);
|
||||||
template void ReluV2(const size_t num, const float *x, float *y, uint32_t *mask, cudaStream_t cuda_stream);
|
template void ReluV2(const size_t num, const float *x, float *y, uint32_t *mask, cudaStream_t cuda_stream);
|
||||||
template void ReluV2(const size_t num, const half *x, half *y, uint32_t *mask, cudaStream_t cuda_stream);
|
template void ReluV2(const size_t num, const half *x, half *y, uint32_t *mask, cudaStream_t cuda_stream);
|
||||||
|
template void ReluV2(const size_t num, const int8_t *x, int8_t *y, uint32_t *mask, cudaStream_t cuda_stream);
|
||||||
|
template void ReluV2(const size_t num, const int16_t *x, int16_t *y, uint32_t *mask, cudaStream_t cuda_stream);
|
||||||
template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream);
|
template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream);
|
||||||
template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream);
|
template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream);
|
||||||
|
template void ReluV2(const size_t num, const uint8_t *x, uint8_t *y, uint32_t *mask, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template void ReluGradV2(const size_t num, const double *dy, const uint32_t *mask, double *dx,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
template void ReluGradV2(const size_t num, const float *dy, const uint32_t *mask, float *dx, cudaStream_t cuda_stream);
|
template void ReluGradV2(const size_t num, const float *dy, const uint32_t *mask, float *dx, cudaStream_t cuda_stream);
|
||||||
template void ReluGradV2(const size_t num, const half *dy, const uint32_t *mask, half *dx, cudaStream_t cuda_stream);
|
template void ReluGradV2(const size_t num, const half *dy, const uint32_t *mask, half *dx, cudaStream_t cuda_stream);
|
||||||
|
template void ReluGradV2(const size_t num, const int8_t *dy, const uint32_t *mask, int8_t *dx,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
template void ReluGradV2(const size_t num, const int16_t *dy, const uint32_t *mask, int16_t *dx,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *mask, int32_t *dx,
|
template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *mask, int32_t *dx,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx,
|
template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx,
|
||||||
cudaStream_t cuda_stream);
|
cudaStream_t cuda_stream);
|
||||||
|
template void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
|
|
@ -46,7 +46,8 @@ static constexpr float kSignedMinFloat = -3.402823466e+38F;
|
||||||
static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {
|
static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {
|
||||||
{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, {"kNumberTypeFloat16", CUDNN_DATA_HALF},
|
{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, {"kNumberTypeFloat16", CUDNN_DATA_HALF},
|
||||||
{"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, {"kNumberTypeInt32", CUDNN_DATA_INT32},
|
{"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, {"kNumberTypeInt32", CUDNN_DATA_INT32},
|
||||||
{"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8}};
|
{"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8},
|
||||||
|
{"kNumberTypeUInt8", CUDNN_DATA_UINT8}};
|
||||||
// Used by mixprecision, cuda dtype select
|
// Used by mixprecision, cuda dtype select
|
||||||
static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F},
|
static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F},
|
||||||
{"kNumberTypeFloat16", CUDA_R_16F}};
|
{"kNumberTypeFloat16", CUDA_R_16F}};
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -18,15 +18,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
ActivationGpuFwdKernel, float)
|
|
||||||
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
ActivationGpuFwdKernel, half)
|
|
||||||
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
|
||||||
ActivationGpuFwdKernel, int8_t)
|
|
||||||
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
|
||||||
ActivationGpuFwdKernel, int32_t)
|
|
||||||
|
|
||||||
MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
ActivationGpuFwdKernel, float)
|
ActivationGpuFwdKernel, float)
|
||||||
MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
@ -44,17 +44,12 @@ class ActivationGpuFwdKernel : public GpuKernel {
|
||||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
|
||||||
if (mode_ == CUDNN_ACTIVATION_RELU) {
|
const float alpha = 1;
|
||||||
const int size = input_size_ / sizeof(T);
|
const float beta = 0;
|
||||||
CalReLU(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||||
} else {
|
cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input,
|
||||||
const float alpha = 1;
|
&beta, data_descriptor_, output),
|
||||||
const float beta = 0;
|
"cudnnActivationForward failed");
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
|
||||||
cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_,
|
|
||||||
input, &beta, data_descriptor_, output),
|
|
||||||
"cudnnActivationForward failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -125,7 +120,7 @@ class ActivationGpuFwdKernel : public GpuKernel {
|
||||||
void ResetResource() noexcept override {
|
void ResetResource() noexcept override {
|
||||||
cudnn_handle_ = nullptr;
|
cudnn_handle_ = nullptr;
|
||||||
activation_desc_ = nullptr;
|
activation_desc_ = nullptr;
|
||||||
mode_ = CUDNN_ACTIVATION_RELU;
|
mode_ = CUDNN_ACTIVATION_SIGMOID;
|
||||||
data_descriptor_ = nullptr;
|
data_descriptor_ = nullptr;
|
||||||
is_null_input_ = false;
|
is_null_input_ = false;
|
||||||
input_size_list_.clear();
|
input_size_list_.clear();
|
||||||
|
@ -154,11 +149,11 @@ class ActivationGpuFwdKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
input_size_list_.push_back(input_size_);
|
input_size_list_.push_back(input_size_);
|
||||||
output_size_list_.push_back(output_size_);
|
output_size_list_.push_back(output_size_);
|
||||||
|
workspace_size_list_.push_back(workspace_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU},
|
std::map<std::string, cudnnActivationMode_t> kernel_map = {{"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU},
|
||||||
{"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU},
|
|
||||||
{"Tanh", CUDNN_ACTIVATION_TANH},
|
{"Tanh", CUDNN_ACTIVATION_TANH},
|
||||||
{"Elu", CUDNN_ACTIVATION_ELU},
|
{"Elu", CUDNN_ACTIVATION_ELU},
|
||||||
{"Sigmoid", CUDNN_ACTIVATION_SIGMOID}};
|
{"Sigmoid", CUDNN_ACTIVATION_SIGMOID}};
|
||||||
|
@ -179,4 +174,4 @@ class ActivationGpuFwdKernel : public GpuKernel {
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -18,6 +18,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReluGrad,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
ActivationGradGpuKernel, double)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
ReluGrad,
|
ReluGrad,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
@ -26,12 +30,21 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
ReluGrad,
|
ReluGrad,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
ActivationGradGpuKernel, half)
|
ActivationGradGpuKernel, half)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
|
ActivationGradGpuKernel, int64_t)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
ActivationGradGpuKernel, int32_t)
|
ActivationGradGpuKernel, int32_t)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||||
|
ActivationGradGpuKernel, int16_t)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||||
ActivationGradGpuKernel, int8_t)
|
ActivationGradGpuKernel, int8_t)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
ActivationGradGpuKernel, uint8_t)
|
||||||
|
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
ReLU6Grad,
|
ReLU6Grad,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
@ -23,7 +23,6 @@
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||||
#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -52,18 +51,13 @@ class ActivationGradGpuKernel : public GpuKernel {
|
||||||
}
|
}
|
||||||
T *dx = GetDeviceAddress<T>(outputs, 0);
|
T *dx = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
|
||||||
if (mode_ == CUDNN_ACTIVATION_RELU) {
|
const float alpha = 1;
|
||||||
const int size = input_size_ / sizeof(T);
|
const float beta = 0;
|
||||||
CalReLUGrad(size, dy, y, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||||
} else {
|
kernel_node_,
|
||||||
const float alpha = 1;
|
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
|
||||||
const float beta = 0;
|
data_descriptor_, y, &beta, data_descriptor_, dx),
|
||||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
"cudnnActivationBackward failed");
|
||||||
kernel_node_,
|
|
||||||
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
|
|
||||||
data_descriptor_, y, &beta, data_descriptor_, dx),
|
|
||||||
"cudnnActivationBackward failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -179,4 +173,4 @@ class ActivationGradGpuKernel : public GpuKernel {
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
ReLUGpuFwdKernel, double)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
ReLUGpuFwdKernel, float)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
ReLUGpuFwdKernel, half)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
|
ReLUGpuFwdKernel, int64_t)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
ReLUGpuFwdKernel, int32_t)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||||
|
ReLUGpuFwdKernel, int16_t)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReLUGpuFwdKernel,
|
||||||
|
int8_t)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
ReLUGpuFwdKernel, uint8_t)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,98 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
template <typename T>
|
||||||
|
class ReLUGpuFwdKernel : public GpuKernel {
|
||||||
|
public:
|
||||||
|
ReLUGpuFwdKernel() { ResetResource(); }
|
||||||
|
~ReLUGpuFwdKernel() override {}
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||||
|
if (is_null_input_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||||
|
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||||
|
|
||||||
|
const int size = input_size_ / sizeof(T);
|
||||||
|
CalReLU(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
|
if (input_num != 1) {
|
||||||
|
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 1.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||||
|
is_null_input_ = CHECK_NULL_INPUT(input_shape);
|
||||||
|
if (is_null_input_) {
|
||||||
|
MS_LOG(WARNING) << "ReLUGpuFwdKernel input is null.";
|
||||||
|
}
|
||||||
|
size_t size = 1;
|
||||||
|
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||||
|
size *= input_shape[i];
|
||||||
|
}
|
||||||
|
input_size_ = size * sizeof(T);
|
||||||
|
|
||||||
|
InitSizeLists();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ResetResource() noexcept override {
|
||||||
|
is_null_input_ = false;
|
||||||
|
input_size_list_.clear();
|
||||||
|
output_size_list_.clear();
|
||||||
|
workspace_size_list_.clear();
|
||||||
|
input_size_ = 0;
|
||||||
|
workspace_size_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitSizeLists() override {
|
||||||
|
input_size_list_.push_back(input_size_);
|
||||||
|
output_size_list_.push_back(input_size_);
|
||||||
|
workspace_size_list_.push_back(workspace_size_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_null_input_;
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
size_t input_size_;
|
||||||
|
size_t workspace_size_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_
|
|
@ -18,6 +18,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReluGradV2,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
ReluGradV2GpuKernel, double)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
ReluGradV2,
|
ReluGradV2,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
@ -26,6 +30,13 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
ReluGradV2,
|
ReluGradV2,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16),
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||||
ReluGradV2GpuKernel, half)
|
ReluGradV2GpuKernel, half)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReluGradV2, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8),
|
||||||
|
ReluGradV2GpuKernel, int8_t)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReluGradV2,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16),
|
||||||
|
ReluGradV2GpuKernel, int16_t)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
ReluGradV2,
|
ReluGradV2,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
|
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
@ -34,5 +45,10 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
ReluGradV2,
|
ReluGradV2,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
|
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
|
||||||
ReluGradV2GpuKernel, int64_t)
|
ReluGradV2GpuKernel, int64_t)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReluGradV2,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
ReluGradV2GpuKernel, uint8_t)
|
||||||
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,6 +18,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReLUV2,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
ReluV2GpuKernel, double)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
ReLUV2,
|
ReLUV2,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32),
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
@ -26,12 +30,20 @@ MS_REG_GPU_KERNEL_ONE(
|
||||||
ReLUV2,
|
ReLUV2,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32),
|
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32),
|
||||||
ReluV2GpuKernel, half)
|
ReluV2GpuKernel, half)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
ReluV2GpuKernel, int8_t)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
ReluV2GpuKernel, int16_t)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||||
ReluV2GpuKernel, int32_t)
|
ReluV2GpuKernel, int32_t)
|
||||||
MS_REG_GPU_KERNEL_ONE(
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
ReLUV2,
|
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32),
|
|
||||||
ReluV2GpuKernel, int64_t)
|
ReluV2GpuKernel, int64_t)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(
|
||||||
|
ReLUV2, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
ReluV2GpuKernel, uint8_t)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -79,4 +79,4 @@ class ReluV2GpuKernel : public GpuKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_MASK_GPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_
|
||||||
|
|
|
@ -1,84 +0,0 @@
|
||||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import mindspore.context as context
|
|
||||||
import mindspore.nn as nn
|
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
|
||||||
|
|
||||||
|
|
||||||
class NetReluGrad(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(NetReluGrad, self).__init__()
|
|
||||||
self.rekuGrad = G.ReluGrad()
|
|
||||||
|
|
||||||
def construct(self, x, dy):
|
|
||||||
return self.rekuGrad(dy, x)
|
|
||||||
|
|
||||||
|
|
||||||
def relu_grad_base(dtype):
|
|
||||||
x = Tensor(np.array([[[[-1, 1, 1],
|
|
||||||
[1, -1, 1],
|
|
||||||
[1, 1, -1]]]]).astype(dtype))
|
|
||||||
dy = Tensor(np.array([[[[1, 0, 1],
|
|
||||||
[0, 1, 0],
|
|
||||||
[1, 1, 1]]]]).astype(dtype))
|
|
||||||
expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.dtype)
|
|
||||||
error = np.ones(shape=[3, 3]) * 1.0e-6
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
|
||||||
relu_grad = NetReluGrad()
|
|
||||||
output = relu_grad(x, dy)
|
|
||||||
diff = output.asnumpy() - expect
|
|
||||||
assert np.all(diff < error)
|
|
||||||
assert output.asnumpy().dtype == dtype
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_relu_grad_float16():
|
|
||||||
relu_grad_base(np.float16)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_relu_grad_float32():
|
|
||||||
relu_grad_base(np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_relu_grad_int8():
|
|
||||||
relu_grad_base(np.int8)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_relu_grad_int32():
|
|
||||||
relu_grad_base(np.int32)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_relu_grad_int64():
|
|
||||||
relu_grad_base(np.int64)
|
|
Loading…
Reference in New Issue