forked from mindspore-Ecosystem/mindspore
ReLU/ReLUGrad int8/32/64 support
This commit is contained in:
parent
8ddb10fd8a
commit
d982fec2ff
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void CalReLUGradKernel(int size, T *dy, T *y, T *dx) {
|
||||
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
dx[pos] = y[pos] > static_cast<T>(0) ? dy[pos] : static_cast<T>(0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) {
|
||||
CalReLUGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, y, dx);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream);
|
||||
template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream);
|
||||
template void CalReLUGrad(int size, int8_t *dy, int8_t *y, int8_t *dx, cudaStream_t cuda_stream);
|
||||
template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream);
|
||||
template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,23 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_
|
|
@ -33,6 +33,7 @@ void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream)
|
|||
|
||||
template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream);
|
||||
template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream);
|
||||
template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStream_t cuda_stream);
|
||||
template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream);
|
||||
template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
|
|
|
@ -22,6 +22,8 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
|
|||
ActivationGpuFwdKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ActivationGpuFwdKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
ActivationGpuFwdKernel, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ActivationGpuFwdKernel, int32_t)
|
||||
|
||||
|
|
|
@ -26,6 +26,12 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
ReluGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ActivationGradGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ActivationGradGpuKernel, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
ActivationGradGpuKernel, int8_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReLU6Grad,
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -36,7 +37,7 @@ class ActivationGradGpuKernel : public GpuKernel {
|
|||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *) override {
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
|
@ -51,13 +52,18 @@ class ActivationGradGpuKernel : public GpuKernel {
|
|||
}
|
||||
T *dx = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
|
||||
data_descriptor_, y, &beta, data_descriptor_, dx),
|
||||
"cudnnActivationBackward failed");
|
||||
if (mode_ == CUDNN_ACTIVATION_RELU) {
|
||||
const int size = input_size_ / sizeof(T);
|
||||
CalReLUGrad(size, dy, y, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
|
||||
data_descriptor_, y, &beta, data_descriptor_, dx),
|
||||
"cudnnActivationBackward failed");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -31,17 +31,14 @@ class NetReluGrad(nn.Cell):
|
|||
return self.rekuGrad(dy, x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_relu_grad():
|
||||
def relu_grad_base(dtype):
|
||||
x = Tensor(np.array([[[[-1, 1, 1],
|
||||
[1, -1, 1],
|
||||
[1, 1, -1]]]]).astype(np.float32))
|
||||
[1, 1, -1]]]]).astype(dtype))
|
||||
dy = Tensor(np.array([[[[1, 0, 1],
|
||||
[0, 1, 0],
|
||||
[1, 1, 1]]]]).astype(np.float32))
|
||||
expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.float32)
|
||||
[1, 1, 1]]]]).astype(dtype))
|
||||
expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.dtype)
|
||||
error = np.ones(shape=[3, 3]) * 1.0e-6
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
@ -49,3 +46,39 @@ def test_relu_grad():
|
|||
output = relu_grad(x, dy)
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert output.asnumpy().dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_relu_grad_float16():
|
||||
relu_grad_base(np.float16)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_relu_grad_float32():
|
||||
relu_grad_base(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_relu_grad_int8():
|
||||
relu_grad_base(np.int8)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_relu_grad_int32():
|
||||
relu_grad_base(np.int32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_relu_grad_int64():
|
||||
relu_grad_base(np.int64)
|
||||
|
|
|
@ -65,6 +65,28 @@ def test_relu_float32():
|
|||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_relu_int8():
|
||||
x = Tensor(np.array([[[[-1, 1, 10],
|
||||
[1, -1, 1],
|
||||
[10, 1, -1]]]]).astype(np.int8))
|
||||
expect = np.array([[[[0, 1, 10,],
|
||||
[1, 0, 1,],
|
||||
[10, 1, 0.]]]]).astype(np.int8)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
relu = NetRelu()
|
||||
output = relu(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
relu = NetRelu()
|
||||
output = relu(x)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue