added nll_loss_grad for gpu

This commit is contained in:
markuskunej 2021-06-08 11:50:17 -04:00
parent 56f6288fa7
commit 2fece8a7c2
5 changed files with 331 additions and 2 deletions

View File

@ -126,6 +126,13 @@ __global__ void LossInitKernel(T *loss) {
loss[0] = static_cast<T>(0.);
}
template <typename T>
__global__ void InitZero(T *array, int size) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
array[i] = static_cast<T>(0.);
}
}
template <typename T>
__global__ void KLDivLossKernel(const int input_size, const int reduction, const T *input_x, const T *input_y, T *loss,
T *tmp_loss) {
@ -332,6 +339,50 @@ void NLLLoss(const int n, const int c, const int reduction, const T *input, cons
CopyEqual<<<1, 1, 0, stream>>>(tmp_target_weight, total_weight, 1);
}
template <typename T, typename S>
__global__ void NLLLossGradKernel(const int n, const int c, const int reduction, const T *input, const int32_t *target,
const S *weight, const S *total_weight, const T *dloss, T *dinput) {
int input_idx;
int target_class;
S tmp_quot;
if (reduction == 0) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
target_class = static_cast<int>(target[i]);
input_idx = (i * c) + target_class;
MultiplyDevice(-weight[target_class], dloss[i], dinput + input_idx);
}
} else if (reduction == 1) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
target_class = static_cast<int>(target[i]);
input_idx = (i * c) + target_class;
tmp_quot = (-weight[target_class]) / *total_weight;
MultiplyDevice(tmp_quot, dloss[0], dinput + input_idx);
}
} else if (reduction == 2) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
target_class = static_cast<int>(target[i]);
input_idx = (i * c) + target_class;
MultiplyDevice(-weight[target_class], dloss[0], dinput + input_idx);
}
}
}
template <typename T, typename S>
void NLLLossGrad(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight,
const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream) {
int input_size = n * c;
InitZero<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(dinput, input_size);
NLLLossGradKernel<<<GET_BLOCKS(n), GET_THREADS, 0, stream>>>(n, c, reduction, input, target, weight, total_weight,
dloss, dinput);
}
template void KLDivLoss<float>(const int &input_size, const int &reduction, const float *input_x, const float *input_y,
float *loss, float *tmp_loss, cudaStream_t stream);
@ -354,6 +405,14 @@ template void NLLLoss<float, half>(const int n, const int c, const int reduction
const int32_t *target, const half *weight, float *loss, half *total_weight,
float *tmp_loss, half *tmp_target_weight, cudaStream_t stream);
template void NLLLossGrad<float, float>(const int n, const int c, const int reduction, const float *input,
const int32_t *target, const float *weight, const float *total_weight,
const float *dloss, float *dinput, cudaStream_t stream);
template void NLLLossGrad<float, half>(const int n, const int c, const int reduction, const float *input,
const int32_t *target, const half *weight, const half *total_weight,
const float *dloss, float *dinput, cudaStream_t stream);
template void KLDivLoss<half>(const int &input_size, const int &reduction, const half *input_x, const half *input_y,
half *loss, half *tmp_loss, cudaStream_t stream);
@ -375,3 +434,11 @@ template void NLLLoss<half, half>(const int n, const int c, const int reduction,
template void NLLLoss<half, float>(const int n, const int c, const int reduction, const half *input,
const int32_t *target, const float *weight, half *loss, float *total_weight,
half *tmp_loss, float *tmp_target_weight, cudaStream_t stream);
template void NLLLossGrad<half, half>(const int n, const int c, const int reduction, const half *input,
const int32_t *target, const half *weight, const half *total_weight,
const half *dloss, half *dinput, cudaStream_t stream);
template void NLLLossGrad<half, float>(const int n, const int c, const int reduction, const half *input,
const int32_t *target, const float *weight, const float *total_weight,
const half *dloss, half *dinput, cudaStream_t stream);

View File

@ -32,4 +32,8 @@ void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x
template <typename T, typename S>
void NLLLoss(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight,
T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream);
template <typename T, typename S>
void NLLLossGrad(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight,
const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(NLLLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
NLLLossGradGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(NLLLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32),
NLLLossGradGpuKernel, float, half)
MS_REG_GPU_KERNEL_TWO(NLLLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16),
NLLLossGradGpuKernel, half, float)
MS_REG_GPU_KERNEL_TWO(NLLLossGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
NLLLossGradGpuKernel, half, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,113 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_NLL_LOSS_GRAD_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_NLL_LOSS_GRAD_GPU_KERNEL_H
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class NLLLossGradGpuKernel : public GpuKernel {
public:
NLLLossGradGpuKernel() { ResetResource(); }
~NLLLossGradGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_device = GetDeviceAddress<T>(inputs, 0);
T *dloss_device = GetDeviceAddress<T>(inputs, 1);
int32_t *target_device = GetDeviceAddress<int32_t>(inputs, 2); // nll_loss_grad only supports int32 target
S *weight_device = GetDeviceAddress<S>(inputs, 3);
S *total_weight_device = GetDeviceAddress<S>(inputs, 4);
T *dinput_device = GetDeviceAddress<T>(outputs, 0);
NLLLossGrad(n_, c_, reduction_, input_device, target_device, weight_device, total_weight_device, dloss_device,
dinput_device, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
n_ = static_cast<int>(input_shape[0]);
c_ = static_cast<int>(input_shape[1]);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
// if reduction is not 'none', tmp_nll is (N,) size
if (reduction == "none") {
reduction_ = 0;
num_dloss_ = n_; // dloss is a vector
} else if (reduction == "sum") {
reduction_ = 2;
} else {
// reduction = 'mean'
reduction_ = 1;
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 1;
n_ = 0;
c_ = 0;
reduction_ = 1; // default value
num_dloss_ = 1; // default size (scalar)
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T)); // input tensor with shape (N, C)
input_size_list_.push_back(num_dloss_ * sizeof(T)); // dloss tensor (either scalar or size N)
input_size_list_.push_back(n_ * sizeof(int32_t)); // target tensor with shape (N)
input_size_list_.push_back(c_ * sizeof(S)); // weight tensor with shape (C)
input_size_list_.push_back(sizeof(S)); // total_weight scalar
output_size_list_.push_back(input_size_ * sizeof(T)); // dinput
}
private:
size_t input_size_;
int reduction_;
int n_;
int c_;
int num_dloss_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_NLL_LOSS_GRAD_GPU_KERNEL_H

View File

@ -19,6 +19,7 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import operations as P
@ -31,12 +32,23 @@ class Net(nn.Cell):
return self.loss(predict, target, weight)
class NLLLossGradNet(nn.Cell):
def __init__(self, reduction):
super(NLLLossGradNet, self).__init__()
self.grad = G.NLLLossGrad(reduction=reduction)
def construct(self, x, dout_x, target, weight, total_weight):
gout = self.grad(x, dout_x, target, weight, total_weight)
return gout
def nll_loss_template(nptype_input, nptype_weight, reduction):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
nll_loss_net = Net(reduction)
predict = Tensor(np.array([[0.53, 0.74, -2.12], [1.29, -0.34, -1.13]]).astype(nptype_input))
predict = Tensor(
np.array([[0.53, 0.74, -2.12], [1.29, -0.34, -1.13]]).astype(nptype_input))
target = Tensor(np.array([0, 1]).astype(np.int32))
@ -67,7 +79,48 @@ def nll_loss_template(nptype_input, nptype_weight, reduction):
ertol_weight = 1e-03
np.testing.assert_allclose(loss_np, expected_loss, ertol_loss)
np.testing.assert_allclose(total_weight_np, expected_tot_weight, ertol_weight)
np.testing.assert_allclose(
total_weight_np, expected_tot_weight, ertol_weight)
def nll_loss_grad_template(nptype_input, nptype_weight, reduction):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
nll_loss_grad_net = NLLLossGradNet(reduction)
x = Tensor(
np.array([[0.53, 0.74, -2.12], [1.29, -0.34, -1.13]]).astype(nptype_input))
if reduction == "none":
dloss = Tensor(
np.array([3.24, -2.13]).astype(nptype_input))
else:
dloss = Tensor(np.array(1.23).astype(nptype_input))
target = Tensor(np.array([0, 1]).astype(np.int32))
weight = Tensor(np.array([0.45, -0.32, 1.21]).astype(nptype_weight))
total_weight = Tensor(np.array(0.13).astype(nptype_weight))
dx = nll_loss_grad_net(x, dloss, target, weight, total_weight)
dx_np = dx.asnumpy()
print(dx)
if reduction == "none":
dx_expected = np.array([[-1.45799994, 0, 0], [0, -0.681600034, 0]])
elif reduction == "mean":
dx_expected = np.array([[-4.25769234, 0, 0], [0, 3.02769232, 0]])
else:
dx_expected = np.array([[-0.553499997, 0, 0], [0, 0.393599987, 0]])
if nptype_input == np.float32 and nptype_weight == np.float32:
ertol_loss = 1e-06
else:
ertol_loss = 1e-02
np.testing.assert_allclose(dx_np, dx_expected, ertol_loss)
@pytest.mark.level0
@ -91,6 +144,7 @@ def test_nll_loss_mean_reduction():
nll_loss_template(np.float16, np.float32, "mean")
nll_loss_template(np.float16, np.float16, "mean")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -100,3 +154,36 @@ def test_nll_loss_sum_reduction():
nll_loss_template(np.float32, np.float16, "sum")
nll_loss_template(np.float16, np.float32, "sum")
nll_loss_template(np.float16, np.float16, "sum")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nll_loss_grad_mean_reduction():
# Four combinations of fp32 and fp16 inputs and weights
nll_loss_grad_template(np.float32, np.float32, "mean")
nll_loss_grad_template(np.float32, np.float16, "mean")
nll_loss_grad_template(np.float16, np.float32, "mean")
nll_loss_grad_template(np.float16, np.float16, "mean")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nll_loss_grad_sum_reduction():
# Four combinations of fp32 and fp16 inputs and weights
nll_loss_grad_template(np.float32, np.float32, "sum")
nll_loss_grad_template(np.float32, np.float16, "sum")
nll_loss_grad_template(np.float16, np.float32, "sum")
nll_loss_grad_template(np.float16, np.float16, "sum")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nll_loss_grad_no_reduction():
# Four combinations of fp32 and fp16 inputs and weights
nll_loss_grad_template(np.float32, np.float32, "none")
nll_loss_grad_template(np.float32, np.float16, "none")
nll_loss_grad_template(np.float16, np.float32, "none")
nll_loss_grad_template(np.float16, np.float16, "none")