diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu new file mode 100755 index 00000000000..32eb1684cf8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "unary_op_grad_impl.cuh" +template +__global__ void SqrtGradKernel(const T *input, const T *dout, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + float input_f = static_cast(input[i]); + float dout_f = static_cast(dout[i]); + float res_vmul = dout_f / (2.0 * input_f); + output[i] = static_cast(res_vmul); + } + return; +} +template +__global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + float input_f = static_cast(input[i]); + float dout_f = static_cast(dout[i]); + float res_vmul = input_f * input_f * input_f; + res_vmul = -0.5 * res_vmul * dout_f; + output[i] = static_cast(res_vmul); + } + return; +} + +template +void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { + SqrtGradKernel<<>>(input, dout, output, count); + return; +} +template +void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { + RsqrtGradKernel<<>>(input, dout, output, count); + return; +} + +template void SqrtGrad(const float *input, const float *dout, float *output, const size_t count, + cudaStream_t cuda_stream); +template void RsqrtGrad(const float *input, const float *dout, float *output, const size_t count, + cudaStream_t cuda_stream); +template void SqrtGrad(const half *input, const half *dout, half *output, const size_t count, + cudaStream_t cuda_stream); +template void RsqrtGrad(const half *input, const half *dout, half *output, const size_t count, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh new file mode 100755 index 00000000000..61256ac73a0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); +template +void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc new file mode 100644 index 00000000000..43c5334c2ca --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + SqrtGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryGradOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + SqrtGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryGradOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + RsqrtGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryGradOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + RsqrtGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryGradOpGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h new file mode 100644 index 00000000000..e78676fd01d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h @@ -0,0 +1,142 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARYOP_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARYOP_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +enum UnaryGradOptype { UNARY_OP_SQRT_GRAD = 0, UNARY_OP_RSQRT_GRAD, UNARY_OP_GRAD_INVALID_TYPE = 255 }; +static const std::map kUnaryGradOpTypeMap = {{"SqrtGrad", UNARY_OP_SQRT_GRAD}, + {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}}; +template +class UnaryGradOpGpuKernel : public GpuKernel { + public: + UnaryGradOpGpuKernel() + : unary_grad_op_type_(UNARY_OP_GRAD_INVALID_TYPE), + input_size_(sizeof(T)), + dx_size_(sizeof(T)), + output_size_(sizeof(T)), + workspace_size_(0), + is_null_input_(false) {} + ~UnaryGradOpGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_x_addr = GetDeviceAddress(inputs, 0); + T *input_dx_addr = GetDeviceAddress(inputs, 1); + T *output_y_addr = GetDeviceAddress(outputs, 0); + + switch (unary_grad_op_type_) { + case UNARY_OP_SQRT_GRAD: { + SqrtGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_RSQRT_GRAD: { + RsqrtGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } + default: { + MS_LOG(EXCEPTION) << "Unary grad operation " << unary_grad_op_type_ << " is not supported."; + } + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kUnaryGradOpTypeMap.find(kernel_name); + if (iter == kUnaryGradOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "Unary grad operation " << kernel_name << " is not supported."; + } else { + unary_grad_op_type_ = iter->second; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but unary grad op needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but unary grad op needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "UnaryGradOpGpuKernel input 0 is null"; + InitSizeLists(); + return true; + } + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + auto dx_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(dx_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "UnaryGradOpGpuKernel input 1 is null"; + InitSizeLists(); + return true; + } + for (size_t i = 0; i < dx_shape.size(); i++) { + dx_size_ *= dx_shape[i]; + } + if (input_size_ != dx_size_) { + MS_LOG(WARNING) << "UnaryGradOpGpuKernel inputs should be same, but got " << input_size_ << " and " << dx_size_; + InitSizeLists(); + return true; + } + output_size_ = input_size_; + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(dx_size_); + output_size_list_.push_back(output_size_); + } + + private: + UnaryGradOptype unary_grad_op_type_; + size_t input_size_; + size_t dx_size_; + size_t output_size_; + size_t workspace_size_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARYOP_GRAD_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_rsqrt_grad_op.py b/tests/st/ops/gpu/test_rsqrt_grad_op.py new file mode 100644 index 00000000000..c84e7a92e76 --- /dev/null +++ b/tests/st/ops/gpu/test_rsqrt_grad_op.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class NetRsqrtGrad(nn.Cell): + def __init__(self): + super(NetRsqrtGrad, self).__init__() + self.rsqrt_grad = G.RsqrtGrad() + + def construct(self, x, dx): + return self.rsqrt_grad(x, dx) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_rsqrt_grad(): + x = Tensor(np.array([[[[-1, 1, 10], + [5.9, 6.1, 6], + [10, 1, -1]]]]).astype(np.float32)) + dx = Tensor(np.array([[[[1, 1, 1], + [2, 2, 2], + [3, 3, 3]]]]).astype(np.float32)) + expect = np.array([[[[0.5, -0.5, -500,], + [-205.37901, -226.98099, -216], + [-1500, -1.5, 1.5,]]]]).astype(np.float32) + error = np.ones(shape=[3, 3]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + rsqrt_grad = NetRsqrtGrad() + output = rsqrt_grad(x, dx) + diff = output.asnumpy() - expect + assert np.all(np.abs(diff) < error) diff --git a/tests/st/ops/gpu/test_sqrt_grad_op.py b/tests/st/ops/gpu/test_sqrt_grad_op.py new file mode 100644 index 00000000000..df9b2893cca --- /dev/null +++ b/tests/st/ops/gpu/test_sqrt_grad_op.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class NetSqrtGrad(nn.Cell): + def __init__(self): + super(NetSqrtGrad, self).__init__() + self.sqrt_grad = G.SqrtGrad() + + def construct(self, x, dx): + return self.sqrt_grad(x, dx) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sqrt_grad(): + x = Tensor(np.array([[[[-1, 1, 10], + [5.9, 6.1, 6], + [10, 1, -1]]]]).astype(np.float32)) + dx = Tensor(np.array([[[[1, 1, 1], + [2, 2, 2], + [3, 3, 3]]]]).astype(np.float32)) + expect = np.array([[[[-0.5, 0.5, 0.05,], + [0.16949153, 0.16393442, 0.16666667,], + [0.15, 1.5, -1.5,]]]]).astype(np.float32) + error = np.ones(shape=[3, 3]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + sqrt_grad = NetSqrtGrad() + output = sqrt_grad(x, dx) + diff = output.asnumpy() - expect + assert np.all(np.abs(diff) < error)