From abdba421e54f8154daa3461ee90c5c4897cbafbe Mon Sep 17 00:00:00 2001 From: markuskunej Date: Thu, 30 Sep 2021 20:22:01 +0000 Subject: [PATCH] added GetReductionInt to common_utils.h and replaced duplicated code in all loss with reduction gpu op kernels (nll loss, kl div loss, and binary cross entropy) --- .../ccsrc/backend/kernel_compiler/common_utils.cc | 11 +++++++++++ .../ccsrc/backend/kernel_compiler/common_utils.h | 1 + .../gpu/nn/binary_cross_entropy_gpu_kernel.h | 7 ++----- .../gpu/nn/binary_cross_entropy_grad_kernel.h | 7 ++----- .../gpu/nn/kl_div_loss_gpu_kernel.h | 7 ++----- .../gpu/nn/kl_div_loss_grad_kernel.h | 7 ++----- .../kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h | 14 +++----------- .../gpu/nn/nll_loss_grad_gpu_kernel.h | 14 ++++---------- 8 files changed, 27 insertions(+), 41 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc index 0d46fed362a..20cec768723 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc @@ -571,6 +571,17 @@ int Sign(float x) { return 0; } +int GetReductionInt(const std::string &reduction) { + if (reduction == "none") { + return 0; + } else if (reduction == "sum") { + return 2; + } else { + // reduction = 'mean' + return 1; + } +} + std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index) { MS_EXCEPTION_IF_NULL(anf_node); diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h index e29e9040007..c94241509aa 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h @@ -89,6 +89,7 @@ std::string GetProcessor(const AnfNodePtr &anf_node); Processor GetProcessor(const string &processor); bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b); int Sign(float x); +int GetReductionInt(const std::string &reduction); std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index); std::vector>> GetInputIndex(const std::vector &node_list, const std::vector &input_list); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h index c8633148e40..bfdd4bcc4cb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h @@ -22,6 +22,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace kernel { @@ -67,11 +68,7 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - if (reduction == "none") { - reduction_ = 0; - } else if (reduction == "sum") { - reduction_ = 2; - } + reduction_ = GetReductionInt(reduction); workspace_size_ = sizeof(T); if (reduction_ != 0) { workspace_size_ *= input_size_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h index c3b95e36cd7..c52835636a9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h @@ -22,6 +22,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace kernel { @@ -69,11 +70,7 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - if (reduction == "none") { - reduction_ = 0; - } else if (reduction == "sum") { - reduction_ = 2; - } + reduction_ = GetReductionInt(reduction); InitSizeLists(); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h index 3079b9f4767..7bbab716be5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h @@ -22,6 +22,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace kernel { @@ -60,11 +61,7 @@ class KLDivLossGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - if (reduction == "none") { - reduction_ = 0; - } else if (reduction == "sum") { - reduction_ = 2; - } + reduction_ = GetReductionInt(reduction); workspace_size_ = sizeof(T); if (reduction_ == 0) { workspace_size_ *= input_size_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h index 8409364cd3d..62771be64e4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h @@ -22,6 +22,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace kernel { @@ -61,11 +62,7 @@ class KLDivLossGradGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - if (reduction == "none") { - reduction_ = 0; - } else if (reduction == "sum") { - reduction_ = 2; - } + reduction_ = GetReductionInt(reduction); InitSizeLists(); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h index da2edb2178f..0f606b6a932 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h @@ -22,6 +22,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace kernel { @@ -60,19 +61,10 @@ class NLLLossGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - - // if reduction is not 'none', tmp_nll is (N,) size - if (reduction == "none") { - reduction_ = 0; - } else if (reduction == "sum") { - reduction_ = 2; - tmp_loss_size_ = sizeof(T) * n_; - } else { - // reduction = 'mean' - reduction_ = 1; + reduction_ = GetReductionInt(reduction); + if ((reduction_ == 2) || (reduction_ == 1)) { tmp_loss_size_ = sizeof(T) * n_; } - tmp_target_weight_size_ = n_ * sizeof(S); InitSizeLists(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h index 3956db52149..ed988db2f10 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h @@ -22,6 +22,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace kernel { @@ -59,16 +60,9 @@ class NLLLossGradGpuKernel : public GpuKernel { input_size_ *= input_shape[i]; } string reduction = GetAttr(kernel_node, "reduction"); - - // if reduction is not 'none', tmp_nll is (N,) size - if (reduction == "none") { - reduction_ = 0; - num_dloss_ = n_; // dloss is a vector - } else if (reduction == "sum") { - reduction_ = 2; - } else { - // reduction = 'mean' - reduction_ = 1; + reduction_ = GetReductionInt(reduction); + if (reduction_ == 0) { + num_dloss_ = n_; } InitSizeLists();