ReLU/ReLUGrad int8/32/64 support

This commit is contained in:
jonwe 2020-12-08 12:36:05 -05:00
parent 8ddb10fd8a
commit d982fec2ff
8 changed files with 145 additions and 15 deletions

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void CalReLUGradKernel(int size, T *dy, T *y, T *dx) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
dx[pos] = y[pos] > static_cast<T>(0) ? dy[pos] : static_cast<T>(0);
}
}
template <typename T>
void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) {
CalReLUGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, y, dx);
return;
}
template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, int8_t *dy, int8_t *y, int8_t *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream);

View File

@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_

View File

@ -33,6 +33,7 @@ void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream)
template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream);

View File

@ -22,6 +22,8 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
ActivationGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
ActivationGpuFwdKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ActivationGpuFwdKernel, int32_t)

View File

@ -26,6 +26,12 @@ MS_REG_GPU_KERNEL_ONE(
ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ActivationGradGpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(
ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
ActivationGradGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(
ReLU6Grad,

View File

@ -23,6 +23,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh"
namespace mindspore {
namespace kernel {
@ -36,7 +37,7 @@ class ActivationGradGpuKernel : public GpuKernel {
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *) override {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
@ -51,13 +52,18 @@ class ActivationGradGpuKernel : public GpuKernel {
}
T *dx = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
data_descriptor_, y, &beta, data_descriptor_, dx),
"cudnnActivationBackward failed");
if (mode_ == CUDNN_ACTIVATION_RELU) {
const int size = input_size_ / sizeof(T);
CalReLUGrad(size, dy, y, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy,
data_descriptor_, y, &beta, data_descriptor_, dx),
"cudnnActivationBackward failed");
}
return true;
}

View File

@ -31,17 +31,14 @@ class NetReluGrad(nn.Cell):
return self.rekuGrad(dy, x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad():
def relu_grad_base(dtype):
x = Tensor(np.array([[[[-1, 1, 1],
[1, -1, 1],
[1, 1, -1]]]]).astype(np.float32))
[1, 1, -1]]]]).astype(dtype))
dy = Tensor(np.array([[[[1, 0, 1],
[0, 1, 0],
[1, 1, 1]]]]).astype(np.float32))
expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.float32)
[1, 1, 1]]]]).astype(dtype))
expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.dtype)
error = np.ones(shape=[3, 3]) * 1.0e-6
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -49,3 +46,39 @@ def test_relu_grad():
output = relu_grad(x, dy)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert output.asnumpy().dtype == dtype
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad_float16():
relu_grad_base(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad_float32():
relu_grad_base(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad_int8():
relu_grad_base(np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad_int32():
relu_grad_base(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_grad_int64():
relu_grad_base(np.int64)

View File

@ -65,6 +65,28 @@ def test_relu_float32():
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_int8():
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.int8))
expect = np.array([[[[0, 1, 10,],
[1, 0, 1,],
[10, 1, 0.]]]]).astype(np.int8)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
relu = NetRelu()
output = relu(x)
assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
relu = NetRelu()
output = relu(x)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard