diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu deleted file mode 100644 index 887515b05e4..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh" -#include "runtime/device/gpu/cuda_common.h" - -template -__global__ void CalReLUGradKernel(int size, T *dy, T *y, T *dx) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - dx[pos] = y[pos] > static_cast(0) ? dy[pos] : static_cast(0); - } -} - -template -void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) { - CalReLUGradKernel<<>>(size, dy, y, dx); - return; -} - -template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream); -template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream); -template void CalReLUGrad(int size, int8_t *dy, int8_t *y, int8_t *dx, cudaStream_t cuda_stream); -template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream); -template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh deleted file mode 100644 index 1d1fbbde7c3..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ - -#include "runtime/device/gpu/cuda_common.h" -template -void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu index d0c0b5f5264..d7290dc4a4a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu @@ -31,11 +31,14 @@ void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) return; } +template void CalReLU(int size, double *input_addr, double *output_addr, cudaStream_t cuda_stream); template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStream_t cuda_stream); +template void CalReLU(int size, int16_t *input_addr, int16_t *output_addr, cudaStream_t cuda_stream); template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); +template void CalReLU(int size, uint8_t *input_addr, uint8_t *output_addr, cudaStream_t cuda_stream); template __global__ void ReluV2Kernel(const size_t num, const T *x, T *y, uint32_t *mask) { @@ -69,14 +72,26 @@ void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cuda ReluGradV2Kernel<<>>(num, dy, mask, dx); } +template void ReluV2(const size_t num, const double *x, double *y, uint32_t *mask, cudaStream_t cuda_stream); template void ReluV2(const size_t num, const float *x, float *y, uint32_t *mask, cudaStream_t cuda_stream); template void ReluV2(const size_t num, const half *x, half *y, uint32_t *mask, cudaStream_t cuda_stream); +template void ReluV2(const size_t num, const int8_t *x, int8_t *y, uint32_t *mask, cudaStream_t cuda_stream); +template void ReluV2(const size_t num, const int16_t *x, int16_t *y, uint32_t *mask, cudaStream_t cuda_stream); template void ReluV2(const size_t num, const int32_t *x, int32_t *y, uint32_t *mask, cudaStream_t cuda_stream); template void ReluV2(const size_t num, const int64_t *x, int64_t *y, uint32_t *mask, cudaStream_t cuda_stream); +template void ReluV2(const size_t num, const uint8_t *x, uint8_t *y, uint32_t *mask, cudaStream_t cuda_stream); +template void ReluGradV2(const size_t num, const double *dy, const uint32_t *mask, double *dx, + cudaStream_t cuda_stream); template void ReluGradV2(const size_t num, const float *dy, const uint32_t *mask, float *dx, cudaStream_t cuda_stream); template void ReluGradV2(const size_t num, const half *dy, const uint32_t *mask, half *dx, cudaStream_t cuda_stream); +template void ReluGradV2(const size_t num, const int8_t *dy, const uint32_t *mask, int8_t *dx, + cudaStream_t cuda_stream); +template void ReluGradV2(const size_t num, const int16_t *dy, const uint32_t *mask, int16_t *dx, + cudaStream_t cuda_stream); template void ReluGradV2(const size_t num, const int32_t *dy, const uint32_t *mask, int32_t *dx, - cudaStream_t cuda_stream); + cudaStream_t cuda_stream); template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *mask, int64_t *dx, - cudaStream_t cuda_stream); + cudaStream_t cuda_stream); +template void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h index 8648cfac6b3..87c81183d2e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h @@ -46,7 +46,8 @@ static constexpr float kSignedMinFloat = -3.402823466e+38F; static std::map kCudnnDtypeMap = { {"kNumberTypeFloat32", CUDNN_DATA_FLOAT}, {"kNumberTypeFloat16", CUDNN_DATA_HALF}, {"kNumberTypeFloat64", CUDNN_DATA_DOUBLE}, {"kNumberTypeInt32", CUDNN_DATA_INT32}, - {"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8}}; + {"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8}, + {"kNumberTypeUInt8", CUDNN_DATA_UINT8}}; // Used by mixprecision, cuda dtype select static std::map kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F}, {"kNumberTypeFloat16", CUDA_R_16F}}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc index 436da8bdba1..d91c78c2d80 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,15 +18,6 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - ActivationGpuFwdKernel, int8_t) -MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ActivationGpuFwdKernel, int32_t) - MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ActivationGpuFwdKernel, float) MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h index 78cfbe0f6ea..64749262b49 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_ #include #include @@ -44,17 +44,12 @@ class ActivationGpuFwdKernel : public GpuKernel { T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 0); - if (mode_ == CUDNN_ACTIVATION_RELU) { - const int size = input_size_ / sizeof(T); - CalReLU(size, input, output, reinterpret_cast(stream_ptr)); - } else { - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, - input, &beta, data_descriptor_, output), - "cudnnActivationForward failed"); - } + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, + &beta, data_descriptor_, output), + "cudnnActivationForward failed"); return true; } @@ -125,7 +120,7 @@ class ActivationGpuFwdKernel : public GpuKernel { void ResetResource() noexcept override { cudnn_handle_ = nullptr; activation_desc_ = nullptr; - mode_ = CUDNN_ACTIVATION_RELU; + mode_ = CUDNN_ACTIVATION_SIGMOID; data_descriptor_ = nullptr; is_null_input_ = false; input_size_list_.clear(); @@ -154,11 +149,11 @@ class ActivationGpuFwdKernel : public GpuKernel { } input_size_list_.push_back(input_size_); output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(workspace_size_); } private: - std::map kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, - {"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU}, + std::map kernel_map = {{"ReLU6", CUDNN_ACTIVATION_CLIPPED_RELU}, {"Tanh", CUDNN_ACTIVATION_TANH}, {"Elu", CUDNN_ACTIVATION_ELU}, {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; @@ -179,4 +174,4 @@ class ActivationGpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc index 63c10a55258..e4d781db171 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,10 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE( + ReluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + ActivationGradGpuKernel, double) MS_REG_GPU_KERNEL_ONE( ReluGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), @@ -26,12 +30,21 @@ MS_REG_GPU_KERNEL_ONE( ReluGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ActivationGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ActivationGradGpuKernel, int64_t) MS_REG_GPU_KERNEL_ONE( ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ActivationGradGpuKernel, int32_t) +MS_REG_GPU_KERNEL_ONE( + ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + ActivationGradGpuKernel, int16_t) MS_REG_GPU_KERNEL_ONE( ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ActivationGradGpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE( + ReluGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + ActivationGradGpuKernel, uint8_t) MS_REG_GPU_KERNEL_ONE( ReLU6Grad, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h index c35fe5a70cd..b97f7b93bd4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_ #include #include @@ -23,7 +23,6 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/kernel_constants.h" -#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh" namespace mindspore { namespace kernel { @@ -52,18 +51,13 @@ class ActivationGradGpuKernel : public GpuKernel { } T *dx = GetDeviceAddress(outputs, 0); - if (mode_ == CUDNN_ACTIVATION_RELU) { - const int size = input_size_ / sizeof(T); - CalReLUGrad(size, dy, y, dx, reinterpret_cast(stream_ptr)); - } else { - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, - data_descriptor_, y, &beta, data_descriptor_, dx), - "cudnnActivationBackward failed"); - } + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, + data_descriptor_, y, &beta, data_descriptor_, dx), + "cudnnActivationBackward failed"); return true; } @@ -179,4 +173,4 @@ class ActivationGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GRAD_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ACTIVATION_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.cc new file mode 100644 index 00000000000..2556df5bc84 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + ReLUGpuFwdKernel, double) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReLUGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ReLUGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ReLUGpuFwdKernel, int64_t) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReLUGpuFwdKernel, int32_t) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + ReLUGpuFwdKernel, int16_t) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), ReLUGpuFwdKernel, + int8_t) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + ReLUGpuFwdKernel, uint8_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h new file mode 100644 index 00000000000..ac408d8ff07 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_gpu_kernel.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ReLUGpuFwdKernel : public GpuKernel { + public: + ReLUGpuFwdKernel() { ResetResource(); } + ~ReLUGpuFwdKernel() override {} + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + const int size = input_size_ / sizeof(T); + CalReLU(size, input, output, reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 1."; + return false; + } + auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ReLUGpuFwdKernel input is null."; + } + size_t size = 1; + for (size_t i = 0; i < input_shape.size(); i++) { + size *= input_shape[i]; + } + input_size_ = size * sizeof(T); + + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + is_null_input_ = false; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + input_size_ = 0; + workspace_size_ = 0; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + workspace_size_list_.push_back(workspace_size_); + } + + private: + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.cc index 2739eac1e2d..3e85ceda308 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_grad_v2_gpu_kernel.cc @@ -18,6 +18,10 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE( + ReluGradV2, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), + ReluGradV2GpuKernel, double) MS_REG_GPU_KERNEL_ONE( ReluGradV2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), @@ -26,6 +30,13 @@ MS_REG_GPU_KERNEL_ONE( ReluGradV2, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), ReluGradV2GpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + ReluGradV2, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), + ReluGradV2GpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE( + ReluGradV2, + KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), + ReluGradV2GpuKernel, int16_t) MS_REG_GPU_KERNEL_ONE( ReluGradV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), @@ -34,5 +45,10 @@ MS_REG_GPU_KERNEL_ONE( ReluGradV2, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), ReluGradV2GpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + ReluGradV2, + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), + ReluGradV2GpuKernel, uint8_t) + } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.cc index 566c27f252c..9fa07bd1ff9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.cc @@ -18,6 +18,10 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE( + ReLUV2, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), + ReluV2GpuKernel, double) MS_REG_GPU_KERNEL_ONE( ReLUV2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), @@ -26,12 +30,20 @@ MS_REG_GPU_KERNEL_ONE( ReLUV2, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), ReluV2GpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), + ReluV2GpuKernel, int8_t) +MS_REG_GPU_KERNEL_ONE( + ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), + ReluV2GpuKernel, int16_t) MS_REG_GPU_KERNEL_ONE( ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), ReluV2GpuKernel, int32_t) MS_REG_GPU_KERNEL_ONE( - ReLUV2, - KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), + ReLUV2, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), ReluV2GpuKernel, int64_t) +MS_REG_GPU_KERNEL_ONE( + ReLUV2, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), + ReluV2GpuKernel, uint8_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h index c6cf25a2f43..a9d8b1a7012 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/relu_v2_gpu_kernel.h @@ -79,4 +79,4 @@ class ReluV2GpuKernel : public GpuKernel { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_MASK_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RELU_V2_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_relu_grad_op.py b/tests/st/ops/gpu/test_relu_grad_op.py deleted file mode 100644 index 43647c01804..00000000000 --- a/tests/st/ops/gpu/test_relu_grad_op.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops.operations import _grad_ops as G - - -class NetReluGrad(nn.Cell): - def __init__(self): - super(NetReluGrad, self).__init__() - self.rekuGrad = G.ReluGrad() - - def construct(self, x, dy): - return self.rekuGrad(dy, x) - - -def relu_grad_base(dtype): - x = Tensor(np.array([[[[-1, 1, 1], - [1, -1, 1], - [1, 1, -1]]]]).astype(dtype)) - dy = Tensor(np.array([[[[1, 0, 1], - [0, 1, 0], - [1, 1, 1]]]]).astype(dtype)) - expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.dtype) - error = np.ones(shape=[3, 3]) * 1.0e-6 - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - relu_grad = NetReluGrad() - output = relu_grad(x, dy) - diff = output.asnumpy() - expect - assert np.all(diff < error) - assert output.asnumpy().dtype == dtype - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_relu_grad_float16(): - relu_grad_base(np.float16) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_relu_grad_float32(): - relu_grad_base(np.float32) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_relu_grad_int8(): - relu_grad_base(np.int8) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_relu_grad_int32(): - relu_grad_base(np.int32) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_relu_grad_int64(): - relu_grad_base(np.int64)