diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu index 9cb6e88189f..9e084296b15 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu @@ -126,6 +126,13 @@ __global__ void LossInitKernel(T *loss) { loss[0] = static_cast(0.); } +template +__global__ void InitZero(T *array, int size) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + array[i] = static_cast(0.); + } +} + template __global__ void KLDivLossKernel(const int input_size, const int reduction, const T *input_x, const T *input_y, T *loss, T *tmp_loss) { @@ -332,6 +339,50 @@ void NLLLoss(const int n, const int c, const int reduction, const T *input, cons CopyEqual<<<1, 1, 0, stream>>>(tmp_target_weight, total_weight, 1); } +template +__global__ void NLLLossGradKernel(const int n, const int c, const int reduction, const T *input, const int32_t *target, + const S *weight, const S *total_weight, const T *dloss, T *dinput) { + int input_idx; + int target_class; + S tmp_quot; + if (reduction == 0) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + target_class = static_cast(target[i]); + + input_idx = (i * c) + target_class; + + MultiplyDevice(-weight[target_class], dloss[i], dinput + input_idx); + } + } else if (reduction == 1) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + target_class = static_cast(target[i]); + + input_idx = (i * c) + target_class; + + tmp_quot = (-weight[target_class]) / *total_weight; + MultiplyDevice(tmp_quot, dloss[0], dinput + input_idx); + } + } else if (reduction == 2) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + target_class = static_cast(target[i]); + + input_idx = (i * c) + target_class; + + MultiplyDevice(-weight[target_class], dloss[0], dinput + input_idx); + } + } +} + +template +void NLLLossGrad(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight, + const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream) { + int input_size = n * c; + InitZero<<>>(dinput, input_size); + + NLLLossGradKernel<<>>(n, c, reduction, input, target, weight, total_weight, + dloss, dinput); +} + template void KLDivLoss(const int &input_size, const int &reduction, const float *input_x, const float *input_y, float *loss, float *tmp_loss, cudaStream_t stream); @@ -354,6 +405,14 @@ template void NLLLoss(const int n, const int c, const int reduction const int32_t *target, const half *weight, float *loss, half *total_weight, float *tmp_loss, half *tmp_target_weight, cudaStream_t stream); +template void NLLLossGrad(const int n, const int c, const int reduction, const float *input, + const int32_t *target, const float *weight, const float *total_weight, + const float *dloss, float *dinput, cudaStream_t stream); + +template void NLLLossGrad(const int n, const int c, const int reduction, const float *input, + const int32_t *target, const half *weight, const half *total_weight, + const float *dloss, float *dinput, cudaStream_t stream); + template void KLDivLoss(const int &input_size, const int &reduction, const half *input_x, const half *input_y, half *loss, half *tmp_loss, cudaStream_t stream); @@ -375,3 +434,11 @@ template void NLLLoss(const int n, const int c, const int reduction, template void NLLLoss(const int n, const int c, const int reduction, const half *input, const int32_t *target, const float *weight, half *loss, float *total_weight, half *tmp_loss, float *tmp_target_weight, cudaStream_t stream); + +template void NLLLossGrad(const int n, const int c, const int reduction, const half *input, + const int32_t *target, const half *weight, const half *total_weight, + const half *dloss, half *dinput, cudaStream_t stream); + +template void NLLLossGrad(const int n, const int c, const int reduction, const half *input, + const int32_t *target, const float *weight, const float *total_weight, + const half *dloss, half *dinput, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh index a1390db56ca..135cda68a38 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh @@ -32,4 +32,8 @@ void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x template void NLLLoss(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight, T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream); +template +void NLLLossGrad(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight, + const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream); + #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.cc new file mode 100644 index 00000000000..4d7b2a2a630 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(NLLLossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + NLLLossGradGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO(NLLLossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32), + NLLLossGradGpuKernel, float, half) +MS_REG_GPU_KERNEL_TWO(NLLLossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16), + NLLLossGradGpuKernel, half, float) +MS_REG_GPU_KERNEL_TWO(NLLLossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + NLLLossGradGpuKernel, half, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h new file mode 100644 index 00000000000..3956db52149 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h @@ -0,0 +1,113 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_NLL_LOSS_GRAD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_NLL_LOSS_GRAD_GPU_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class NLLLossGradGpuKernel : public GpuKernel { + public: + NLLLossGradGpuKernel() { ResetResource(); } + ~NLLLossGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input_device = GetDeviceAddress(inputs, 0); + T *dloss_device = GetDeviceAddress(inputs, 1); + int32_t *target_device = GetDeviceAddress(inputs, 2); // nll_loss_grad only supports int32 target + S *weight_device = GetDeviceAddress(inputs, 3); + S *total_weight_device = GetDeviceAddress(inputs, 4); + + T *dinput_device = GetDeviceAddress(outputs, 0); + + NLLLossGrad(n_, c_, reduction_, input_device, target_device, weight_device, total_weight_device, dloss_device, + dinput_device, reinterpret_cast(stream_ptr)); + + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + std::vector input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + n_ = static_cast(input_shape[0]); + c_ = static_cast(input_shape[1]); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + string reduction = GetAttr(kernel_node, "reduction"); + + // if reduction is not 'none', tmp_nll is (N,) size + if (reduction == "none") { + reduction_ = 0; + num_dloss_ = n_; // dloss is a vector + } else if (reduction == "sum") { + reduction_ = 2; + } else { + // reduction = 'mean' + reduction_ = 1; + } + + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + input_size_ = 1; + n_ = 0; + c_ = 0; + reduction_ = 1; // default value + num_dloss_ = 1; // default size (scalar) + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); // input tensor with shape (N, C) + input_size_list_.push_back(num_dloss_ * sizeof(T)); // dloss tensor (either scalar or size N) + input_size_list_.push_back(n_ * sizeof(int32_t)); // target tensor with shape (N) + input_size_list_.push_back(c_ * sizeof(S)); // weight tensor with shape (C) + input_size_list_.push_back(sizeof(S)); // total_weight scalar + + output_size_list_.push_back(input_size_ * sizeof(T)); // dinput + } + + private: + size_t input_size_; + int reduction_; + int n_; + int c_; + int num_dloss_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_NLL_LOSS_GRAD_GPU_KERNEL_H diff --git a/tests/st/ops/gpu/test_nll_loss.py b/tests/st/ops/gpu/test_nll_loss.py index 8b748d56ddd..e4454a1fd2d 100644 --- a/tests/st/ops/gpu/test_nll_loss.py +++ b/tests/st/ops/gpu/test_nll_loss.py @@ -19,6 +19,7 @@ import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G from mindspore.ops import operations as P @@ -31,12 +32,23 @@ class Net(nn.Cell): return self.loss(predict, target, weight) +class NLLLossGradNet(nn.Cell): + def __init__(self, reduction): + super(NLLLossGradNet, self).__init__() + self.grad = G.NLLLossGrad(reduction=reduction) + + def construct(self, x, dout_x, target, weight, total_weight): + gout = self.grad(x, dout_x, target, weight, total_weight) + return gout + + def nll_loss_template(nptype_input, nptype_weight, reduction): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") nll_loss_net = Net(reduction) - predict = Tensor(np.array([[0.53, 0.74, -2.12], [1.29, -0.34, -1.13]]).astype(nptype_input)) + predict = Tensor( + np.array([[0.53, 0.74, -2.12], [1.29, -0.34, -1.13]]).astype(nptype_input)) target = Tensor(np.array([0, 1]).astype(np.int32)) @@ -67,7 +79,48 @@ def nll_loss_template(nptype_input, nptype_weight, reduction): ertol_weight = 1e-03 np.testing.assert_allclose(loss_np, expected_loss, ertol_loss) - np.testing.assert_allclose(total_weight_np, expected_tot_weight, ertol_weight) + np.testing.assert_allclose( + total_weight_np, expected_tot_weight, ertol_weight) + + +def nll_loss_grad_template(nptype_input, nptype_weight, reduction): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + nll_loss_grad_net = NLLLossGradNet(reduction) + + x = Tensor( + np.array([[0.53, 0.74, -2.12], [1.29, -0.34, -1.13]]).astype(nptype_input)) + + if reduction == "none": + dloss = Tensor( + np.array([3.24, -2.13]).astype(nptype_input)) + else: + dloss = Tensor(np.array(1.23).astype(nptype_input)) + + target = Tensor(np.array([0, 1]).astype(np.int32)) + weight = Tensor(np.array([0.45, -0.32, 1.21]).astype(nptype_weight)) + + total_weight = Tensor(np.array(0.13).astype(nptype_weight)) + + dx = nll_loss_grad_net(x, dloss, target, weight, total_weight) + + dx_np = dx.asnumpy() + + print(dx) + + if reduction == "none": + dx_expected = np.array([[-1.45799994, 0, 0], [0, -0.681600034, 0]]) + elif reduction == "mean": + dx_expected = np.array([[-4.25769234, 0, 0], [0, 3.02769232, 0]]) + else: + dx_expected = np.array([[-0.553499997, 0, 0], [0, 0.393599987, 0]]) + + if nptype_input == np.float32 and nptype_weight == np.float32: + ertol_loss = 1e-06 + else: + ertol_loss = 1e-02 + + np.testing.assert_allclose(dx_np, dx_expected, ertol_loss) @pytest.mark.level0 @@ -91,6 +144,7 @@ def test_nll_loss_mean_reduction(): nll_loss_template(np.float16, np.float32, "mean") nll_loss_template(np.float16, np.float16, "mean") + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -100,3 +154,36 @@ def test_nll_loss_sum_reduction(): nll_loss_template(np.float32, np.float16, "sum") nll_loss_template(np.float16, np.float32, "sum") nll_loss_template(np.float16, np.float16, "sum") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nll_loss_grad_mean_reduction(): + # Four combinations of fp32 and fp16 inputs and weights + nll_loss_grad_template(np.float32, np.float32, "mean") + nll_loss_grad_template(np.float32, np.float16, "mean") + nll_loss_grad_template(np.float16, np.float32, "mean") + nll_loss_grad_template(np.float16, np.float16, "mean") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nll_loss_grad_sum_reduction(): + # Four combinations of fp32 and fp16 inputs and weights + nll_loss_grad_template(np.float32, np.float32, "sum") + nll_loss_grad_template(np.float32, np.float16, "sum") + nll_loss_grad_template(np.float16, np.float32, "sum") + nll_loss_grad_template(np.float16, np.float16, "sum") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nll_loss_grad_no_reduction(): + # Four combinations of fp32 and fp16 inputs and weights + nll_loss_grad_template(np.float32, np.float32, "none") + nll_loss_grad_template(np.float32, np.float16, "none") + nll_loss_grad_template(np.float16, np.float32, "none") + nll_loss_grad_template(np.float16, np.float16, "none")