diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/apply_gradient_descent_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/apply_gradient_descent_impl.cu new file mode 100644 index 00000000000..19213071150 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/apply_gradient_descent_impl.cu @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/apply_gradient_descent_impl.cuh" + +template +__global__ void ApplyGradientDescent(const size_t size, T *var, const T *alpha, const T *delta, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + const T alpha_value = alpha[0]; + var[pos] -= alpha_value * delta[pos]; + output[pos] = var[pos]; + } +} + +template +void CalApplyGradientDescent(const size_t &size, T *var, const T *alpha, const T *delta, T *output, + cudaStream_t cuda_stream) { + ApplyGradientDescent<<>>(size, var, alpha, delta, output); +} + +template void CalApplyGradientDescent(const size_t &size, float *var, const float *alpha, const float *delta, + float *output, cudaStream_t cuda_stream); +template void CalApplyGradientDescent(const size_t &size, half *var, const half *alpha, const half *delta, + half *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/apply_gradient_descent_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/apply_gradient_descent_impl.cuh new file mode 100644 index 00000000000..6a96a898a5d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/apply_gradient_descent_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_APPLY_GRADIENT_DESCENT_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_APPLY_GRADIENT_DESCENT_IMPL_CUH_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void CalApplyGradientDescent(const size_t &size, T *var, const T *alpha, const T *delta, T *output, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_APPLY_GRADIENT_DESCENT_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/apply_gradient_descent_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/apply_gradient_descent_gpu_kernel.cc new file mode 100644 index 00000000000..ff88a5cbc91 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/apply_gradient_descent_gpu_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/apply_gradient_descent_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ApplyGradientDescent, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + ApplyGradientDescentKernel, float) +MS_REG_GPU_KERNEL_ONE(ApplyGradientDescent, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + ApplyGradientDescentKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/apply_gradient_descent_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/apply_gradient_descent_gpu_kernel.h new file mode 100644 index 00000000000..9116c2755f6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/apply_gradient_descent_gpu_kernel.h @@ -0,0 +1,92 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_GRADIENT_DESCENT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_GRADIENT_DESCENT_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/apply_gradient_descent_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ApplyGradientDescentKernel : public GpuKernel { + public: + ApplyGradientDescentKernel() { ResetResource(); } + ~ApplyGradientDescentKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *var = GetDeviceAddress(inputs, 0); + T *alpha = GetDeviceAddress(inputs, 1); + T *delta = GetDeviceAddress(inputs, 2); + T *output = GetDeviceAddress(outputs, 0); + CalApplyGradientDescent(input_size_, var, alpha, delta, output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but ApplyGradientDescent needs 3 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ApplyGradientDescent has 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = 1; + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + input_size_ = 1; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + } + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_APPLY_GRADIENT_DESCENT_GPU_KERNEL_H_ diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 79e914630cc..fd52d1580d4 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -995,7 +995,7 @@ class BCEWithLogitsLoss(_Loss): ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` Examples: >>> logits = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 6920692ea01..d50562768e8 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3832,7 +3832,7 @@ class BCEWithLogitsLoss(PrimitiveWithInfer): ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` Examples: >>> predict = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32)) @@ -6268,7 +6268,7 @@ class ApplyGradientDescent(PrimitiveWithInfer): TypeError: If `alpha` is neither a Number nor a Tensor. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` Examples: >>> import numpy as np diff --git a/tests/st/ops/gpu/test_apply_gradient_descent_op.py b/tests/st/ops/gpu/test_apply_gradient_descent_op.py new file mode 100644 index 00000000000..6475ee7b59b --- /dev/null +++ b/tests/st/ops/gpu/test_apply_gradient_descent_op.py @@ -0,0 +1,86 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self, var): + super(Net, self).__init__() + self.var = Parameter(var, name="var") + self.apply_gradient_descent = P.ApplyGradientDescent() + + def construct(self, alpha, delta): + return self.apply_gradient_descent(self.var, alpha, delta) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_apply_gradient_descent_float32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + var = Tensor(np.arange(10).reshape(2, 5).astype(np.float32) / 10) + net = Net(var) + alpha = Tensor(np.array([0.0001]).astype(np.float32)) + delta = Tensor(np.arange(34, 44).reshape(2, 5).astype(np.float32)) + output = net(alpha, delta) + expect = np.array([[-0.0034, 0.0965, 0.1964, 0.29630002, 0.3962], + [0.4961, 0.596, 0.69589996, 0.79580003, 0.8957]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect) + np.testing.assert_almost_equal(net.var.asnumpy(), expect) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + var = Tensor(np.arange(10).reshape(2, 5).astype(np.float32) / 10) + net = Net(var) + alpha = Tensor(np.array([0.0001]).astype(np.float32)) + delta = Tensor(np.arange(34, 44).reshape(2, 5).astype(np.float32)) + output = net(alpha, delta) + expect = np.array([[-0.0034, 0.0965, 0.1964, 0.29630002, 0.3962], + [0.4961, 0.596, 0.69589996, 0.79580003, 0.8957]], dtype=np.float32) + np.testing.assert_almost_equal(output.asnumpy(), expect) + np.testing.assert_almost_equal(net.var.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_apply_gradient_descent_float16(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + var = Tensor(np.arange(10).reshape(2, 5).astype(np.float16) / 10) + net = Net(var) + alpha = Tensor(np.array([0.0001]).astype(np.float16)) + delta = Tensor(np.arange(34, 44).reshape(2, 5).astype(np.float16)) + output = net(alpha, delta) + expect = np.array([[-0.0034, 0.0965, 0.1964, 0.29630002, 0.3962], + [0.4961, 0.596, 0.69589996, 0.79580003, 0.8957]], dtype=np.float16) + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=3) + np.testing.assert_almost_equal(net.var.asnumpy(), expect, decimal=3) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + var = Tensor(np.arange(10).reshape(2, 5).astype(np.float16) / 10) + net = Net(var) + alpha = Tensor(np.array([0.0001]).astype(np.float16)) + delta = Tensor(np.arange(34, 44).reshape(2, 5).astype(np.float16)) + output = net(alpha, delta) + expect = np.array([[-0.0034, 0.0965, 0.1964, 0.2964, 0.396], + [0.496, 0.596, 0.6963, 0.7954, 0.8955]], dtype=np.float16) + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=3) + np.testing.assert_almost_equal(net.var.asnumpy(), expect, decimal=3)