diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu index 5e5660be99f..9cb6e88189f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu @@ -313,26 +313,23 @@ __global__ void NLLLossKernel(const int n, const int c, const T *input, const in template void NLLLoss(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight, - S *tmp_weight, T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream) { - CopyEqual<<>>(weight, tmp_weight, c); - Sum(tmp_weight, c, stream); - - // copy sum of weight (tmp_weight[0]) to total_weight - CopyEqual<<<1, 1, 0, stream>>>(tmp_weight, total_weight, 1); - + T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream) { // if reduction != "none" if (reduction != 0) { NLLLossKernel<<>>(n, c, input, target, weight, tmp_target_weight, tmp_loss); - if (reduction == 1) { - // prepare denominator for mean reduction - Sum(tmp_target_weight, n, stream); - } + // sum target weights after populating them + Sum(tmp_target_weight, n, stream); // reduce tmp_loss Reduce(tmp_loss, n, tmp_target_weight, reduction, loss, stream); } else { // no reduction, output directly to loss NLLLossKernel<<>>(n, c, input, target, weight, tmp_target_weight, loss); + // sum target weights after populatin them + Sum(tmp_target_weight, n, stream); } + + // copy sum of weight (tmp_target_weight[0]) to total_weight + CopyEqual<<<1, 1, 0, stream>>>(tmp_target_weight, total_weight, 1); } template void KLDivLoss(const int &input_size, const int &reduction, const float *input_x, const float *input_y, @@ -350,13 +347,12 @@ template void BinaryCrossEntropyLossGrad(const int &input_size, const int float *dx, cudaStream_t stream); template void NLLLoss(const int n, const int c, const int reduction, const float *input, - const int32_t *target, const float *weight, float *tmp_weight, float *loss, - float *total_weight, float *tmp_loss, float *tmp_target_weight, - cudaStream_t stream); + const int32_t *target, const float *weight, float *loss, float *total_weight, + float *tmp_loss, float *tmp_target_weight, cudaStream_t stream); template void NLLLoss(const int n, const int c, const int reduction, const float *input, - const int32_t *target, const half *weight, half *tmp_weight, float *loss, - half *total_weight, float *tmp_loss, half *tmp_target_weight, cudaStream_t stream); + const int32_t *target, const half *weight, float *loss, half *total_weight, + float *tmp_loss, half *tmp_target_weight, cudaStream_t stream); template void KLDivLoss(const int &input_size, const int &reduction, const half *input_x, const half *input_y, half *loss, half *tmp_loss, cudaStream_t stream); @@ -373,9 +369,9 @@ template void BinaryCrossEntropyLossGrad(const int &input_size, const int cudaStream_t stream); template void NLLLoss(const int n, const int c, const int reduction, const half *input, - const int32_t *target, const half *weight, half *tmp_weight, half *loss, - half *total_weight, half *tmp_loss, half *tmp_target_weight, cudaStream_t stream); + const int32_t *target, const half *weight, half *loss, half *total_weight, + half *tmp_loss, half *tmp_target_weight, cudaStream_t stream); template void NLLLoss(const int n, const int c, const int reduction, const half *input, - const int32_t *target, const float *weight, float *tmp_weight, half *loss, - float *total_weight, half *tmp_loss, float *tmp_target_weight, cudaStream_t stream); + const int32_t *target, const float *weight, half *loss, float *total_weight, + half *tmp_loss, float *tmp_target_weight, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh index 328ab855f71..a1390db56ca 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh @@ -31,5 +31,5 @@ void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x T *dx, T *dy, cudaStream_t stream); template void NLLLoss(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight, - S *tmp_weight, T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream); + T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h index 60dc0c94087..f3ab1a2b90e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h @@ -46,10 +46,9 @@ class NLLLossGpuKernel : public GpuKernel { T *tmp_loss_device = GetDeviceAddress(workspace, 0); S *tmp_target_weight_device = GetDeviceAddress(workspace, 1); - S *tmp_weight_device = GetDeviceAddress(workspace, 2); - NLLLoss(n_, c_, reduction_, input_device, target_device, weight_device, tmp_weight_device, loss_device, - total_weight_device, tmp_loss_device, tmp_target_weight_device, reinterpret_cast(stream_ptr)); + NLLLoss(n_, c_, reduction_, input_device, target_device, weight_device, loss_device, total_weight_device, + tmp_loss_device, tmp_target_weight_device, reinterpret_cast(stream_ptr)); return true; } @@ -74,7 +73,6 @@ class NLLLossGpuKernel : public GpuKernel { tmp_loss_size_ = sizeof(T) * n_; } - tmp_weight_size_ = c_ * sizeof(S); tmp_target_weight_size_ = n_ * sizeof(S); InitSizeLists(); @@ -88,7 +86,6 @@ class NLLLossGpuKernel : public GpuKernel { reduction_ = 1; // default value tmp_loss_size_ = 0; tmp_target_weight_size_ = 0; // tmp_target_weight (N,) array - tmp_weight_size_ = 0; input_size_list_.clear(); output_size_list_.clear(); workspace_size_list_.clear(); @@ -108,7 +105,6 @@ class NLLLossGpuKernel : public GpuKernel { output_size_list_.push_back(sizeof(S)); // total weight workspace_size_list_.push_back(tmp_loss_size_); workspace_size_list_.push_back(tmp_target_weight_size_); - workspace_size_list_.push_back(tmp_weight_size_); } private: @@ -116,7 +112,6 @@ class NLLLossGpuKernel : public GpuKernel { int reduction_; size_t tmp_loss_size_; size_t tmp_target_weight_size_; - size_t tmp_weight_size_; int n_; int c_; std::vector input_size_list_; diff --git a/tests/st/ops/gpu/test_nll_loss.py b/tests/st/ops/gpu/test_nll_loss.py index 9c2c632ced4..8b748d56ddd 100644 --- a/tests/st/ops/gpu/test_nll_loss.py +++ b/tests/st/ops/gpu/test_nll_loss.py @@ -47,7 +47,7 @@ def nll_loss_template(nptype_input, nptype_weight, reduction): loss_np = loss.asnumpy() total_weight_np = total_weight.asnumpy() - expected_tot_weight = np.array(1.34000003) + expected_tot_weight = np.array(0.129999995) if reduction == 'none': expected_loss = np.array([-0.238499984, -0.108800001])