From 80e86a697a0cb9d7ed507aca032b66131080584d Mon Sep 17 00:00:00 2001 From: markuskunej Date: Wed, 15 Dec 2021 16:57:35 +0000 Subject: [PATCH] Added ReductionMode enum for loss with reduction gpu kernels. --- .../backend/kernel_compiler/common_utils.cc | 11 -- .../backend/kernel_compiler/common_utils.h | 1 - .../gpu/cuda_impl/loss_with_reduction_impl.cu | 124 +++++++++--------- .../cuda_impl/loss_with_reduction_impl.cuh | 30 +++-- .../gpu/nn/binary_cross_entropy_gpu_kernel.h | 10 +- .../gpu/nn/binary_cross_entropy_grad_kernel.h | 8 +- .../gpu/nn/kl_div_loss_gpu_kernel.h | 10 +- .../gpu/nn/kl_div_loss_grad_kernel.h | 8 +- .../gpu/nn/nll_loss_gpu_kernel.h | 14 +- .../gpu/nn/nll_loss_grad_gpu_kernel.h | 10 +- 10 files changed, 111 insertions(+), 115 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc index d8dbb0e8e1d..b027744b4f5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc @@ -531,17 +531,6 @@ int Sign(float x) { return 0; } -int GetReductionInt(const std::string &reduction) { - if (reduction == "none") { - return 0; - } else if (reduction == "sum") { - return 2; - } else { - // reduction = 'mean' - return 1; - } -} - std::vector> GetOutputIndex(const std::vector &node_list, const std::vector &input_list, const std::vector &output_list) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h index ba7dd5d01ba..2daa0a5ba8a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h @@ -86,7 +86,6 @@ std::string GetProcessor(const AnfNodePtr &anf_node); Processor GetProcessor(const string &processor); bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b); int Sign(float x); -int GetReductionInt(const std::string &reduction); std::vector> GetOutputIndex(const std::vector &node_list, const std::vector &input_list, const std::vector &output_list); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu index 9e084296b15..cbdd81f114b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu @@ -27,9 +27,9 @@ inline __device__ float maxT(float a, float b) { return fmaxf(a, b); } inline __device__ half maxT(half a, half b) { return a > b ? a : b; } template -__global__ void Copy(T *loss, T *tmp_loss, int reduction, int input_size) { +__global__ void Copy(T *loss, T *tmp_loss, ReductionMode reduction, int input_size) { loss[0] += tmp_loss[0]; - if (reduction == 1) { + if (reduction == ReductionMode::kMean) { loss[0] /= castT(loss[0], input_size); } } @@ -108,14 +108,14 @@ void Sum(T *array, const int &size, cudaStream_t stream) { } template -void Reduce(T *tmp_loss, const int &size, S *denom, const int &reduction, T *output, cudaStream_t stream) { +void Reduce(T *tmp_loss, const int &size, S *denom, const ReductionMode &reduction, T *output, cudaStream_t stream) { // sum losses together Sum(tmp_loss, size, stream); - if (reduction == 1) { + if (reduction == ReductionMode::kMean) { // mean reduction, divide sum by denominator, store result in output Divide<<<1, 1, 0, stream>>>(tmp_loss, denom, output); - } else if (reduction == 2) { + } else if (reduction == ReductionMode::kSum) { // sum reduction, copy sum to output CopyEqual<<>>(tmp_loss, output, size); } @@ -134,10 +134,10 @@ __global__ void InitZero(T *array, int size) { } template -__global__ void KLDivLossKernel(const int input_size, const int reduction, const T *input_x, const T *input_y, T *loss, - T *tmp_loss) { +__global__ void KLDivLossKernel(const int input_size, const ReductionMode reduction, const T *input_x, const T *input_y, + T *loss, T *tmp_loss) { T epsilon = 1e-6; - if (reduction == 0) { + if (reduction == ReductionMode::kNone) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { T denominator = maxT(input_y[i], epsilon); T value = input_y[i] * (logT(denominator) - input_x[i]); @@ -153,12 +153,12 @@ __global__ void KLDivLossKernel(const int input_size, const int reduction, const } template -void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, T *tmp_loss, - cudaStream_t stream) { +void KLDivLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, T *loss, + T *tmp_loss, cudaStream_t stream) { LossInitKernel<<<1, 1, 0, stream>>>(loss); KLDivLossKernel<<>>(input_size, reduction, input_x, input_y, loss, tmp_loss); - if (reduction != 0) { + if (reduction != ReductionMode::kNone) { if (input_size % 2 == 1) { AddTile<<<1, 1, 0, stream>>>(tmp_loss, input_size - 1); } @@ -173,11 +173,11 @@ void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, co } template -__global__ void KLDivLossGradKernel(const int input_size, const int reduction, const T *input_x, const T *input_y, - const T *dloss, T *dx, T *dy) { +__global__ void KLDivLossGradKernel(const int input_size, const ReductionMode reduction, const T *input_x, + const T *input_y, const T *dloss, T *dx, T *dy) { T epsilon = 1e-6; T one = static_cast(1); - if (reduction == 0) { + if (reduction == ReductionMode::kNone) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { T denominator = maxT(input_y[i], epsilon); dx[i] = -input_y[i] * dloss[i]; @@ -185,7 +185,7 @@ __global__ void KLDivLossGradKernel(const int input_size, const int reduction, c } } else { T dloss1 = dloss[0]; - if (reduction == 1) { + if (reduction == ReductionMode::kMean) { dloss1 = dloss[0] / castT(dloss[0], input_size); } for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { @@ -197,29 +197,29 @@ __global__ void KLDivLossGradKernel(const int input_size, const int reduction, c } template -void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, const T *dloss, - T *dx, T *dy, cudaStream_t stream) { +void KLDivLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, + const T *dloss, T *dx, T *dy, cudaStream_t stream) { KLDivLossGradKernel<<>>(input_size, reduction, input_x, input_y, dloss, dx, dy); } template -__global__ void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const T *input_x, +__global__ void BinaryCrossEntropyLossKernel(const int input_size, const ReductionMode reduction, const T *input_x, const T *input_y, const T *weight, T *loss, T *tmp_loss) { T epsilon = 1e-12; T one = static_cast(1); - if (reduction == 0 && weight != nullptr) { + if (reduction == ReductionMode::kNone && weight != nullptr) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { T value = -weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); loss[i] = value; } - } else if (reduction == 0 && weight == nullptr) { + } else if (reduction == ReductionMode::kNone && weight == nullptr) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { T value = -(input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); loss[i] = value; } - } else if (reduction != 0 && weight != nullptr) { + } else if (reduction != ReductionMode::kNone && weight != nullptr) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { T value = -weight[i] * (input_y[i] * logT(input_x[i] + epsilon) + (one - input_y[i]) * logT(one - input_x[i] + epsilon)); @@ -234,12 +234,12 @@ __global__ void BinaryCrossEntropyLossKernel(const int input_size, const int red } template -void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, +void BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, const T *weight, T *loss, T *tmp_loss, cudaStream_t stream) { LossInitKernel<<<1, 1, 0, stream>>>(loss); BinaryCrossEntropyLossKernel<<>>(input_size, reduction, input_x, input_y, weight, loss, tmp_loss); - if (reduction != 0) { + if (reduction != ReductionMode::kNone) { if (input_size % 2 == 1) { AddTile<<<1, 1, 0, stream>>>(tmp_loss, input_size - 1); } @@ -254,11 +254,11 @@ void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T } template -__global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int reduction, const T *input_x, +__global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const ReductionMode reduction, const T *input_x, const T *input_y, const T *weight, const T *dloss, T *dx) { T epsilon = 1e-12; T one = static_cast(1); - if (reduction == 0) { + if (reduction == ReductionMode::kNone) { if (weight != nullptr) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { T denominator = maxT(input_x[i] * (one - input_x[i]), epsilon); @@ -274,7 +274,7 @@ __global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int } } else { T dloss1 = dloss[0]; - if (reduction == 1) { + if (reduction == ReductionMode::kMean) { dloss1 = dloss[0] / castT(dloss[0], input_size); } if (weight != nullptr) { @@ -294,8 +294,8 @@ __global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int } template -void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, - const T *weight, const T *dloss, T *dx, cudaStream_t stream) { +void BinaryCrossEntropyLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, + const T *input_y, const T *weight, const T *dloss, T *dx, cudaStream_t stream) { BinaryCrossEntropyLossGradKernel<<>>(input_size, reduction, input_x, input_y, weight, dloss, dx); } @@ -319,10 +319,9 @@ __global__ void NLLLossKernel(const int n, const int c, const T *input, const in } template -void NLLLoss(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight, - T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream) { - // if reduction != "none" - if (reduction != 0) { +void NLLLoss(const int n, const int c, const ReductionMode reduction, const T *input, const int32_t *target, + const S *weight, T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream) { + if (reduction != ReductionMode::kNone) { NLLLossKernel<<>>(n, c, input, target, weight, tmp_target_weight, tmp_loss); // sum target weights after populating them Sum(tmp_target_weight, n, stream); @@ -340,12 +339,13 @@ void NLLLoss(const int n, const int c, const int reduction, const T *input, cons } template -__global__ void NLLLossGradKernel(const int n, const int c, const int reduction, const T *input, const int32_t *target, - const S *weight, const S *total_weight, const T *dloss, T *dinput) { +__global__ void NLLLossGradKernel(const int n, const int c, const ReductionMode reduction, const T *input, + const int32_t *target, const S *weight, const S *total_weight, const T *dloss, + T *dinput) { int input_idx; int target_class; S tmp_quot; - if (reduction == 0) { + if (reduction == ReductionMode::kNone) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { target_class = static_cast(target[i]); @@ -353,7 +353,7 @@ __global__ void NLLLossGradKernel(const int n, const int c, const int reduction, MultiplyDevice(-weight[target_class], dloss[i], dinput + input_idx); } - } else if (reduction == 1) { + } else if (reduction == ReductionMode::kMean) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { target_class = static_cast(target[i]); @@ -362,7 +362,7 @@ __global__ void NLLLossGradKernel(const int n, const int c, const int reduction, tmp_quot = (-weight[target_class]) / *total_weight; MultiplyDevice(tmp_quot, dloss[0], dinput + input_idx); } - } else if (reduction == 2) { + } else if (reduction == ReductionMode::kSum) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { target_class = static_cast(target[i]); @@ -374,8 +374,8 @@ __global__ void NLLLossGradKernel(const int n, const int c, const int reduction, } template -void NLLLossGrad(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight, - const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream) { +void NLLLossGrad(const int n, const int c, const ReductionMode reduction, const T *input, const int32_t *target, + const S *weight, const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream) { int input_size = n * c; InitZero<<>>(dinput, input_size); @@ -383,62 +383,62 @@ void NLLLossGrad(const int n, const int c, const int reduction, const T *input, dloss, dinput); } -template void KLDivLoss(const int &input_size, const int &reduction, const float *input_x, const float *input_y, - float *loss, float *tmp_loss, cudaStream_t stream); +template void KLDivLoss(const int &input_size, const ReductionMode &reduction, const float *input_x, + const float *input_y, float *loss, float *tmp_loss, cudaStream_t stream); -template void KLDivLossGrad(const int &input_size, const int &reduction, const float *input_x, +template void KLDivLossGrad(const int &input_size, const ReductionMode &reduction, const float *input_x, const float *input_y, const float *dloss, float *dx, float *dy, cudaStream_t stream); -template void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const float *input_x, +template void BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, const float *input_x, const float *input_y, const float *weight, float *loss, float *tmp_loss, cudaStream_t stream); -template void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const float *input_x, - const float *input_y, const float *weight, const float *dloss, - float *dx, cudaStream_t stream); +template void BinaryCrossEntropyLossGrad(const int &input_size, const ReductionMode &reduction, + const float *input_x, const float *input_y, const float *weight, + const float *dloss, float *dx, cudaStream_t stream); -template void NLLLoss(const int n, const int c, const int reduction, const float *input, +template void NLLLoss(const int n, const int c, const ReductionMode reduction, const float *input, const int32_t *target, const float *weight, float *loss, float *total_weight, float *tmp_loss, float *tmp_target_weight, cudaStream_t stream); -template void NLLLoss(const int n, const int c, const int reduction, const float *input, +template void NLLLoss(const int n, const int c, const ReductionMode reduction, const float *input, const int32_t *target, const half *weight, float *loss, half *total_weight, float *tmp_loss, half *tmp_target_weight, cudaStream_t stream); -template void NLLLossGrad(const int n, const int c, const int reduction, const float *input, +template void NLLLossGrad(const int n, const int c, const ReductionMode reduction, const float *input, const int32_t *target, const float *weight, const float *total_weight, const float *dloss, float *dinput, cudaStream_t stream); -template void NLLLossGrad(const int n, const int c, const int reduction, const float *input, +template void NLLLossGrad(const int n, const int c, const ReductionMode reduction, const float *input, const int32_t *target, const half *weight, const half *total_weight, const float *dloss, float *dinput, cudaStream_t stream); -template void KLDivLoss(const int &input_size, const int &reduction, const half *input_x, const half *input_y, - half *loss, half *tmp_loss, cudaStream_t stream); +template void KLDivLoss(const int &input_size, const ReductionMode &reduction, const half *input_x, + const half *input_y, half *loss, half *tmp_loss, cudaStream_t stream); -template void KLDivLossGrad(const int &input_size, const int &reduction, const half *input_x, const half *input_y, - const half *dloss, half *dx, half *dy, cudaStream_t stream); +template void KLDivLossGrad(const int &input_size, const ReductionMode &reduction, const half *input_x, + const half *input_y, const half *dloss, half *dx, half *dy, cudaStream_t stream); -template void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const half *input_x, +template void BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, const half *input_x, const half *input_y, const half *weight, half *loss, half *tmp_loss, cudaStream_t stream); -template void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const half *input_x, - const half *input_y, const half *weight, const half *dloss, half *dx, - cudaStream_t stream); +template void BinaryCrossEntropyLossGrad(const int &input_size, const ReductionMode &reduction, + const half *input_x, const half *input_y, const half *weight, + const half *dloss, half *dx, cudaStream_t stream); -template void NLLLoss(const int n, const int c, const int reduction, const half *input, +template void NLLLoss(const int n, const int c, const ReductionMode reduction, const half *input, const int32_t *target, const half *weight, half *loss, half *total_weight, half *tmp_loss, half *tmp_target_weight, cudaStream_t stream); -template void NLLLoss(const int n, const int c, const int reduction, const half *input, +template void NLLLoss(const int n, const int c, const ReductionMode reduction, const half *input, const int32_t *target, const float *weight, half *loss, float *total_weight, half *tmp_loss, float *tmp_target_weight, cudaStream_t stream); -template void NLLLossGrad(const int n, const int c, const int reduction, const half *input, +template void NLLLossGrad(const int n, const int c, const ReductionMode reduction, const half *input, const int32_t *target, const half *weight, const half *total_weight, const half *dloss, half *dinput, cudaStream_t stream); -template void NLLLossGrad(const int n, const int c, const int reduction, const half *input, +template void NLLLossGrad(const int n, const int c, const ReductionMode reduction, const half *input, const int32_t *target, const float *weight, const float *total_weight, const half *dloss, half *dinput, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh index 135cda68a38..bbb3137e8d7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh @@ -17,23 +17,31 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH +#include +#include + +enum class ReductionMode { kNone, kMean, kSum }; + +static std::map kReductionModeMap{ + {"none", ReductionMode::kNone}, {"mean", ReductionMode::kMean}, {"sum", ReductionMode::kSum}}; + template -void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, +void BinaryCrossEntropyLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, const T *weight, T *loss, T *tmp_loss, cudaStream_t stream); template -void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, - const T *weight, const T *dloss, T *dx, cudaStream_t stream); +void BinaryCrossEntropyLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, + const T *input_y, const T *weight, const T *dloss, T *dx, cudaStream_t stream); template -void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, T *tmp_loss, - cudaStream_t stream); +void KLDivLoss(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, T *loss, + T *tmp_loss, cudaStream_t stream); template -void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, const T *dloss, - T *dx, T *dy, cudaStream_t stream); +void KLDivLossGrad(const int &input_size, const ReductionMode &reduction, const T *input_x, const T *input_y, + const T *dloss, T *dx, T *dy, cudaStream_t stream); template -void NLLLoss(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight, - T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream); +void NLLLoss(const int n, const int c, const ReductionMode reduction, const T *input, const int32_t *target, + const S *weight, T *loss, S *total_weight, T *tmp_loss, S *tmp_target_weight, cudaStream_t stream); template -void NLLLossGrad(const int n, const int c, const int reduction, const T *input, const int32_t *target, const S *weight, - const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream); +void NLLLossGrad(const int n, const int c, const ReductionMode reduction, const T *input, const int32_t *target, + const S *weight, const S *total_weight, const T *dloss, T *dinput, cudaStream_t stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h index 110253ec468..deb8531ea70 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h @@ -34,7 +34,7 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { is_null_input_(false), kernel_name_("BinaryCrossEntropy"), input_size_(1), - reduction_(1), + reduction_(ReductionMode::kMean), workspace_size_(1) {} ~BinaryCrossEntropyGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -74,9 +74,9 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - reduction_ = GetReductionInt(reduction); + reduction_ = kReductionModeMap[reduction]; workspace_size_ = sizeof(T); - if (reduction_ != 0) { + if (reduction_ != ReductionMode::kNone) { workspace_size_ *= input_size_; } InitSizeLists(); @@ -90,7 +90,7 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { if (weight_defined_) { input_size_list_.push_back(input_size_ * sizeof(T)); } - if (reduction_ == 0) { + if (reduction_ == ReductionMode::kNone) { output_size_list_.push_back(input_size_ * sizeof(T)); } else { output_size_list_.push_back(sizeof(T)); @@ -103,7 +103,7 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { bool is_null_input_; std::string kernel_name_; size_t input_size_; - int reduction_; + ReductionMode reduction_; size_t workspace_size_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h index d51dc875dc1..43e79da6afd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h @@ -31,7 +31,7 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel { public: BinaryCrossEntropyGradGpuKernel() : input_size_(1), - reduction_(1), + reduction_(ReductionMode::kMean), weight_defined_(false), is_null_input_(false), kernel_name_("BinaryCrossEntropyGrad") {} @@ -75,7 +75,7 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - reduction_ = GetReductionInt(reduction); + reduction_ = kReductionModeMap[reduction]; InitSizeLists(); return true; } @@ -84,7 +84,7 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel { void InitSizeLists() override { input_size_list_.push_back(input_size_ * sizeof(T)); input_size_list_.push_back(input_size_ * sizeof(T)); - if (reduction_ == 0) { + if (reduction_ == ReductionMode::kNone) { input_size_list_.push_back(input_size_ * sizeof(T)); } else { input_size_list_.push_back(sizeof(T)); @@ -97,7 +97,7 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel { private: size_t input_size_; - int reduction_; + ReductionMode reduction_; bool weight_defined_; // true: there are 4 inputs, false: there are 3 inputs(no [weight]) bool is_null_input_; std::string kernel_name_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h index 5e8c49d6773..61c56e89cbd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h @@ -29,7 +29,7 @@ namespace kernel { template class KLDivLossGpuKernel : public GpuKernel { public: - KLDivLossGpuKernel() : input_size_(1), reduction_(1), is_null_input_(false), workspace_size_(0) {} + KLDivLossGpuKernel() : input_size_(1), reduction_(ReductionMode::kMean), is_null_input_(false), workspace_size_(0) {} ~KLDivLossGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -61,9 +61,9 @@ class KLDivLossGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - reduction_ = GetReductionInt(reduction); + reduction_ = kReductionModeMap[reduction]; workspace_size_ = sizeof(T); - if (reduction_ == 0) { + if (reduction_ == ReductionMode::kNone) { workspace_size_ *= input_size_; } InitSizeLists(); @@ -74,7 +74,7 @@ class KLDivLossGpuKernel : public GpuKernel { void InitSizeLists() override { input_size_list_.push_back(input_size_ * sizeof(T)); input_size_list_.push_back(input_size_ * sizeof(T)); - if (reduction_ == 0) { + if (reduction_ == ReductionMode::kNone) { output_size_list_.push_back(input_size_ * sizeof(T)); } else { output_size_list_.push_back(sizeof(T)); @@ -84,7 +84,7 @@ class KLDivLossGpuKernel : public GpuKernel { private: size_t input_size_; - int reduction_; + ReductionMode reduction_; bool is_null_input_; size_t workspace_size_; std::vector input_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h index 181d25dcd03..f1ff0440ce7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h @@ -29,7 +29,7 @@ namespace kernel { template class KLDivLossGradGpuKernel : public GpuKernel { public: - KLDivLossGradGpuKernel() : input_size_(1), reduction_(1), is_null_input_(false) {} + KLDivLossGradGpuKernel() : input_size_(1), reduction_(ReductionMode::kMean), is_null_input_(false) {} ~KLDivLossGradGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -62,7 +62,7 @@ class KLDivLossGradGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - reduction_ = GetReductionInt(reduction); + reduction_ = kReductionModeMap[reduction]; InitSizeLists(); return true; } @@ -73,7 +73,7 @@ class KLDivLossGradGpuKernel : public GpuKernel { input_size_list_.push_back(input_size_ * sizeof(T)); output_size_list_.push_back(input_size_ * sizeof(T)); output_size_list_.push_back(input_size_ * sizeof(T)); - if (reduction_ == 0) { + if (reduction_ == ReductionMode::kNone) { input_size_list_.push_back(input_size_ * sizeof(T)); } else { input_size_list_.push_back(sizeof(T)); @@ -82,7 +82,7 @@ class KLDivLossGradGpuKernel : public GpuKernel { private: size_t input_size_; - int reduction_; + ReductionMode reduction_; bool is_null_input_; std::vector input_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h index 77f6abc2e68..72d750926f3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h @@ -48,8 +48,8 @@ class NLLLossGpuKernel : public GpuKernel { T *loss_device = GetDeviceAddress(outputs, 0); S *total_weight_device = GetDeviceAddress(outputs, 1); - T *tmp_loss_device = - reduction_ != 0 ? GetDeviceAddress(workspace, 0) : GetPossiblyNullDeviceAddress(workspace, 0); + T *tmp_loss_device = reduction_ != ReductionMode::kNone ? GetDeviceAddress(workspace, 0) + : GetPossiblyNullDeviceAddress(workspace, 0); S *tmp_target_weight_device = GetDeviceAddress(workspace, 1); @@ -76,8 +76,8 @@ class NLLLossGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - reduction_ = GetReductionInt(reduction); - if ((reduction_ == 2) || (reduction_ == 1)) { + reduction_ = kReductionModeMap[reduction]; + if ((reduction_ == ReductionMode::kSum) || (reduction_ == ReductionMode::kMean)) { tmp_loss_size_ = sizeof(T) * n_; } tmp_target_weight_size_ = n_ * sizeof(S); @@ -91,7 +91,7 @@ class NLLLossGpuKernel : public GpuKernel { n_ = 0; c_ = 0; is_null_input_ = false; - reduction_ = 1; // default value + reduction_ = ReductionMode::kMean; // default value tmp_loss_size_ = 0; tmp_target_weight_size_ = 0; // tmp_target_weight (N,) array input_size_list_.clear(); @@ -105,7 +105,7 @@ class NLLLossGpuKernel : public GpuKernel { input_size_list_.push_back(n_ * sizeof(int32_t)); // target tensor with shape (N) input_size_list_.push_back(c_ * sizeof(S)); // weight tensor with shape (C) - if (reduction_ == 0) { + if (reduction_ == ReductionMode::kNone) { output_size_list_.push_back(n_ * sizeof(T)); // loss output of shape (N,) } else { output_size_list_.push_back(sizeof(T)); // scalar loss output @@ -117,7 +117,7 @@ class NLLLossGpuKernel : public GpuKernel { private: size_t input_size_; - int reduction_; + ReductionMode reduction_; size_t tmp_loss_size_; size_t tmp_target_weight_size_; int n_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h index 078204f8a2d..bce21c6f0ab 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h @@ -73,8 +73,8 @@ class NLLLossGradGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - reduction_ = GetReductionInt(reduction); - if (reduction_ == 0) { + reduction_ = kReductionModeMap[reduction]; + if (reduction_ == ReductionMode::kNone) { num_dloss_ = n_; } @@ -87,8 +87,8 @@ class NLLLossGradGpuKernel : public GpuKernel { n_ = 0; c_ = 0; is_null_input_ = false; - reduction_ = 1; // default value - num_dloss_ = 1; // default size (scalar) + reduction_ = ReductionMode::kMean; // default value + num_dloss_ = 1; // default size (scalar) input_size_list_.clear(); output_size_list_.clear(); workspace_size_list_.clear(); @@ -107,7 +107,7 @@ class NLLLossGradGpuKernel : public GpuKernel { private: size_t input_size_; - int reduction_; + ReductionMode reduction_; int n_; int c_; bool is_null_input_;