diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.h index fa1efafa671..78dc29941e5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.h @@ -27,7 +27,7 @@ namespace kernel { template class CastGpuKernel : public GpuKernel { public: - CastGpuKernel() : input_size_(1), output_size_(1) {} + CastGpuKernel() { ResetResource(); } ~CastGpuKernel() = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -42,6 +42,7 @@ class CastGpuKernel : public GpuKernel { Cast(input_size_, input_addr, output_addr, reinterpret_cast(stream_ptr)); return true; } + bool Init(const CNodePtr &kernel_node) override { auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); @@ -62,6 +63,14 @@ class CastGpuKernel : public GpuKernel { return true; } + void ResetResource() noexcept override { + input_size_ = 1; + output_size_ = 1; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + protected: void InitSizeLists() override { input_size_list_.push_back(input_size_ * sizeof(T)); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu new file mode 100644 index 00000000000..6f3a3ad6347 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void CalReLUKernel(int size, T *input_addr, T *output_addr) { + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output_addr[pos] = input_addr[pos] > static_cast(0) ? input_addr[pos] : static_cast(0); + } +} + +template +void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) { + CalReLUKernel<<>>(size, input_addr, output_addr); + return; +} + +template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); +template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); +template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); +template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh new file mode 100644 index 00000000000..19e10224794 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void CalReLU(int input_size, T *input_addr, T *output_addr, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h index 67f4a9b9a91..371f27437db 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h @@ -45,6 +45,7 @@ static constexpr float kSignedMinFloat = -3.402823466e+38F; // Used by mixprecision, cudnn dtype select static std::map kCudnnDtypeMap = {{"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, {"kNumberTypeFloat16", CUDNN_DATA_HALF}, + {"kNumberTypeInt64", CUDNN_DATA_DOUBLE}, {"kNumberTypeInt32", CUDNN_DATA_INT32}}; // Used by mixprecision, cuda dtype select static std::map kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F}, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc index 27827753266..627a71de8ca 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc @@ -22,6 +22,10 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut ActivationGpuFwdKernel, float) MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ActivationGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ActivationGpuFwdKernel, int32_t) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ActivationGpuFwdKernel, int64_t) MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ActivationGpuFwdKernel, float) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h index f6934d0416e..8779ed0a664 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h @@ -23,6 +23,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" namespace mindspore { namespace kernel { @@ -36,18 +37,23 @@ class ActivationGpuFwdKernel : public GpuKernel { const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *) override { + const std::vector &outputs, void *stream_ptr) override { if (is_null_input_) { return true; } T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 0); - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, - &beta, data_descriptor_, output), - "cudnnActivationForward failed"); + if (mode_ == CUDNN_ACTIVATION_RELU) { + const int size = input_size_ / sizeof(T); + CalReLU(size, input, output, reinterpret_cast(stream_ptr)); + } else { + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, + input, &beta, data_descriptor_, output), + "cudnnActivationForward failed"); + } return true; } diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 11a470e1914..d5cfcb89045 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -291,7 +291,7 @@ class Softsign(PrimitiveWithInfer): return input_x -class ReLU(PrimitiveWithInfer): +class ReLU(PrimitiveWithCheck): r""" Computes ReLU (Rectified Linear Unit) of input tensors element-wise. @@ -320,12 +320,11 @@ class ReLU(PrimitiveWithInfer): """Initialize ReLU""" self.init_prim_io_names(inputs=['x'], outputs=['output']) - def infer_shape(self, input_x): - return input_x + def check_shape(self, input_x): + pass - def infer_dtype(self, input_x): + def check_dtype(self, input_x): validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name) - return input_x class ReLU6(PrimitiveWithInfer): diff --git a/tests/st/ops/gpu/test_cast_op.py b/tests/st/ops/gpu/test_cast_op.py index 4cec61dc0b9..c58299de666 100644 --- a/tests/st/ops/gpu/test_cast_op.py +++ b/tests/st/ops/gpu/test_cast_op.py @@ -21,6 +21,7 @@ import mindspore.context as context from mindspore.common.tensor import Tensor from mindspore.nn import Cell from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner class Net(Cell): @@ -36,6 +37,22 @@ class Net(Cell): return output +class NetDynamic(Cell): + def __init__(self, type0, type1): + super(NetDynamic, self).__init__() + self.conv = inner.GpuConvertToDynamicShape() + self.Cast = P.Cast() + self.type0 = type0 + self.type1 = type1 + + def construct(self, x0, x1): + x0_conv = self.conv(x0) + x1_conv = self.conv(x1) + output = (self.Cast(x0_conv, self.type0), + self.Cast(x1_conv, self.type1)) + return output + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -563,3 +580,20 @@ def test_cast30(): assert type0 == 'uint16' type1 = output[1].asnumpy().dtype assert type1 == 'uint32' + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cast31(): + x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float32)) + t0 = mstype.uint16 + x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float32)) + t1 = mstype.uint32 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + net = NetDynamic(t0, t1) + output = net(x0, x1) + type0 = output[0].asnumpy().dtype + assert type0 == 'uint16' + type1 = output[1].asnumpy().dtype + assert type1 == 'uint32' diff --git a/tests/st/ops/gpu/test_relu_op.py b/tests/st/ops/gpu/test_relu_op.py index 3cec34ef5d7..03443c0e65e 100644 --- a/tests/st/ops/gpu/test_relu_op.py +++ b/tests/st/ops/gpu/test_relu_op.py @@ -20,6 +20,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner class NetRelu(nn.Cell): @@ -31,10 +32,21 @@ class NetRelu(nn.Cell): return self.relu(x) +class NetReluDynamic(nn.Cell): + def __init__(self): + super(NetReluDynamic, self).__init__() + self.conv = inner.GpuConvertToDynamicShape() + self.relu = P.ReLU() + + def construct(self, x): + x_conv = self.conv(x) + return self.relu(x_conv) + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_relu(): +def test_relu_float32(): x = Tensor(np.array([[[[-1, 1, 10], [1, -1, 1], [10, 1, -1]]]]).astype(np.float32)) @@ -51,3 +63,65 @@ def test_relu(): relu = NetRelu() output = relu(x) assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_int32(): + x = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.int32)) + expect = np.array([[[[0, 1, 10,], + [1, 0, 1,], + [10, 1, 0.]]]]).astype(np.int32) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + relu = NetRelu() + output = relu(x) + assert (output.asnumpy() == expect).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + relu = NetRelu() + output = relu(x) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_int64(): + x = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.int64)) + expect = np.array([[[[0, 1, 10,], + [1, 0, 1,], + [10, 1, 0.]]]]).astype(np.int64) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + relu = NetRelu() + output = relu(x) + print(output.asnumpy(), expect) + assert (output.asnumpy() == expect).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + relu = NetRelu() + output = relu(x) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_int64_dynamic_shape(): + x = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.int64)) + expect = np.array([[[[0, 1, 10,], + [1, 0, 1,], + [10, 1, 0.]]]]).astype(np.int64) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + relu_dynamic = NetReluDynamic() + output = relu_dynamic(x) + assert (output.asnumpy() == expect).all()