diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_grad_cpu_kernel.cc index b2d8b04243b..2f28fe5bd8f 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_grad_cpu_kernel.cc @@ -34,7 +34,6 @@ bool KLDivLossGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const auto kernel_ptr = std::dynamic_pointer_cast(base_operator); if (!kernel_ptr) { MS_LOG(EXCEPTION) << "cast KLDivLoss ops failed!"; - return false; } kernel_name_ = kernel_ptr->name(); reductionMode_ = kernel_ptr->get_reduction(); @@ -76,21 +75,21 @@ int KLDivLossGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons std::vector input_grad_shape = inputs[kIndex0]->GetShapeVector(); if (input_grad_shape.size() >= 1) { for (size_t i = 0; i < input_grad_shape.size(); ++i) { - input_grad_shape_size_ *= input_grad_shape[i]; + input_grad_shape_size_ *= LongToSize(input_grad_shape[i]); } } std::vector input_x_shape = inputs[kIndex1]->GetShapeVector(); if (input_x_shape.size() >= 1) { for (size_t i = 0; i < input_x_shape.size(); ++i) { - input_x_shape_size_ *= input_x_shape[i]; + input_x_shape_size_ *= LongToSize(input_x_shape[i]); } } std::vector input_target_shape = inputs[kIndex2]->GetShapeVector(); if (input_target_shape.size() >= 1) { for (size_t i = 0; i < input_target_shape.size(); ++i) { - input_target_shape_size_ *= input_target_shape[i]; + input_target_shape_size_ *= LongToSize(input_target_shape[i]); } } @@ -105,7 +104,7 @@ std::vector KLDivLossGradCpuKernelMod::GetOpSupport() { return support_list; } -bool KLDivLossGradCpuKernelMod::CheckParams() { +bool KLDivLossGradCpuKernelMod::CheckParams() const { // for kl div, shape size of input 1 and input 2 must be the same if (input_target_shape_size_ != input_x_shape_size_) { MS_LOG(ERROR) << kernel_name_ << ": input x shape size = " << input_x_shape_size_ @@ -131,8 +130,7 @@ bool KLDivLossGradCpuKernelMod::CheckParams() { } template -bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &workspace, +bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &, const std::vector &outputs) { auto *input_grad = reinterpret_cast(inputs[kIndex0]->addr); auto *input_target = reinterpret_cast(inputs[kIndex2]->addr); @@ -142,7 +140,7 @@ bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector &inpu Eigen::Map> array_target(input_target, input_target_shape_size_, 1); Eigen::Map> array_y(y, input_x_shape_size_, 1); - float coefficient = -1.0; + double coefficient = -1.0; if (reductionMode_ == ops::kMean) { coefficient /= input_x_shape_size_; } else if (reductionMode_ == ops::kBatchMean) { @@ -159,7 +157,7 @@ bool KLDivLossGradCpuKernelMod::LaunchKernel(const std::vector &inpu auto task = [&](size_t start, size_t end) { for (size_t i = start; i < end; ++i) { - if (static_cast(array_target[i]) <= 0.0) { + if (static_cast(array_target[i]) <= 0.0) { array_y[i] = static_cast(0); } else { array_y[i] *= (static_cast(coefficient) * static_cast(bcast)); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_grad_cpu_kernel.h index b6af394a12a..79249c1f1d6 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_grad_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/kl_div_loss_grad_cpu_kernel.h @@ -48,15 +48,14 @@ class KLDivLossGradCpuKernelMod : public NativeCpuKernelMod { private: template - bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + bool LaunchKernel(const std::vector &inputs, const std::vector &, const std::vector &outputs); - bool CheckParams(); + bool CheckParams() const; using KLDivLossGradFunc = std::function &, const std::vector &, const std::vector &)>; - private: static std::vector> func_list_; KLDivLossGradFunc kernel_func_; std::string reductionMode_ = "mean"; diff --git a/mindspore/core/ops/grad/kl_div_loss_grad.cc b/mindspore/core/ops/grad/kl_div_loss_grad.cc index ce08147bb4e..77ffcb33815 100644 --- a/mindspore/core/ops/grad/kl_div_loss_grad.cc +++ b/mindspore/core/ops/grad/kl_div_loss_grad.cc @@ -52,13 +52,13 @@ TypePtr KLDivLossGradInferType(const PrimitivePtr &prim, const std::vectorBuildType(); auto input_x_type = input_args[kInputIndex1]->BuildType(); auto input_target_type = input_args[kInputIndex2]->BuildType(); - CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, op_name); std::map types; - types.emplace("grad", input_grad_type); - types.emplace("x", input_x_type); - types.emplace("target", input_target_type); - CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name); + (void)types.emplace("grad", input_grad_type); + (void)types.emplace("x", input_x_type); + (void)types.emplace("target", input_target_type); + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name); return input_x_type; }