!18098 Change NLL_Loss total_weight output for gpu

Merge pull request !18098 from markuskunej/nll_loss_total_weight_fix
This commit is contained in:
i-robot 2021-06-10 11:03:23 +08:00 committed by Gitee
commit 63b91904ec
4 changed files with 20 additions and 29 deletions

View File

@ -313,26 +313,23 @@ __global__ void NLLLossKernel(const int n, const int c, const T *input, const in
template <typename T, typename S>
void NLLLoss(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight,
S *tmp_weight, T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream) {
CopyEqual<<<GET_BLOCKS(c), GET_THREADS, 0, stream>>>(weight, tmp_weight, c);
Sum(tmp_weight, c, stream);
// copy sum of weight (tmp_weight[0]) to total_weight
CopyEqual<<<1, 1, 0, stream>>>(tmp_weight, total_weight, 1);
T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream) {
// if reduction != "none"
if (reduction != 0) {
NLLLossKernel<<<GET_BLOCKS(n), GET_THREADS, 0, stream>>>(n, c, input, target, weight, tmp_target_weight, tmp_loss);
if (reduction == 1) {
// prepare denominator for mean reduction
Sum(tmp_target_weight, n, stream);
}
// sum target weights after populating them
Sum(tmp_target_weight, n, stream);
// reduce tmp_loss
Reduce(tmp_loss, n, tmp_target_weight, reduction, loss, stream);
} else {
// no reduction, output directly to loss
NLLLossKernel<<<GET_BLOCKS(n), GET_THREADS, 0, stream>>>(n, c, input, target, weight, tmp_target_weight, loss);
// sum target weights after populatin them
Sum(tmp_target_weight, n, stream);
}
// copy sum of weight (tmp_target_weight[0]) to total_weight
CopyEqual<<<1, 1, 0, stream>>>(tmp_target_weight, total_weight, 1);
}
template void KLDivLoss<float>(const int &input_size, const int &reduction, const float *input_x, const float *input_y,
@ -350,13 +347,12 @@ template void BinaryCrossEntropyLossGrad<float>(const int &input_size, const int
float *dx, cudaStream_t stream);
template void NLLLoss<float, float>(const int n, const int c, const int reduction, const float *input,
const int32_t *target, const float *weight, float *tmp_weight, float *loss,
float *total_weight, float *tmp_loss, float *tmp_target_weight,
cudaStream_t stream);
const int32_t *target, const float *weight, float *loss, float *total_weight,
float *tmp_loss, float *tmp_target_weight, cudaStream_t stream);
template void NLLLoss<float, half>(const int n, const int c, const int reduction, const float *input,
const int32_t *target, const half *weight, half *tmp_weight, float *loss,
half *total_weight, float *tmp_loss, half *tmp_target_weight, cudaStream_t stream);
const int32_t *target, const half *weight, float *loss, half *total_weight,
float *tmp_loss, half *tmp_target_weight, cudaStream_t stream);
template void KLDivLoss<half>(const int &input_size, const int &reduction, const half *input_x, const half *input_y,
half *loss, half *tmp_loss, cudaStream_t stream);
@ -373,9 +369,9 @@ template void BinaryCrossEntropyLossGrad<half>(const int &input_size, const int
cudaStream_t stream);
template void NLLLoss<half, half>(const int n, const int c, const int reduction, const half *input,
const int32_t *target, const half *weight, half *tmp_weight, half *loss,
half *total_weight, half *tmp_loss, half *tmp_target_weight, cudaStream_t stream);
const int32_t *target, const half *weight, half *loss, half *total_weight,
half *tmp_loss, half *tmp_target_weight, cudaStream_t stream);
template void NLLLoss<half, float>(const int n, const int c, const int reduction, const half *input,
const int32_t *target, const float *weight, float *tmp_weight, half *loss,
float *total_weight, half *tmp_loss, float *tmp_target_weight, cudaStream_t stream);
const int32_t *target, const float *weight, half *loss, float *total_weight,
half *tmp_loss, float *tmp_target_weight, cudaStream_t stream);

View File

@ -31,5 +31,5 @@ void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x
T *dx, T *dy, cudaStream_t stream);
template <typename T, typename S>
void NLLLoss(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight,
S *tmp_weight, T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream);
T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH

View File

@ -46,10 +46,9 @@ class NLLLossGpuKernel : public GpuKernel {
T *tmp_loss_device = GetDeviceAddress<T>(workspace, 0);
S *tmp_target_weight_device = GetDeviceAddress<S>(workspace, 1);
S *tmp_weight_device = GetDeviceAddress<S>(workspace, 2);
NLLLoss(n_, c_, reduction_, input_device, target_device, weight_device, tmp_weight_device, loss_device,
total_weight_device, tmp_loss_device, tmp_target_weight_device, reinterpret_cast<cudaStream_t>(stream_ptr));
NLLLoss(n_, c_, reduction_, input_device, target_device, weight_device, loss_device, total_weight_device,
tmp_loss_device, tmp_target_weight_device, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -74,7 +73,6 @@ class NLLLossGpuKernel : public GpuKernel {
tmp_loss_size_ = sizeof(T) * n_;
}
tmp_weight_size_ = c_ * sizeof(S);
tmp_target_weight_size_ = n_ * sizeof(S);
InitSizeLists();
@ -88,7 +86,6 @@ class NLLLossGpuKernel : public GpuKernel {
reduction_ = 1; // default value
tmp_loss_size_ = 0;
tmp_target_weight_size_ = 0; // tmp_target_weight (N,) array
tmp_weight_size_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
@ -108,7 +105,6 @@ class NLLLossGpuKernel : public GpuKernel {
output_size_list_.push_back(sizeof(S)); // total weight
workspace_size_list_.push_back(tmp_loss_size_);
workspace_size_list_.push_back(tmp_target_weight_size_);
workspace_size_list_.push_back(tmp_weight_size_);
}
private:
@ -116,7 +112,6 @@ class NLLLossGpuKernel : public GpuKernel {
int reduction_;
size_t tmp_loss_size_;
size_t tmp_target_weight_size_;
size_t tmp_weight_size_;
int n_;
int c_;
std::vector<size_t> input_size_list_;

View File

@ -47,7 +47,7 @@ def nll_loss_template(nptype_input, nptype_weight, reduction):
loss_np = loss.asnumpy()
total_weight_np = total_weight.asnumpy()
expected_tot_weight = np.array(1.34000003)
expected_tot_weight = np.array(0.129999995)
if reduction == 'none':
expected_loss = np.array([-0.238499984, -0.108800001])